基于 PyTorch 的预训练模型在隐语联邦学习环境下的微调#
引言#
预训练模型加载和精调在机器学习中非常重要。一般来说,从头训练一个非常大的模型,不仅需要大量的算力资源,同时也需要耗费大量的时间。所以在传统的机器学习中,使用预训练模型,然后针对具体的任务做微调和迁移学习非常普遍。同样的,对于联邦学习来说,如果能够加载预训练模型进行微调和迁移学习,不仅能够节省参与方的算力资源,降低参与方的准入门槛,同时也能够加快模型的学习速度。
得益于隐语联邦学习模块优异的兼容性,使得其可以直接加载 PyTorch 的一系列预训练模型;本教程将基于 PyTorch 的 AlexNet 的微调教程展现如何基于PyTorch的预训练模型在SecretFlow的框架下进行微调,充分展现SecretFlow的易用性。
[1]:
%load_ext autoreload
%autoreload 2
加载数据集#
数据集介绍#
Flower 数据集介绍:flower 数据集是一个包含了 5 种花卉(雏菊、蒲公英、玫瑰、向日葵、郁金香)共计 4323 张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片,每张图片的分辨率为 320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是:daisy(633),dandelion(898),rose(641),sunflower(699),tulip(852)
下载地址: http://download.tensorflow.org/example_images/flower_photos.tgz
下载数据集并解压#
[2]:
# The TensorFlow interface is reused to download images , and the output is a folder, as shown in the following figure.
import tempfile
import tensorflow as tf
_temp_dir = tempfile.mkdtemp()
path_to_flower_dataset = tf.keras.utils.get_file(
"flower_photos",
"https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz",
untar=True,
cache_dir=_temp_dir,
)
2023-10-29 02:10:49.666627: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-29 02:10:50.522159: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-29 02:10:50.619157: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-10-29 02:10:50.619183: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2023-10-29 02:10:52.776194: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-10-29 02:10:52.776304: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-10-29 02:10:52.776317: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz
67588319/67588319 [==============================] - 3s 0us/step
环境设置#
首先我们初始化各个参与方。
[3]:
import secretflow as sf
# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address="local", log_to_driver=False)
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
The version of SecretFlow: 1.2.0.dev20231009
2023-10-29 02:11:04,382 INFO worker.py:1538 -- Started a local Ray instance.
定义Dataloader#
我们可以参考PyTorch下的DataBuilder教程定义我们自己的DataBuilder。
[4]:
def create_dataset_builder(
batch_size=32,
train_split=0.8,
shuffle=True,
random_seed=1234,
):
def dataset_builder(x, stage="train"):
""" """
import math
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms
# Define dataset
flower_transform = transforms.Compose(
[
transforms.Resize((180, 180)),
transforms.ToTensor(),
]
)
flower_dataset = datasets.ImageFolder(x, transform=flower_transform)
dataset_size = len(flower_dataset)
# Define sampler
indices = list(range(dataset_size))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
split = int(np.floor(train_split * dataset_size))
train_indices, val_indices = indices[:split], indices[split:]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
# Define databuilder
train_loader = DataLoader(
flower_dataset, batch_size=batch_size, sampler=train_sampler
)
valid_loader = DataLoader(
flower_dataset, batch_size=batch_size, sampler=valid_sampler
)
# Return
if stage == "train":
train_step_per_epoch = math.ceil(split / batch_size)
return train_loader, train_step_per_epoch
elif stage == "eval":
eval_step_per_epoch = math.ceil((dataset_size - split) / batch_size)
return valid_loader, eval_step_per_epoch
return dataset_builder
[5]:
# prepare dataset dict
data_builder_dict = {
alice: create_dataset_builder(
batch_size=32,
train_split=0.8,
shuffle=False,
random_seed=1234,
),
bob: create_dataset_builder(
batch_size=32,
train_split=0.8,
shuffle=False,
random_seed=1234,
),
}
加载模型#
我们只要参照教程里对模型的定义,在函数里完成我们对模型的定义即可;可以看到代码几乎不需要作任何修改,只需要进行适当的封装。并且得益于隐语优异的封装性,我们可以在定义模型很快进行进行训练,而不是需要自行编写训练和测试函数;相反如果我们自行从头开始写整个神经网络结构的话,我们需要自行参考AlexNet的源代码,将其适配在隐语的secretflow.ml.nn.utils.BaseModule;为便于对比,我们分别给出两种实现方式:
加载预训练模型#
[6]:
from secretflow.ml.nn.utils import BaseModule
from torchvision.models import alexnet
from torch import nn
class AlexNet(BaseModule):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.finetune_net = alexnet(weights='IMAGENET1K_V1')
self.finetune_net.classifier[6] = nn.Linear(4096, 5)
nn.init.xavier_uniform_(self.finetune_net.classifier[6].weight)
for named_param in self.finetune_net.named_parameters():
if 'classifier' in named_param[0]:
print('Will train', named_param[0])
named_param[1].requires_grad = True
else:
print('Won\'t train', named_param[0])
named_param[1].requires_grad = False
def forward(self, xb):
return self.finetune_net(xb)
自行编写网络结构#
[7]:
import torch
class handy_AlexNet(BaseModule):
def __init__(self, num_classes: int = 5, dropout: float = 0.5, **kwargs) -> None:
super().__init__(**kwargs)
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
for named_param in self.features.named_parameters():
named_param[1].requires_grad = True
for named_param in self.classifier.named_parameters():
named_param[1].requires_grad = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
[8]:
from secretflow.ml.nn import FLModel
from secretflow.security.aggregation import SecureAggregator
from torch import nn, optim
from torchmetrics import Accuracy, Precision
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn.utils import TorchModel
定义 FLModel 并且训练#
基于预训练模型定义 Torch 后端的 FLModel#
[9]:
device_list = [alice, bob]
aggregator = SecureAggregator(charlie, [alice, bob])
data = {
alice: path_to_flower_dataset,
bob: path_to_flower_dataset,
}
# prepare model
num_classes = 5
# torch model
loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-3)
model_def = TorchModel(
model_fn=AlexNet,
loss_fn=loss_fn,
optim_fn=optim_fn,
metrics=[
metric_wrapper(
Accuracy, task="multiclass", num_classes=num_classes, average='micro'
),
metric_wrapper(
Precision, task="multiclass", num_classes=num_classes, average='micro'
),
],
)
fed_model = FLModel(
device_list=device_list,
model=model_def,
aggregator=aggregator,
backend="torch",
strategy="fed_avg_w",
random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
基于预训练模型的 FLModel 开始训练#
[10]:
history = fed_model.fit(
data,
None,
validation_data=data,
epochs=20,
batch_size=32,
aggregate_freq=2,
sampler_method="batch",
random_seed=1234,
dp_spent_step_freq=1,
dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'x': {PYURuntime(alice): '/tmp/tmphh_f8sq7/datasets/flower_photos', PYURuntime(bob): '/tmp/tmphh_f8sq7/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 20, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmphh_f8sq7/datasets/flower_photos', PYURuntime(bob): '/tmp/tmphh_f8sq7/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7f683654edc0>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7f68365845e0>}, 'wait_steps': 100, 'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7f689e8d44c0>}
100%|██████████| 30/30 [01:53<00:00, 3.80s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
2023-10-29 02:13:26.358659: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-29 02:13:31.470244: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13757 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:3b:00.0, compute capability: 7.5
2023-10-29 02:13:31.481341: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13757 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:3c:00.0, compute capability: 7.5
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:17<00:00, 4.59s/it, epoch: 1/20 - multiclassaccuracy:0.625 multiclassprecision:0.625 val_multiclassaccuracy:0.4564315378665924 val_multiclassprecision:0.4564315378665924 ]
100%|██████████| 30/30 [01:43<00:00, 3.81s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [01:57<00:00, 3.92s/it, epoch: 2/20 - multiclassaccuracy:0.7822916507720947 multiclassprecision:0.7822916507720947 val_multiclassaccuracy:0.17842324078083038 val_multiclassprecision:0.17842324078083038 ]
100%|██████████| 30/30 [01:44<00:00, 3.82s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [01:58<00:00, 3.94s/it, epoch: 3/20 - multiclassaccuracy:0.815625011920929 multiclassprecision:0.815625011920929 val_multiclassaccuracy:0.6224066615104675 val_multiclassprecision:0.6224066615104675 ]
100%|██████████| 30/30 [01:44<00:00, 3.81s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [01:58<00:00, 3.96s/it, epoch: 4/20 - multiclassaccuracy:0.8885416388511658 multiclassprecision:0.8885416388511658 val_multiclassaccuracy:0.5352697372436523 val_multiclassprecision:0.5352697372436523 ]
100%|██████████| 30/30 [01:44<00:00, 4.07s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [01:58<00:00, 3.97s/it, epoch: 5/20 - multiclassaccuracy:0.8916666507720947 multiclassprecision:0.8916666507720947 val_multiclassaccuracy:0.7966805100440979 val_multiclassprecision:0.7966805100440979 ]
100%|██████████| 30/30 [01:54<00:00, 3.99s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:08<00:00, 4.29s/it, epoch: 6/20 - multiclassaccuracy:0.8999999761581421 multiclassprecision:0.8999999761581421 val_multiclassaccuracy:0.7178423404693604 val_multiclassprecision:0.7178423404693604 ]
100%|██████████| 30/30 [01:58<00:00, 4.97s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:16<00:00, 4.55s/it, epoch: 7/20 - multiclassaccuracy:0.9041666388511658 multiclassprecision:0.9041666388511658 val_multiclassaccuracy:0.4107883870601654 val_multiclassprecision:0.4107883870601654 ]
100%|██████████| 30/30 [01:59<00:00, 4.45s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:13<00:00, 4.46s/it, epoch: 8/20 - multiclassaccuracy:0.9437500238418579 multiclassprecision:0.9437500238418579 val_multiclassaccuracy:0.6473029255867004 val_multiclassprecision:0.6473029255867004 ]
100%|██████████| 30/30 [01:56<00:00, 4.36s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:11<00:00, 4.39s/it, epoch: 9/20 - multiclassaccuracy:0.940625011920929 multiclassprecision:0.940625011920929 val_multiclassaccuracy:0.22406639158725739 val_multiclassprecision:0.22406639158725739 ]
100%|██████████| 30/30 [02:13<00:00, 5.26s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:32<00:00, 5.09s/it, epoch: 10/20 - multiclassaccuracy:0.9541666507720947 multiclassprecision:0.9541666507720947 val_multiclassaccuracy:0.8091286420822144 val_multiclassprecision:0.8091286420822144 ]
100%|██████████| 30/30 [02:09<00:00, 4.73s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:26<00:00, 4.87s/it, epoch: 11/20 - multiclassaccuracy:0.9447916746139526 multiclassprecision:0.9447916746139526 val_multiclassaccuracy:0.5228216052055359 val_multiclassprecision:0.5228216052055359 ]
100%|██████████| 30/30 [02:03<00:00, 4.56s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:22<00:00, 4.77s/it, epoch: 12/20 - multiclassaccuracy:0.9427083134651184 multiclassprecision:0.9427083134651184 val_multiclassaccuracy:0.5933610200881958 val_multiclassprecision:0.5933610200881958 ]
100%|██████████| 30/30 [02:05<00:00, 4.60s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:21<00:00, 4.71s/it, epoch: 13/20 - multiclassaccuracy:0.9395833611488342 multiclassprecision:0.9395833611488342 val_multiclassaccuracy:0.6514523029327393 val_multiclassprecision:0.6514523029327393 ]
100%|██████████| 30/30 [02:03<00:00, 5.06s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:18<00:00, 4.61s/it, epoch: 14/20 - multiclassaccuracy:0.9385416507720947 multiclassprecision:0.9385416507720947 val_multiclassaccuracy:0.5062240958213806 val_multiclassprecision:0.5062240958213806 ]
100%|██████████| 30/30 [02:12<00:00, 4.70s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:29<00:00, 4.99s/it, epoch: 15/20 - multiclassaccuracy:0.9635416865348816 multiclassprecision:0.9635416865348816 val_multiclassaccuracy:0.4149377644062042 val_multiclassprecision:0.4149377644062042 ]
100%|██████████| 30/30 [02:11<00:00, 5.30s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:26<00:00, 4.89s/it, epoch: 16/20 - multiclassaccuracy:0.96875 multiclassprecision:0.96875 val_multiclassaccuracy:0.6763485670089722 val_multiclassprecision:0.6763485670089722 ]
100%|██████████| 30/30 [01:52<00:00, 4.39s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:07<00:00, 4.24s/it, epoch: 17/20 - multiclassaccuracy:0.9750000238418579 multiclassprecision:0.9750000238418579 val_multiclassaccuracy:0.8713693022727966 val_multiclassprecision:0.8713693022727966 ]
100%|██████████| 30/30 [02:09<00:00, 5.34s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:25<00:00, 4.84s/it, epoch: 18/20 - multiclassaccuracy:0.9489583373069763 multiclassprecision:0.9489583373069763 val_multiclassaccuracy:0.6182572841644287 val_multiclassprecision:0.6182572841644287 ]
100%|██████████| 30/30 [02:09<00:00, 5.03s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:23<00:00, 4.79s/it, epoch: 19/20 - multiclassaccuracy:0.9739583134651184 multiclassprecision:0.9739583134651184 val_multiclassaccuracy:0.39419087767601013 val_multiclassprecision:0.39419087767601013 ]
100%|██████████| 30/30 [02:02<00:00, 4.75s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:16<00:00, 4.55s/it, epoch: 20/20 - multiclassaccuracy:0.9489583373069763 multiclassprecision:0.9489583373069763 val_multiclassaccuracy:0.5643153786659241 val_multiclassprecision:0.5643153786659241 ]
训练结果可视化#
[11]:
from matplotlib import pyplot as plt
# Draw accuracy values for training & validation
plt.plot(history.global_history['multiclassaccuracy'])
plt.plot(history.global_history['val_multiclassaccuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw precision for training & validation
plt.plot(history.global_history['multiclassprecision'])
plt.plot(history.global_history['val_multiclassprecision'])
plt.title('FLModel multiclassprecision')
plt.ylabel('precision')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
基于自编写网络定义 Torch 后端的 FLModel#
[12]:
model_def = TorchModel(
model_fn=handy_AlexNet,
loss_fn=loss_fn,
optim_fn=optim_fn,
metrics=[
metric_wrapper(
Accuracy, task="multiclass", num_classes=num_classes, average='micro'
),
metric_wrapper(
Precision, task="multiclass", num_classes=num_classes, average='micro'
),
],
)
fed_model = FLModel(
device_list=device_list,
model=model_def,
aggregator=aggregator,
backend="torch",
strategy="fed_avg_w",
random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
基于自编写网络模型的 FLModel 开始训练#
[13]:
history = fed_model.fit(
data,
None,
validation_data=data,
epochs=20,
batch_size=32,
aggregate_freq=2,
sampler_method="batch",
random_seed=1234,
dp_spent_step_freq=1,
dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'x': {PYURuntime(alice): '/tmp/tmphh_f8sq7/datasets/flower_photos', PYURuntime(bob): '/tmp/tmphh_f8sq7/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 20, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmphh_f8sq7/datasets/flower_photos', PYURuntime(bob): '/tmp/tmphh_f8sq7/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7f683654edc0>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7f68365845e0>}, 'wait_steps': 100, 'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7f689e53fac0>}
100%|██████████| 30/30 [02:29<00:00, 5.61s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:47<00:00, 5.57s/it, epoch: 1/20 - multiclassaccuracy:0.2979166805744171 multiclassprecision:0.2979166805744171 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:28<00:00, 5.87s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:45<00:00, 5.51s/it, epoch: 2/20 - multiclassaccuracy:0.39375001192092896 multiclassprecision:0.39375001192092896 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:22<00:00, 4.90s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:38<00:00, 5.28s/it, epoch: 3/20 - multiclassaccuracy:0.4010416567325592 multiclassprecision:0.4010416567325592 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:18<00:00, 4.91s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:35<00:00, 5.19s/it, epoch: 4/20 - multiclassaccuracy:0.47083333134651184 multiclassprecision:0.47083333134651184 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:21<00:00, 5.26s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:36<00:00, 5.21s/it, epoch: 5/20 - multiclassaccuracy:0.47083333134651184 multiclassprecision:0.47083333134651184 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:21<00:00, 5.45s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:39<00:00, 5.30s/it, epoch: 6/20 - multiclassaccuracy:0.45520833134651184 multiclassprecision:0.45520833134651184 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:38<00:00, 6.13s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:58<00:00, 5.95s/it, epoch: 7/20 - multiclassaccuracy:0.5197916626930237 multiclassprecision:0.5197916626930237 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:35<00:00, 5.92s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:51<00:00, 5.70s/it, epoch: 8/20 - multiclassaccuracy:0.5385416746139526 multiclassprecision:0.5385416746139526 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:29<00:00, 5.36s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:47<00:00, 5.57s/it, epoch: 9/20 - multiclassaccuracy:0.5333333611488342 multiclassprecision:0.5333333611488342 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:27<00:00, 6.01s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:44<00:00, 5.47s/it, epoch: 10/20 - multiclassaccuracy:0.5677083134651184 multiclassprecision:0.5677083134651184 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:35<00:00, 5.79s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:54<00:00, 5.83s/it, epoch: 11/20 - multiclassaccuracy:0.5895833373069763 multiclassprecision:0.5895833373069763 val_multiclassaccuracy:0.01244813296943903 val_multiclassprecision:0.01244813296943903 ]
100%|██████████| 30/30 [02:36<00:00, 5.92s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:56<00:00, 5.87s/it, epoch: 12/20 - multiclassaccuracy:0.6156250238418579 multiclassprecision:0.6156250238418579 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:28<00:00, 5.86s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:43<00:00, 5.46s/it, epoch: 13/20 - multiclassaccuracy:0.6260416507720947 multiclassprecision:0.6260416507720947 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:40<00:00, 6.08s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:58<00:00, 5.96s/it, epoch: 14/20 - multiclassaccuracy:0.640625 multiclassprecision:0.640625 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:36<00:00, 5.90s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:54<00:00, 5.81s/it, epoch: 15/20 - multiclassaccuracy:0.640625 multiclassprecision:0.640625 val_multiclassaccuracy:0.4730290472507477 val_multiclassprecision:0.4730290472507477 ]
100%|██████████| 30/30 [02:27<00:00, 5.60s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:43<00:00, 5.44s/it, epoch: 16/20 - multiclassaccuracy:0.6395833492279053 multiclassprecision:0.6395833492279053 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:27<00:00, 5.12s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:42<00:00, 5.42s/it, epoch: 17/20 - multiclassaccuracy:0.6979166865348816 multiclassprecision:0.6979166865348816 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:35<00:00, 6.70s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:52<00:00, 5.75s/it, epoch: 18/20 - multiclassaccuracy:0.7197916507720947 multiclassprecision:0.7197916507720947 val_multiclassaccuracy:0.24066390097141266 val_multiclassprecision:0.24066390097141266 ]
100%|██████████| 30/30 [02:25<00:00, 5.47s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:42<00:00, 5.41s/it, epoch: 19/20 - multiclassaccuracy:0.734375 multiclassprecision:0.734375 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
100%|██████████| 30/30 [02:40<00:00, 5.66s/it]WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
WARNING:root:Please pay attention to local metrics, global only do naive aggregation
100%|██████████| 30/30 [02:58<00:00, 5.95s/it, epoch: 20/20 - multiclassaccuracy:0.6635416746139526 multiclassprecision:0.6635416746139526 val_multiclassaccuracy:0.0 val_multiclassprecision:0.0 ]
训练结果可视化#
[14]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['multiclassaccuracy'])
plt.plot(history.global_history['val_multiclassaccuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw precision for training & validation
plt.plot(history.global_history['multiclassprecision'])
plt.plot(history.global_history['val_multiclassprecision'])
plt.title('FLModel multiclassprecision')
plt.ylabel('precision')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
可以看到,在同样的任务,同样的模型上,我们加载预训练模型,不仅能省时省力,还能获得更好的模型性能。
小结#
隐语能够无缝地兼容基于 PyTorch 预训练模型,我们可以不需要自己再重新写出复杂网络的模型结构,这对于大型网络结构可以起到省时省力的效果。并且通过加载预训练模型的权重,可以让我们的模型性能更优秀。
本篇教程,我们以 AlexNet 为例介绍了如何在隐语的联邦学习模式下基于直接加载 PyTorch 的预训练模型,通过直接加载预训练模型,我们能够获得: - 不需要再次编写复杂模型的结构代码 - 基于预训练模型进行微调和迁移学习 - 使用预训练权重模型能够使得联邦模型获得更好的性能