3.3 A Tensorflow (Keras) example
This page shows an example of changing custom Tensorflow (Keras) code into the format that AI Labs FL/FV platform required, which is from FLaVor Tensorflow (Keras) example. Please read the previous section 3-1 an overview of flavor fl in advance.
Prerequisites
- Training code e.g. ./main.py
- Dataset e.g. ./MNIST/*
- (Optional) Pre-trained weights of the model e.g. ./weights/weight.ckpt
Main.py
The modified part is marked as red text and the corresponding comment is in green text.
import argparse import os # Remember to import for environment variables, i.e. os.environ import pickle # Remember to import for saving a keras model
import tensorflow as tf import tensorflow_datasets as tfds
from flavor.cook.utils import SaveInfoJson, SetEvent, WaitEvent # import Flavor
def normalize_img(image, label): """Normalizes images: `uint8` -> `float32`.""" return tf.cast(image, tf.float32) / 255.0, label
def main():
parser = argparse.ArgumentParser(description="PyTorch MNIST Example") parser.add_argument( "--batch-size", type=int, default=32, metavar="N", help="input batch size for training (default: 32)", ) parser.add_argument( "--epochs", type=int, default=300, metavar="N", help="number of epochs to train (default: 300)", ) parser.add_argument( "--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)" ) args = parser.parse_args()
(ds_train, ds_test), ds_info = tfds.load( "mnist", split=["train", "test"], shuffle_files=True, as_supervised=True, with_info=True, ) # Here we use built-in dataset loader. You should custom dataset via os.environ["INPUT_PATH"] instead.
ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_train = ds_train.cache() ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) ds_train = ds_train.batch(args.batch_size) ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) ds_test = ds_test.batch(args.batch_size) ds_test = ds_test.cache() ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
model = tf.keras.models.Sequential( [ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10), ] ) model.compile( optimizer=tf.keras.optimizers.Adam(args.lr), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], )
# Tell the server that all preparations for training have been completed. SetEvent("TrainInitDone")
for epoch in range(args.epochs):
# Wait for the server WaitEvent("TrainStarted")
# Load checkpoint sent from the server if epoch != 0 or os.path.exists(os.environ["GLOBAL_MODEL_PATH"]): # replace "./weights/weight.ckpt" with it with open(os.environ["GLOBAL_MODEL_PATH"], "rb") as F: weights = pickle.load(F) model.set_weights(list(weights["state_dict"].values()))
# Verify the performance of the global model before training loss, accuracy = model.evaluate(ds_test)
# Save information that the server needs to know output_dict = {} output_dict["metadata"] = { "epoch": epoch, "datasetSize": ds_info.splits["train"].num_examples, "importance": 1.0, } output_dict["metrics"] = { "accuracy": accuracy, "basic/confusion_tp": -1, # If N/A or you don't want to track, fill in -1. "basic/confusion_fp": -1, "basic/confusion_fn": -1, "basic/confusion_tn": -1, } SaveInfoJson(output_dict)
model.fit(ds_train, epochs=1)
# Save checkpoints weights = {"state_dict": {str(key): value for key, value in enumerate(model.get_weights())}} with open(os.environ["LOCAL_MODEL_PATH"], "wb") as F: pickle.dump(weights, F, protocol=pickle.HIGHEST_PROTOCOL)
# Tell the server that this round of training work has ended. SetEvent("TrainFinished")
if __name__ == "__main__":
main() |
Note: Please NOT use early stopping while training, since the weight cannot be passed successfully.
Dockerfile
FROM tensorflow/tensorflow:2.11.0-gpu COPY . /app workdir /app RUN pip install -r requirements.txt ENV PROCESS="python main.py" CMD flavor-fl -m "${PROCESS}" |
"Tensorflow 2.11" is compatible to our system.
the requirements.txt is
tensorflow tensorflow_datasets https://github.com/ailabstw/FLaVor/archive/refs/heads/release/stable.zip |
Remember to check the correctness by check-fl command in the previous section.