3.4 A XGBoost example

    This page shows an example of changing custom PyTorch code into the format that AI Labs FL/FV platform required, which is from FLaVor XGBoost example. Please read the previous section 3-1 an overview of flavor fl in advance.

    Prerequisites

    • Training code e.g. ./main.py
    • Dataset e.g. ./MNIST/*
    • (Optional) Pre-trained weights of the model e.g. ./weights/weight.json 

    Main.py that use low-level API: xgboost.train and xgboost.update

    The modified part is marked as red text and the corresponding comment is in green text.

    import os

     

    import xgboost as xgb

    from sklearn.datasets import load_digits

    from sklearn.metrics import top_k_accuracy_score

    from sklearn.model_selection import train_test_split

     

    from flavor.cook.utils import SaveInfoJson, SetEvent, WaitEvent

    from flavor.gadget.xgboost import load_update_xgbmodel, save_update_xgbmodel # import flavor

     

    total_round = 10

     

    # Load example dataset from sklearn, please load from INPUT_PATH in the real-world scenario

    data = load_digits()

    X_train, X_val, y_train, y_val = train_test_split(

        data["data"], data["target"], test_size=0.2, random_state=0

    )

     

    # LOW-LEVEL API START

    # Training hyperparmeter

    param = {

        "objective": "multi:softprob",

        "eta": 0.1,

        "max_depth": 8,

        "nthread": 16,

        "num_parallel_tree": 1,

        "subsample": 1,

        "tree_method": "hist",

        "num_class": 10,

    }

     

    # Tell the server that all preparations for training have been completed.

    SetEvent("TrainInitDone")

     

    for round_idx in range(total_round): # the term "round" here is "epoch"

     

        # Wait for the server

        WaitEvent("TrainStarted")

     

        # Reloading data to DMatrix are required if loading any model file due to xgb's bug

        dtrain = xgb.DMatrix(X_train, y_train)

        dval_X = xgb.DMatrix(X_val)

     

        if round_idx == 0:

            acc = 0.0

     

            # First round, set num_boost_round to 1

            bst = xgb.train(param, dtrain, num_boost_round=1)

        else:

            # Load checkpoint sent from the server

            load_update_xgbmodel(bst, os.environ["GLOBAL_MODEL_PATH"])

     

            # Save latest checkpoints

            bst.save_model(os.path.join(os.environ["OUTPUT_PATH"], "latest.json"))

     

            # Eval global model

            y_pred = bst.predict(dval_X)

            acc = top_k_accuracy_score(y_val, y_pred, k=1)

     

            # Local train

            bst.update(dtrain, bst.num_boosted_rounds())

     

        # Save information that the server needs to know

        output_dict = {}

        output_dict["metadata"] = {"epoch": round_idx, "datasetSize": len(X_train), "importance": 1.0}

        output_dict["metrics"] = {

            "accuracy": acc,

            "basic/confusion_tp": -1,  # If N/A or you don't want to track, fill in -1.

            "basic/confusion_fp": -1,

            "basic/confusion_fn": -1,

            "basic/confusion_tn": -1,

        }

        SaveInfoJson(output_dict)

     

        # Save newly added tree by slicing instead of the entire tree

        save_update_xgbmodel(bst, os.environ["LOCAL_MODEL_PATH"])

     

        # Tell the server that this round of training work has ended.

        SetEvent("TrainFinished")

    # LOW-LEVEL API END

    Main.py that use high-level API: xgboost.sklearn.XGBClassifier.fit

    For user that use high level API such as

    xgb = xgboost.sklearn.xgbclassifier(n_estimators=epoch, ...)

    xgb.fit(xTrain, yTrain)

    Please replace it with the low level API (from "LOW-LEVEL API START" to "LOW-LEVEL API END")

    Dockerfile

    FROM python:3.8.16-slim

    COPY . /app

    workdir /app

    RUN pip install -r requirements.txt

    ENV PROCESS="python main.py"

    CMD flavor-fl --ignore-ckpt -m "${PROCESS}"

    the requirements.txt is

    scikit-learn

    https://github.com/ailabstw/FLaVor/archive/refs/heads/release/stable.zip

    Note that for XGBoost framework, both "flavor-fl" and "check-fl" should add a tag "--ignore-ckpt".

    Remember to check the correctness by check-fl command in the previous section.

    Taiwan AI Labs (AILabs.tw) Copyright © 2023Powered by Bludit