NVIDIA FLARE™ (NVIDIA Federated Learning Application Runtime
Environment) is a domain-agnostic, open-source, and extensible SDK
for Federated Learning. It allows researchers and data scientists
to adapt existing ML/DL workflow to a federated paradigm and
enables platform developers to build a secure, privacy-preserving
offering for a distributed multi-party collaboration.
Begin your NVIDIA FLARE journey with these guides
designed to help you quickly grasp essential concepts.
Follow along with the videos below, and try it out yourself.
Step 1
Introduction to NVIDIA FLARE
Learn about the core concepts and fundamentals of NVIDIA FLARE to help you get started.
Try out these example code blocks below, where we showcase how simple it is to adapt
a popular machine learning framework to a federated learning scenario with NVIDIA FLARE.
For more details, refer to the getting started walkthrough guide above.
Install the required PyTorch dependencies for this example.
pip install nvflare torch torchvision
pip install nvflare torch torchvision
Installation
Install the required PyTorch Lightning dependencies for this example.
pip install nvflare pytorch_lightning
pip install nvflare pytorch_lightning
Installation
Install the required TensorFlow dependencies for this example.
pip install nvflare tensorflow
pip install nvflare tensorflow
Client Code (cifar10_pt_fl.py)
We use the Client API to convert the centralized training PyTorch code into federated learning code with only a few lines of changes highlighted below. Essentially the client will receive a model from NVIDIA FLARE, perform local training and validation, and then send the model back.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
classNet(nn.Module):def__init__(self):super().__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)defforward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x,1)# flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)return x
# (1) import nvflare client APIimport nvflare.client as flare
# (optional) metricsfrom nvflare.client.tracking import SummaryWriter
# (optional) set a fix place so we don't need to download everytime
DATASET_PATH ="/tmp/nvflare/data"# If available, we use GPU to speed things up.
DEVICE ="cuda"if torch.cuda.is_available()else"cpu"defmain():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
batch_size =4
epochs =2
trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
net = Net()# (2) initializes NVFlare client API
flare.init()
summary_writer = SummaryWriter()while flare.is_running():# (3) receives FLModel from NVFlare
input_model = flare.receive()print(f"current_round={input_model.current_round}")# (4) loads model from NVFlare
net.load_state_dict(input_model.params)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# (optional) use GPU to speed things up
net.to(DEVICE)# (optional) calculate total steps
steps = epochs *len(trainloader)for epoch inrange(epochs):# loop over the dataset multiple times
running_loss =0.0for i, data inenumerate(trainloader,0):# get the inputs; data is a list of [inputs, labels]# (optional) use GPU to speed things up
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)# zero the parameter gradients
optimizer.zero_grad()# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()# print statistics
running_loss += loss.item()if i %2000==1999:# print every 2000 mini-batchesprint(f"[{epoch +1}, {i +1:5d}] loss: {running_loss /2000:.3f}")
global_step = input_model.current_round * steps + epoch *len(trainloader)+ i
summary_writer.add_scalar(
tag="loss_for_each_batch",
scalar=running_loss,
global_step=global_step
)
running_loss =0.0print("Finished Training")
PATH ="./cifar_net.pth"
torch.save(net.state_dict(), PATH)# (5) wraps evaluation logic into a method to re-use for# evaluation on both trained and received modeldefevaluate(input_weights):
net = Net()
net.load_state_dict(input_weights)# (optional) use GPU to speed things up
net.to(DEVICE)
correct =0
total =0# since we're not training, we don't need to calculate the gradients for our outputswith torch.no_grad():for data in testloader:# (optional) use GPU to speed things up
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)# calculate outputs by running images through the network
outputs = net(images)# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()print(f"Accuracy of the network on the 10000 test images: {100* correct // total} %")return100* correct // total
# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params)# (7) construct trained FL model
output_model = flare.FLModel(
params=net.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},)# (8) send model back to NVFlare
flare.send(output_model)if __name__ =="__main__":
main()
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# (1) import nvflare client API
import nvflare.client as flare
# (optional) metrics
from nvflare.client.tracking import SummaryWriter
# (optional) set a fix place so we don't need to download everytime
DATASET_PATH = "/tmp/nvflare/data"
# If available, we use GPU to speed things up.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def main():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 4
epochs = 2
trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
net = Net()
# (2) initializes NVFlare client API
flare.init()
summary_writer = SummaryWriter()
while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
print(f"current_round={input_model.current_round}")
# (4) loads model from NVFlare
net.load_state_dict(input_model.params)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# (optional) use GPU to speed things up
net.to(DEVICE)
# (optional) calculate total steps
steps = epochs * len(trainloader)
for epoch in range(epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
# (optional) use GPU to speed things up
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
global_step = input_model.current_round * steps + epoch * len(trainloader) + i
summary_writer.add_scalar(
tag="loss_for_each_batch",
scalar=running_loss,
global_step=global_step
)
running_loss = 0.0
print("Finished Training")
PATH = "./cifar_net.pth"
torch.save(net.state_dict(), PATH)
# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
net = Net()
net.load_state_dict(input_weights)
# (optional) use GPU to speed things up
net.to(DEVICE)
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
# (optional) use GPU to speed things up
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
return 100 * correct // total
# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params)
# (7) construct trained FL model
output_model = flare.FLModel(
params=net.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)
# (8) send model back to NVFlare
flare.send(output_model)
if __name__ == "__main__":
main()
Client Code (cifar10_lightning_fl.py)
We use the Client API to convert the centralized training PyTorch Lightning code into federated learning code with only a few lines of changes highlighted below. Essentially the client will receive a model from NVIDIA FLARE, perform local training and validation, and then send the model back.
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule, LightningModule Trainer, seed_everything
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
classLitNet(LightningModule):def__init__(self):super().__init__()
self.save_hyperparameters()
self.model = Net()
self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)# (optional) pass additional information via self.__fl_meta__
self.__fl_meta__ ={}defforward(self, x):
out = self.model(x)return out
deftraining_step(self, batch, batch_idx):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.train_acc(outputs, labels)
self.log("train_loss", loss)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)return loss
defevaluate(self, batch, stage=None):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.valid_acc(outputs, labels)if stage:
self.log(f"{stage}_loss", loss)
self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)return outputs
defvalidation_step(self, batch, batch_idx):
self.evaluate(batch,"val")deftest_step(self, batch, batch_idx):
self.evaluate(batch,"test")defpredict_step(self, batch: Any, batch_idx:int, dataloader_idx:int=0)-> Any:return self.evaluate(batch)defconfigure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)return{"optimizer": optimizer}classCIFAR10DataModule(LightningDataModule):def__init__(self, data_dir:str= DATASET_PATH, batch_size:int= BATCH_SIZE):super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
defprepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)defsetup(self, stage:str):# Assign train/val datasets for use in dataloadersif stage =="fit"or stage =="validate":
cifar_full = torchvision.datasets.CIFAR10(
root=self.data_dir, train=True, download=False, transform=transform
)
self.cifar_train, self.cifar_val = random_split(cifar_full,[0.8,0.2])# Assign test dataset for use in dataloader(s)if stage =="test"or stage =="predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)deftrain_dataloader(self):return DataLoader(self.cifar_train, batch_size=self.batch_size)defval_dataloader(self):return DataLoader(self.cifar_val, batch_size=self.batch_size)deftest_dataloader(self):return DataLoader(self.cifar_test, batch_size=self.batch_size)defpredict_dataloader(self):return DataLoader(self.cifar_test, batch_size=self.batch_size)# (1) import nvflare lightning client APIimport nvflare.client.lightning as flare
seed_everything(7)
DATASET_PATH ="/tmp/nvflare/data"
BATCH_SIZE =4
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])defmain():
model = LitNet()
cifar10_dm = CIFAR10DataModule()if torch.cuda.is_available():
trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1if torch.cuda.is_available()elseNone)else:
trainer = Trainer(max_epochs=1, devices=None)# (2) patch the lightning trainer
flare.patch(trainer)while flare.is_running():# (3) receives FLModel from NVFlare# Note that we don't need to pass this input_model to trainer# because after flare.patch the trainer.fit/validate will get the# global model internally
input_model = flare.receive()print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")# (4) evaluate the current global model to allow server-side model selectionprint("--- validate global model ---")
trainer.validate(model, datamodule=cifar10_dm)# perform local training starting with the received global modelprint("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)# test local modelprint("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)# get predictionsprint("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)if __name__ =="__main__":
main()
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule, LightningModule Trainer, seed_everything
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
class LitNet(LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.model = Net()
self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
# (optional) pass additional information via self.__fl_meta__
self.__fl_meta__ = {}
def forward(self, x):
out = self.model(x)
return out
def training_step(self, batch, batch_idx):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.train_acc(outputs, labels)
self.log("train_loss", loss)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
return loss
def evaluate(self, batch, stage=None):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.valid_acc(outputs, labels)
if stage:
self.log(f"{stage}_loss", loss)
self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
return outputs
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.evaluate(batch)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
return {"optimizer": optimizer}
class CIFAR10DataModule(LightningDataModule):
def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage == "validate":
cifar_full = torchvision.datasets.CIFAR10(
root=self.data_dir, train=True, download=False, transform=transform
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)
# (1) import nvflare lightning client API
import nvflare.client.lightning as flare
seed_everything(7)
DATASET_PATH = "/tmp/nvflare/data"
BATCH_SIZE = 4
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def main():
model = LitNet()
cifar10_dm = CIFAR10DataModule()
if torch.cuda.is_available():
trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1 if torch.cuda.is_available() else None)
else:
trainer = Trainer(max_epochs=1, devices=None)
# (2) patch the lightning trainer
flare.patch(trainer)
while flare.is_running():
# (3) receives FLModel from NVFlare
# Note that we don't need to pass this input_model to trainer
# because after flare.patch the trainer.fit/validate will get the
# global model internally
input_model = flare.receive()
print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")
# (4) evaluate the current global model to allow server-side model selection
print("--- validate global model ---")
trainer.validate(model, datamodule=cifar10_dm)
# perform local training starting with the received global model
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)
# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)
# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)
if __name__ == "__main__":
main()
Client Code (cifar10_tf_fl.py)
We use the Client API to convert the centralized training TensorFlow code into federated learning code with only a few lines of changes highlighted below. Essentially the client will receive a model from NVIDIA FLARE, perform local training and validation, and then send the model back.
from tensorflow.keras import datasets
from tensorflow.keras import Model, layers, losses
classTFNet(Model):def__init__(self, input_shape):super().__init__()
self._input_shape = input_shape # Required to get constructor arguments in job config
self.conv1 = layers.Conv2D(6,5, activation="relu")
self.pool = layers.MaxPooling2D((2,2),2)
self.conv2 = layers.Conv2D(16,5, activation="relu")
self.flatten = layers.Flatten()
self.fc1 = layers.Dense(120, activation="relu")
self.fc2 = layers.Dense(84, activation="relu")
self.fc3 = layers.Dense(10)
loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
self.compile(optimizer="sgd", loss=loss_fn, metrics=["accuracy"])
self.build(input_shape)defcall(self, x):
x = self.pool(self.conv1(x))
x = self.pool(self.conv2(x))
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)return x
# (1) import nvflare client APIimport nvflare.client as flare
PATH ="./tf_model.ckpt"defmain():(train_images, train_labels),(test_images, test_labels)= datasets.cifar10.load_data()# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images /255.0, test_images /255.0
model = TFNet(input_shape=(None,32,32,3))
model.summary()# (2) initializes NVFlare client API
flare.init()# (3) gets FLModel from NVFlarewhile flare.is_running():
input_model = flare.receive()print(f"current_round={input_model.current_round}")# (optional) print system info
system_info = flare.system_info()print(f"NVFlare system info: {system_info}")# (4) loads model from NVFlarefor k, v in input_model.params.items():
model.get_layer(k).set_weights(v)# (5) evaluate aggregated/received model
_, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)print(
f"Accuracy of the received model on round{input_model.current_round} on the 10000 test images:{test_global_acc *100}%"
)
model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))print("Finished Training")
model.save_weights(PATH)
_, test_acc = model.evaluate(test_images, test_labels, verbose=2)print(f"Accuracy of the model on the 10000 test images: {test_acc *100} %")# (6) construct trained FL model (A dict of {layer name: layer weights} from the keras model)
output_model = flare.FLModel(
params={layer.name: layer.get_weights()for layer in model.layers},
metrics={"accuracy": test_global_acc})# (7) send model back to NVFlare
flare.send(output_model)if __name__ =="__main__":
main()
from tensorflow.keras import datasets
from tensorflow.keras import Model, layers, losses
class TFNet(Model):
def __init__(self, input_shape):
super().__init__()
self._input_shape = input_shape # Required to get constructor arguments in job config
self.conv1 = layers.Conv2D(6, 5, activation="relu")
self.pool = layers.MaxPooling2D((2, 2), 2)
self.conv2 = layers.Conv2D(16, 5, activation="relu")
self.flatten = layers.Flatten()
self.fc1 = layers.Dense(120, activation="relu")
self.fc2 = layers.Dense(84, activation="relu")
self.fc3 = layers.Dense(10)
loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
self.compile(optimizer="sgd", loss=loss_fn, metrics=["accuracy"])
self.build(input_shape)
def call(self, x):
x = self.pool(self.conv1(x))
x = self.pool(self.conv2(x))
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
# (1) import nvflare client API
import nvflare.client as flare
PATH = "./tf_model.ckpt"
def main():
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
model = TFNet(input_shape=(None, 32, 32, 3))
model.summary()
# (2) initializes NVFlare client API
flare.init()
# (3) gets FLModel from NVFlare
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")
# (optional) print system info
system_info = flare.system_info()
print(f"NVFlare system info: {system_info}")
# (4) loads model from NVFlare
for k, v in input_model.params.items():
model.get_layer(k).set_weights(v)
# (5) evaluate aggregated/received model
_, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
print(
f"Accuracy of the received model on round {input_model.current_round} on the 10000 test images:
{test_global_acc * 100} %"
)
model.fit(train_images, train_labels, epochs=1, validation_data=(test_images, test_labels))
print("Finished Training")
model.save_weights(PATH)
_, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Accuracy of the model on the 10000 test images: {test_acc * 100} %")
# (6) construct trained FL model (A dict of {layer name: layer weights} from the keras model)
output_model = flare.FLModel(
params={layer.name: layer.get_weights() for layer in model.layers},
metrics={"accuracy": test_global_acc}
)
# (7) send model back to NVFlare
flare.send(output_model)
if __name__ == "__main__":
main()
Server Code (fedavg.py)
The ModelController API is used to write a federated routine with mechanisms for distributing and receiving models from clients. Here we implement the basic FedAvg algorithm using some helper functions from BaseFedAvg.
from nvflare.app_common.workflows.base_fedavg import BaseFedAvg
class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round
clients = self.sample_clients(self.num_clients)
results = self.send_model_and_wait(targets=clients, data=model)
aggregate_results = self.aggregate(results)
model = self.update_model(model, aggregate_results)
self.save_model(model)
self.info("Finished FedAvg.")
Server Code (fedavg.py)
The ModelController API is used to write a federated routine with mechanisms for distributing and receiving models from clients. Here we implement the basic FedAvg algorithm using some helper functions from BaseFedAvg..
from nvflare.app_common.workflows.base_fedavg import BaseFedAvg
class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round
clients = self.sample_clients(self.num_clients)
results = self.send_model_and_wait(targets=clients, data=model)
aggregate_results = self.aggregate(results)
model = self.update_model(model, aggregate_results)
self.save_model(model)
self.info("Finished FedAvg.")
Server Code (fedavg.py)
The ModelController API is used to write a federated routine with mechanisms for distributing and receiving models from clients. Here we implement the basic FedAvg algorithm using some helper functions from BaseFedAvg..
from nvflare.app_common.workflows.base_fedavg import BaseFedAvg
class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")
model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds
for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round
clients = self.sample_clients(self.num_clients)
results = self.send_model_and_wait(targets=clients, data=model)
aggregate_results = self.aggregate(results)
model = self.update_model(model, aggregate_results)
self.save_model(model)
self.info("Finished FedAvg.")
Job Code (fedavg_cifar10_pt_job.py)
Lastly we construct the job with our 'cifar10_pt_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.
from cifar10_pt_fl import Net
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner
if __name__ =="__main__":
n_clients =2
num_rounds =2
train_script ="cifar10_pt_fl.py"# Create BaseFedJob with initial model
job = BaseFedJob(
name="cifar10_pt_fedavg",
initial_model=Net(),)# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,)
job.to_server(controller)# Add clientsfor i inrange(n_clients):
runner = ScriptRunner(script=train_script)
job.to(runner,f"site-{i}")# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
from cifar10_pt_fl import Net
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner
if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "cifar10_pt_fl.py"
# Create BaseFedJob with initial model
job = BaseFedJob(
name="cifar10_pt_fedavg",
initial_model=Net(),
)
# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to_server(controller)
# Add clients
for i in range(n_clients):
runner = ScriptRunner(script=train_script)
job.to(runner, f"site-{i}")
# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
Job Code (fedavg_cifar10_lightning_job.py)
Lastly we construct the job with our 'cifar10_lightning_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.
from cifar10_lightning_fl import LitNet
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner
if __name__ =="__main__":
n_clients =2
num_rounds =2
train_script ="cifar10_lightning_fl.py"# Create BaseFedJob with initial model
job = BaseFedJob(
name="cifar10_lightning_fedavg",
initial_model=LitNet(),)# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,)
job.to_server(controller)# Add clientsfor i inrange(n_clients):
runner = ScriptRunner(script=train_script)
job.to(runner,f"site-{i}")# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
from cifar10_lightning_fl import LitNet
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner
if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "cifar10_lightning_fl.py"
# Create BaseFedJob with initial model
job = BaseFedJob(
name="cifar10_lightning_fedavg",
initial_model=LitNet(),
)
# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to_server(controller)
# Add clients
for i in range(n_clients):
runner = ScriptRunner(script=train_script)
job.to(runner, f"site-{i}")
# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
Job Code (fedavg_cifar10_tf_job.py)
Lastly we construct the job with our 'cifar10_tf_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.
from cifar10_tf_fl import TFNet
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner
if __name__ =="__main__":
n_clients =2
num_rounds =2
train_script ="cifar10_tf_fl.py"# Create BaseFedJob with initial model
job = BaseFedJob(
name="cifar10_tf_fedavg",
initial_model=TFNet(input_shape=(None,32,32,3)),)# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,)
job.to_server(controller)# Add clientsfor i inrange(n_clients):
runner = ScriptRunner(
script=train_script,
framework=FrameworkType.TENSORFLOW,)
job.to(runner,f"site-{i}")# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
from cifar10_tf_fl import TFNet
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.tf.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import FrameworkType, ScriptRunner
if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "cifar10_tf_fl.py"
# Create BaseFedJob with initial model
job = BaseFedJob(
name="cifar10_tf_fedavg",
initial_model=TFNet(input_shape=(None, 32, 32, 3)),
)
# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to_server(controller)
# Add clients
for i in range(n_clients):
runner = ScriptRunner(
script=train_script,
framework=FrameworkType.TENSORFLOW,
)
job.to(runner, f"site-{i}")
# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")
Run the Job
To run the job with the simulator, copy the code and execute the job script, or run in Google Colab. Alternatively, export the job to a configuration and run the job in other modes.
python3 fedavg_cifar10_pt_job.py
python3 fedavg_cifar10_pt_job.py
Run the Job
To run the job with the simulator, copy the code and execute the job script, or run in Google Colab. Alternatively, export the job to a configuration and run the job in other modes.
python3 fedavg_cifar10_lightning_job.py
python3 fedavg_cifar10_lightning_job.py
Run the Job
To run the job with the simulator, copy the code and execute the job script, or run in Google Colab. Alternatively, export the job to a configuration and run the job in other modes.
python3 fedavg_cifar10_tf_job.py
python3 fedavg_cifar10_tf_job.py
Tutorials
These progressive tutorial series cover various aspects of NVIDIA FLARE from core concepts and basic tools to advanced algorithms and deployments.
For a comprehensive list of all tutorials, view the tutorial catalog below.
explore the process of running federated learning applications using NVIDIA FLARE. We will start by setting up the environment and preparing the data, followed by training a classifier using PyTorch. Finally, we will delve into the job structure and configurations, including running a simulator, and conclude with a recap of the covered topics
Chapter 3: NVIDIA FLARE's Federated Computing System
beg.
doc
Explore the core concepts and system architecture that make NVIDIA FLARE (NVFlare) a powerful platform for federated computing, examine different aspects of the NVFlare system, learn how to simulate deployments, and discover various ways to interact with the system.
A self-paced course that introduces you to running and developing a federated application with NVIDIA FLARE. This course covers the first 6 chapters of the 12-chapter tutorial, including running, developing, deploying, monitoring of federated applications as well as privacy and security and security with notebooks and videos using DLI provided AWS resources.
Example for converting Deep Learning (DL) code to Federated Learning (FL) using the Client API. Configurations for numpy, pytorch, lighting, and tensorflow.
Demonstrate flexibility of the ModelController API, and show how to write a Federated Averaging workflow with early stopping, model selection, and saving and loading.
Provides a CLI or UI to collect information of clients and users from different organizations, and generates startup kits with keys and certificates for users to download.
End-to-End Federated XGBoost for Financial Credit Card Detection
adv.
algorithm
xgboost
finance
Show the end-to-end process of feature engineering, pre-processing and training in federated settings. You can use FLARE to perform federated ETL and then training.
Designed as a federated computing platform agnostic to frameworks, workloads, datasets, and domains. Federated learning apps are built on this foundation.
Flexible open architecture allows for extensive customization, with a modular design that ensures each layer can be easily pluggable with custom components.
Supports cross-cloud deployment with various management tools. Designed for robust, production-scale deployment in real-world federated learning scenarios.
Federated Learning is a distributed learning paradigm where training occurs across multiple clients, each with their own local datasets. This enables the creation of common robust models without sharing sensitive local data, helping solve issues of data privacy and security.