Federated Learning with Pytorch Backend#

The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.

In this tutorial, We will walk you through how to use pytorch backend on SecretFlow for federated learning. + We will use the image clasification task as example + Use pytorch as backend + We will show how to use multi fl strategy

If you want to learn more about federated learning, datasets, etc., you can move to Federated Learning for Image Classification

Here we go!

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob', 'charlie'], num_cpus=8, log_to_driver=False)
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
2022-08-31 23:43:03.362818: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
BaseModule: Similar to the torch.nn.module
TorchModel: A wrap class include loss_fn,optim_fn,model_def,metrics
metric_wrapper: Wrap metrics to workers
optim_wrapper: Wrap optim_fn to workers
FLModel: Federated model, use backend to specify which bachend will be use, use strategy to spcify which federated strategy will be use
[3]:
from secretflow.ml.nn.fl.backend.torch.utils import BaseModule, TorchModel
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn import FLModel
from torchmetrics import Accuracy, Precision
from secretflow.security.aggregation import SecureAggregator
from secretflow.utils.simulation.datasets import load_mnist
from torch import nn, optim
from torch.nn import functional as F

When we define the model, we only need to inherit BaseModule instead of nn.Module, and the others are consistent with pytorch

[4]:

class ConvNet(BaseModule): """Small ConvNet for MNIST.""" def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 3, kernel_size=3) self.fc_in_dim = 192 self.fc = nn.Linear(self.fc_in_dim, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 3)) x = x.view(-1, self.fc_in_dim) x = self.fc(x) return F.softmax(x, dim=1)

We can continue to use the loss function and optimizer defined in pytorch, the only difference is that we need to wrap it with the wrapper provided in secretflow

[5]:
(train_data, train_label), (test_data, test_label) = load_mnist(
    parts={alice: 0.4, bob: 0.6},
    normalized_x=True,
    categorical_y=True,
    is_torch=True,
)

loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-2)
model_def = TorchModel(
    model_fn=ConvNet,
    loss_fn=loss_fn,
    optim_fn=optim_fn,
    metrics=[
        metric_wrapper(Accuracy, num_classes=10, average='micro'),
        metric_wrapper(Precision, num_classes=10, average='micro'),
    ],
)
[6]:
device_list = [alice, bob]
server = charlie
aggregator = SecureAggregator(server,[alice,bob])

# spcify params
fl_model = FLModel(
    server=server,
    device_list=device_list,
    model=model_def,
    aggregator=aggregator,
    strategy='fed_avg_w', # fl strategy
    backend="torch", # backend support ['tensorflow', 'torch']
)
[7]:
history = fl_model.fit(
            train_data,
            train_label,
            validation_data=(test_data, test_label),
            epochs=20,
            batch_size=32,
            aggregate_freq=1,
        )
100%|█████████▉| 749/750 [00:15<00:00, 47.41it/s]2022-08-31 23:43:34.759168: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
2022-08-31 23:43:34.759205: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
100%|██████████| 750/750 [00:16<00:00, 46.78it/s, epoch: 1/20 -  accuracy:0.9709533452987671  precision:0.8571249842643738  val_accuracy:0.9840199947357178  val_precision:0.8955000042915344 ]
100%|██████████| 125/125 [00:03<00:00, 31.28it/s, epoch: 2/20 -  accuracy:0.9825800061225891  precision:0.9190000295639038  val_accuracy:0.9850000143051147  val_precision:0.903249979019165 ]
100%|██████████| 125/125 [00:02<00:00, 41.70it/s, epoch: 3/20 -  accuracy:0.9850000143051147  precision:0.9302499890327454  val_accuracy:0.9856399893760681  val_precision:0.906499981880188 ]
100%|██████████| 125/125 [00:03<00:00, 41.66it/s, epoch: 4/20 -  accuracy:0.9859799742698669  precision:0.9334999918937683  val_accuracy:0.9861800074577332  val_precision:0.9085000157356262 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 5/20 -  accuracy:0.9870200157165527  precision:0.940500020980835  val_accuracy:0.9864799976348877  val_precision:0.9097499847412109 ]
100%|██████████| 125/125 [00:02<00:00, 41.67it/s, epoch: 6/20 -  accuracy:0.987779974937439  precision:0.9422500133514404  val_accuracy:0.9869400262832642  val_precision:0.9137499928474426 ]
100%|██████████| 125/125 [00:02<00:00, 41.92it/s, epoch: 7/20 -  accuracy:0.988099992275238  precision:0.9447500109672546  val_accuracy:0.9870200157165527  val_precision:0.9139999747276306 ]
100%|██████████| 125/125 [00:03<00:00, 41.55it/s, epoch: 8/20 -  accuracy:0.9887800216674805  precision:0.9477499723434448  val_accuracy:0.986739993095398  val_precision:0.9135000109672546 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 9/20 -  accuracy:0.9892399907112122  precision:0.9502500295639038  val_accuracy:0.9868199825286865  val_precision:0.9132500290870667 ]
100%|██████████| 125/125 [00:02<00:00, 41.81it/s, epoch: 10/20 -  accuracy:0.989359974861145  precision:0.9522500038146973  val_accuracy:0.9873600006103516  val_precision:0.9175000190734863 ]
100%|██████████| 125/125 [00:03<00:00, 41.46it/s, epoch: 11/20 -  accuracy:0.9898999929428101  precision:0.953249990940094  val_accuracy:0.9874200224876404  val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.77it/s, epoch: 12/20 -  accuracy:0.990119993686676  precision:0.953499972820282  val_accuracy:0.9871600270271301  val_precision:0.9154999852180481 ]
100%|██████████| 125/125 [00:02<00:00, 41.87it/s, epoch: 13/20 -  accuracy:0.9906600117683411  precision:0.9570000171661377  val_accuracy:0.9876800179481506  val_precision:0.9202499985694885 ]
100%|██████████| 125/125 [00:03<00:00, 40.91it/s, epoch: 14/20 -  accuracy:0.9910399913787842  precision:0.9572499990463257  val_accuracy:0.9880200028419495  val_precision:0.9227499961853027 ]
100%|██████████| 125/125 [00:03<00:00, 41.49it/s, epoch: 15/20 -  accuracy:0.9903600215911865  precision:0.9542499780654907  val_accuracy:0.9878000020980835  val_precision:0.9194999933242798 ]
100%|██████████| 125/125 [00:02<00:00, 41.68it/s, epoch: 16/20 -  accuracy:0.9914000034332275  precision:0.9585000276565552  val_accuracy:0.9878799915313721  val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:02<00:00, 42.21it/s, epoch: 17/20 -  accuracy:0.9915599822998047  precision:0.9597499966621399  val_accuracy:0.988099992275238  val_precision:0.921750009059906 ]
100%|██████████| 125/125 [00:03<00:00, 41.41it/s, epoch: 18/20 -  accuracy:0.9915800094604492  precision:0.9595000147819519  val_accuracy:0.9880399703979492  val_precision:0.921500027179718 ]
100%|██████████| 125/125 [00:02<00:00, 41.83it/s, epoch: 19/20 -  accuracy:0.9916200041770935  precision:0.9605000019073486  val_accuracy:0.9887400269508362  val_precision:0.9244999885559082 ]
100%|██████████| 125/125 [00:03<00:00, 41.34it/s, epoch: 20/20 -  accuracy:0.9922599792480469  precision:0.9637500047683716  val_accuracy:0.9883599877357483  val_precision:0.922249972820282 ]
[8]:
from matplotlib import pyplot as plt

# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

../_images/tutorial_Federated_Learning_with_Pytorch_backend_13_0.png
[ ]: