3.4 A XGBoost example
This page shows an example of changing custom PyTorch code into the format that AI Labs FL/FV platform required, which is from FLaVor XGBoost example. Please read the previous section 3-1 an overview of flavor fl in advance.
Prerequisites
- Training code e.g. ./main.py
- Dataset e.g. ./MNIST/*
- (Optional) Pre-trained weights of the model e.g. ./weights/weight.json
Main.py that use low-level API: xgboost.train and xgboost.update
The modified part is marked as red text and the corresponding comment is in green text.
import os
import xgboost as xgb from sklearn.datasets import load_digits from sklearn.metrics import top_k_accuracy_score from sklearn.model_selection import train_test_split
from flavor.cook.utils import SaveInfoJson, SetEvent, WaitEvent from flavor.gadget.xgboost import load_update_xgbmodel, save_update_xgbmodel # import flavor
total_round = 10
# Load example dataset from sklearn, please load from INPUT_PATH in the real-world scenario data = load_digits() X_train, X_val, y_train, y_val = train_test_split( data["data"], data["target"], test_size=0.2, random_state=0 )
# LOW-LEVEL API START # Training hyperparmeter param = { "objective": "multi:softprob", "eta": 0.1, "max_depth": 8, "nthread": 16, "num_parallel_tree": 1, "subsample": 1, "tree_method": "hist", "num_class": 10, }
# Tell the server that all preparations for training have been completed. SetEvent("TrainInitDone")
for round_idx in range(total_round): # the term "round" here is "epoch"
# Wait for the server WaitEvent("TrainStarted")
# Reloading data to DMatrix are required if loading any model file due to xgb's bug dtrain = xgb.DMatrix(X_train, y_train) dval_X = xgb.DMatrix(X_val)
if round_idx == 0: acc = 0.0
# First round, set num_boost_round to 1 bst = xgb.train(param, dtrain, num_boost_round=1) else: # Load checkpoint sent from the server load_update_xgbmodel(bst, os.environ["GLOBAL_MODEL_PATH"])
# Save latest checkpoints bst.save_model(os.path.join(os.environ["OUTPUT_PATH"], "latest.json"))
# Eval global model y_pred = bst.predict(dval_X) acc = top_k_accuracy_score(y_val, y_pred, k=1)
# Local train bst.update(dtrain, bst.num_boosted_rounds())
# Save information that the server needs to know output_dict = {} output_dict["metadata"] = {"epoch": round_idx, "datasetSize": len(X_train), "importance": 1.0} output_dict["metrics"] = { "accuracy": acc, "basic/confusion_tp": -1, # If N/A or you don't want to track, fill in -1. "basic/confusion_fp": -1, "basic/confusion_fn": -1, "basic/confusion_tn": -1, } SaveInfoJson(output_dict)
# Save newly added tree by slicing instead of the entire tree save_update_xgbmodel(bst, os.environ["LOCAL_MODEL_PATH"])
# Tell the server that this round of training work has ended.
SetEvent("TrainFinished") # LOW-LEVEL API END |
Main.py that use high-level API: xgboost.sklearn.XGBClassifier.fit
For user that use high level API such as
xgb = xgboost.sklearn.xgbclassifier(n_estimators=epoch, ...) xgb.fit(xTrain, yTrain) |
Please replace it with the low level API (from "LOW-LEVEL API START" to "LOW-LEVEL API END")
Dockerfile
FROM python:3.8.16-slim COPY . /app workdir /app RUN pip install -r requirements.txt ENV PROCESS="python main.py" CMD flavor-fl --ignore-ckpt -m "${PROCESS}" |
the requirements.txt is
scikit-learn https://github.com/ailabstw/FLaVor/archive/refs/heads/release/stable.zip |
Note that for XGBoost framework, both "flavor-fl" and "check-fl" should add a tag "--ignore-ckpt".
Remember to check the correctness by check-fl command in the previous section.