3.1.1 A quick review of FL input and output interface

    Overall

    Before After

    load dataset from local
    training epochs
    save model at local

    import flavor # step 1
    load dataset from os.environ["INPUT_PATH"] # step 4
    SetEvent("TrainInitDone") # step 5
    epoch=-1
    while True: # step 2
            epoch+=1
            WaitEvent("TrainStarted") # step 5
            load model from aggregator by os.environ["GLOBAL_MODEL_PATH"] # step 4
            Prepare an output dictionary # step 3
            SaveInfoJson(dictionary) # step 5
            training 1 epoch # step 2
            save model to os.environ["LOCAL_MODEL_PATH"] # step 4
            SetEvent("TrainFinished") # step 5
    SetEvent("ProcessFinished") # step 5 (Optional in most cases)

    Input interface

    • INPUT_PATH with examples

      Before After
      # import cv2
      img = cv2.imread("./dataset/img1.jpg")
      # import cv2
      img = cv2.imread(os.environ["INPUT_PATH"]+"/dataset/img1.jpg")
      # import pandas as pd
      df = pd.read_csv("./data.csv")
      # import pandas as pd
      df = pd.read_csv(os.environ["INPUT_PATH"]+"/data.csv")

    • LOCAL_MODEL_PATH with an example

      Pytorch Tensorflow Keras
      weights={"state_dict":model.state_dict()} weights = {"state_dict": {str(key): val for key,val in enumerate(model.get_weights())}}
      torch.save(weights, os.environ["LOCAL_MODEL_PATH"]) with open(os.environ["LOCAL_MODEL_PATH"], "wb") as F:
              pickle.dump(weights, F, protocol=pickle.HIGHEST_PROTOCOL)

    • GLOBAL_MODEL_PATH with an example

      if epoch != 0 or os.path.exists(os.environ["GLOBAL_MODEL_PATH"]):
      Pytorch Tensorflow Keras
      model.load_state_dict(
      torch.load(
      os.environ["GLOBAL_MODEL_PATH"])["state_dict"])
      with open(os.environ["GLOBAL_MODEL_PATH"], "rb") as F:
              weights = pickle.load(F)
              model.set_weights(list(weights["state_dict"].values()))

    Output interface

    Prepare a Json file that has the following format and save it to OUTPUT_PATH

    • metadata: dict # Must
      • epoch: int
      • datasetSize: int   # The number of training data.
      • importance: float # Epoch weight. Set equal weights 1.0 in default
    • metrics: dict # These parameters are required ,you can add any other keys (str) with values (float).
      • precision: float
      • basic/confusion_tp: float
      • basic/confusion_fp: float
      • basic/confusion_fn: float
      • basic/confusion_tn: float

    Taiwan AI Labs (AILabs.tw) Copyright © 2023Powered by Bludit