3.2 A Pytorch example
This page shows an example of changing custom PyTorch code into the format that AI Labs FL/FV platform required, which is from FLaVor PyTorch example. Please read the previous section 3-1 an overview of flavor fl in advance.
Prerequisites
- Training code e.g. ./main.py
- Dataset e.g. ./MNIST/*
- (Optional) Pre-trained weights of the model e.g. ./weights/weight.pth
Main.py
The modified part is marked as red text and the corresponding comment is in green text.
from __future__ import print_function
import argparse import os
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim.lr_scheduler import StepLR from torchvision import datasets, transforms
from flavor.cook.utils import SaveInfoJson, SetEvent, WaitEvent # import Flavor
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10)
def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output
def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss.item(), ) ) if args.dry_run: break
def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print( "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) ) )
return 100.0 * correct / len(test_loader.dataset)
def main(): # Training settings parser = argparse.ArgumentParser(description="PyTorch MNIST Example") parser.add_argument( "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)", ) parser.add_argument( "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)", ) parser.add_argument( "--epochs", type=int, default=300, metavar="N", help="number of epochs to train (default: 300)", ) parser.add_argument( "--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)" ) parser.add_argument( "--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)", ) parser.add_argument( "--no-cuda", action="store_true", default=False, help="disables CUDA training" ) parser.add_argument( "--no-mps", action="store_true", default=False, help="disables macOS GPU training" ) parser.add_argument( "--dry-run", action="store_true", default=False, help="quickly check a single pass" ) parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--log-interval", type=int, default=10, metavar="N", help="how many batches to wait before logging training status", ) parser.add_argument( "--save-model", action="store_true", default=False, help="For Saving the current Model" ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {"batch_size": args.batch_size} test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs)
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] )
dataset1 = datasets.MNIST( os.environ["INPUT_PATH"], train=True, download=True, transform=transform ) # Use an environment variable os.environ["INPUT_PATH"] instead of "./MNIST" dataset2 = datasets.MNIST(os.environ["INPUT_PATH"], train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
# Tell the server that all preparations for training have been completed. SetEvent("TrainInitDone") epoch=-1 while True: epoch+=1
# Wait for the server WaitEvent("TrainStarted")
# Load checkpoint sent from the server if epoch != 0 or os.path.exists(os.environ["GLOBAL_MODEL_PATH"]): # /weights/weights.pth model.load_state_dict(torch.load(os.environ["GLOBAL_MODEL_PATH"])["state_dict"])
# Verify the performance of the global model before training precision = test(model, device, test_loader)
# Save information that the server needs to know output_dict = {} output_dict["metadata"] = {"epoch": epoch, "datasetSize": len(dataset1), "importance": 1.0} output_dict["metrics"] = { "precision": precision, "basic/confusion_tp": -1, # If N/A or you don't want to track, fill in -1. "basic/confusion_fp": -1, "basic/confusion_fn": -1, "basic/confusion_tn": -1, } SaveInfoJson(output_dict) # Save json to the server
train(args, model, device, train_loader, optimizer, epoch) scheduler.step()
# Save checkpoint torch.save({"state_dict": model.state_dict()}, os.environ["LOCAL_MODEL_PATH"])
# Tell the server that this round of training work has ended. SetEvent("TrainFinished")
if __name__ == "__main__":
main() |
Dockerfile
FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime COPY . /app workdir /app RUN pip install -r requirements.txt ENV PROCESS="python main.py" CMD flavor-fl -m "${PROCESS}" |
the requirements.txt is
torch torchvision https://github.com/ailabstw/FLaVor/archive/refs/heads/release/stable.zip |
Remember to check the correctness by check-fl command in the previous section.