3.1.1 A quick review of FL input and output interface
Overall
Before | After |
load dataset from local |
import flavor # step 1 load dataset from os.environ["INPUT_PATH"] # step 4 SetEvent("TrainInitDone") # step 5 epoch=-1 while True: # step 2 epoch+=1 WaitEvent("TrainStarted") # step 5 load model from aggregator by os.environ["GLOBAL_MODEL_PATH"] # step 4 Prepare an output dictionary # step 3 SaveInfoJson(dictionary) # step 5 training 1 epoch # step 2 save model to os.environ["LOCAL_MODEL_PATH"] # step 4 SetEvent("TrainFinished") # step 5 SetEvent("ProcessFinished") # step 5 (Optional in most cases) |
Input interface
-
INPUT_PATH with examples
Before After # import cv2
img = cv2.imread("./dataset/img1.jpg")# import cv2
img = cv2.imread(os.environ["INPUT_PATH"]+"/dataset/img1.jpg")# import pandas as pd
df = pd.read_csv("./data.csv")# import pandas as pd
df = pd.read_csv(os.environ["INPUT_PATH"]+"/data.csv") -
LOCAL_MODEL_PATH with an example
Pytorch Tensorflow Keras weights={"state_dict":model.state_dict()} weights = {"state_dict": {str(key): val for key,val in enumerate(model.get_weights())}} torch.save(weights, os.environ["LOCAL_MODEL_PATH"]) with open(os.environ["LOCAL_MODEL_PATH"], "wb") as F:
pickle.dump(weights, F, protocol=pickle.HIGHEST_PROTOCOL) -
GLOBAL_MODEL_PATH with an example
if epoch != 0 or os.path.exists(os.environ["GLOBAL_MODEL_PATH"]): Pytorch Tensorflow Keras model.load_state_dict(
torch.load(
os.environ["GLOBAL_MODEL_PATH"])["state_dict"])with open(os.environ["GLOBAL_MODEL_PATH"], "rb") as F:
weights = pickle.load(F)
model.set_weights(list(weights["state_dict"].values()))
Output interface
Prepare a Json file that has the following format and save it to OUTPUT_PATH
- metadata: dict # Must
- epoch: int
- datasetSize: int # The number of training data.
- importance: float # Epoch weight. Set equal weights 1.0 in default
- metrics: dict # These parameters are required ,you can add any other keys (str) with values (float).
- precision: float
- basic/confusion_tp: float
- basic/confusion_fp: float
- basic/confusion_fn: float
- basic/confusion_tn: float