在 SecretFlow 中使用自定义 DataBuilder(Torch)#

下面的代码只是演示。出于系统安全考虑,请不要直接在生产中使用。

本教程将展示下,怎样在 SecretFlow 的多方安全环境中,如何使用自定义 DataBuilder 模式加载数据,并训练模型。

本教程将使用 Flower 数据集的图像分类任务来进行介绍,如何使用自定义 DataBuilder 完成联邦学习。

环境设置#

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address="local", log_to_driver=False)
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
2023-04-17 15:14:51,955 INFO worker.py:1538 -- Started a local Ray instance.

接口介绍#

我们在 SecretFlow 的 FLModel 中支持了自定义 DataBuilder 的读取方式,可以方便用户根据需求更灵活的处理数据输入。

下面我们以一个例子来展示下,如何使用自定义 DataBuilder 来进行联邦模型训练。

使用 DataBuilder 的步骤:

  1. 使用单机版本 PyTorch 引擎进行开发,完成 PyTorch 下构建 DataLoader 的 DataBuilder 函数。注:dataset_builder 函数需要传入 stage 参数

  2. 将各方的 DataBuilder 函数进行 wrap,得到 create_dataset_builder。

  3. 构造 data_builder_dict [PYU, dataset_builder]。

  4. 将得到的 data_builder_dict 作为参数传入 fit 函数的 dataset_builder 。此时 x 参数位置传入 dataset_builder 中需要的输入。(比如:本例中传入的输入是实际使用的图像路径)。

在 FLModel 中使用 DataBuilder 需要预先定义 databuilder dict。需要能够返回 tf.datasetsteps_per_epoch 。而且各方返回的 steps_per_epoch 必须保持一致。

data_builder_dict =
        {
            alice: create_alice_dataset_builder(
                batch_size=32,
            ), # create_alice_dataset_builder must return (Dataset, steps_per_epoch)
            bob: create_bob_dataset_builder(
                batch_size=32,
            ), # create_bob_dataset_builder must return (Dataset, steps_per_epochstep_per_epochs)
        }

下载数据#

Flower 数据集介绍:flower 数据集是一个包含了 5 种花卉(雏菊、蒲公英、玫瑰、向日葵、郁金香)共计 4323 张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片,每张图片的分辨率为 320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是:daisy(633),dandelion(898),rose(641),sunflower(699),tulip(852)

Download link:http://download.tensorflow.org/example_images/flower_photos.tgz

flower_dataset_demo.png

下载数据并解压#

[3]:
# The TensorFlow interface is reused to download images , and the output is a folder, as shown in the following figure.
import tempfile
import tensorflow as tf

_temp_dir = tempfile.mkdtemp()
path_to_flower_dataset = tf.keras.utils.get_file(
    "flower_photos",
    "https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz",
    untar=True,
    cache_dir=_temp_dir,
)
Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz
67588319/67588319 [==============================] - 1s 0us/step

接下来我们开始构造自定义 DataBuilder。#

1. 使用单机引擎开发 DataBuilder。#

我们在开发 DataBuilder 的时候可以自由的按照单机开发的逻辑即可。目的是构建一个 TorchDataloader 对象即可。

[4]:
import math

import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms

# parameter
batch_size = 32
shuffle = True
random_seed = 1234
train_split = 0.8

# Define dataset
flower_transform = transforms.Compose(
    [
        transforms.Resize((180, 180)),
        transforms.ToTensor(),
    ]
)
flower_dataset = datasets.ImageFolder(
    path_to_flower_dataset, transform=flower_transform
)
dataset_size = len(flower_dataset)
# Define sampler

indices = list(range(dataset_size))
if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
split = int(np.floor(train_split * dataset_size))
train_indices, val_indices = indices[:split], indices[split:]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

# Define databuilder
train_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler=train_sampler)
valid_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler=valid_sampler)
[5]:
x, y = next(iter(train_loader))
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
x.shape = torch.Size([32, 3, 180, 180])
y.shape = torch.Size([32])

2. 将开发完成的 DataBuilder 进行包装(wrap)#

我们开发好的 DataBuilder 在运行是需要分发到各个执行机器上去执行,为了序列化,我们需要把他们进行 wrap。

需要注意的是:

  • FLModel 要求 DataBuilder 的输入必须包含 stage 参数(stage=“train”)

  • FLModel 要求传入的 DataBuilder 需要返回两个结果(data_setsteps_per_epoch

[6]:
def create_dataset_builder(
    batch_size=32,
    train_split=0.8,
    shuffle=True,
    random_seed=1234,
):
    def dataset_builder(x, stage="train"):
        """ """
        import math

        import numpy as np
        from torch.utils.data import DataLoader
        from torch.utils.data.sampler import SubsetRandomSampler
        from torchvision import datasets, transforms

        # Define dataset
        flower_transform = transforms.Compose(
            [
                transforms.Resize((180, 180)),
                transforms.ToTensor(),
            ]
        )
        flower_dataset = datasets.ImageFolder(x, transform=flower_transform)
        dataset_size = len(flower_dataset)
        # Define sampler

        indices = list(range(dataset_size))
        if shuffle:
            np.random.seed(random_seed)
            np.random.shuffle(indices)
        split = int(np.floor(train_split * dataset_size))
        train_indices, val_indices = indices[:split], indices[split:]
        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(val_indices)

        # Define databuilder
        train_loader = DataLoader(
            flower_dataset, batch_size=batch_size, sampler=train_sampler
        )
        valid_loader = DataLoader(
            flower_dataset, batch_size=batch_size, sampler=valid_sampler
        )

        # Return
        if stage == "train":
            train_step_per_epoch = math.ceil(split / batch_size)

            return train_loader, train_step_per_epoch
        elif stage == "eval":
            eval_step_per_epoch = math.ceil((dataset_size - split) / batch_size)
            return valid_loader, eval_step_per_epoch

    return dataset_builder

3. 构建 dataset_builder_dict。#

[7]:
# prepare dataset dict
data_builder_dict = {
    alice: create_dataset_builder(
        batch_size=32,
        train_split=0.8,
        shuffle=False,
        random_seed=1234,
    ),
    bob: create_dataset_builder(
        batch_size=32,
        train_split=0.8,
        shuffle=False,
        random_seed=1234,
    ),
}

4. 得到 dataset_builder_dict,我们就可以使用它进行联邦训练了。#

接下来我们定义一个 Torch 后端的 FLModel 来进行训练。

定义模型架构#

[8]:
from secretflow.ml.nn.utils import BaseModule


class ConvRGBNet(BaseModule):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.network = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 45 * 45, 128),
            nn.ReLU(),
            nn.Linear(128, 5),
        )

    def forward(self, xb):
        return self.network(xb)
[9]:
from secretflow.ml.nn import FLModel
from secretflow.security.aggregation import SecureAggregator
from torch import nn, optim
from torchmetrics import Accuracy, Precision
from secretflow.ml.nn.fl.utils import metric_wrapper, optim_wrapper
from secretflow.ml.nn.utils import TorchModel
[10]:
device_list = [alice, bob]
aggregator = SecureAggregator(charlie, [alice, bob])
# prepare model
num_classes = 5

input_shape = (180, 180, 3)
# torch model
loss_fn = nn.CrossEntropyLoss
optim_fn = optim_wrapper(optim.Adam, lr=1e-3)
model_def = TorchModel(
    model_fn=ConvRGBNet,
    loss_fn=loss_fn,
    optim_fn=optim_fn,
    metrics=[
        metric_wrapper(
            Accuracy, task="multiclass", num_classes=num_classes, average='micro'
        ),
        metric_wrapper(
            Precision, task="multiclass", num_classes=num_classes, average='micro'
        ),
    ],
)

fed_model = FLModel(
    device_list=device_list,
    model=model_def,
    aggregator=aggregator,
    backend="torch",
    strategy="fed_avg_w",
    random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party bob.

我们构造好的 dataset builder 的输入是图像数据集的路径,所以这里需要将输入的数据设置为一个 Dict

data = {
    alice: folder_path_of_alice,
    bob: folder_path_of_bob
}
[11]:
data = {
    alice: path_to_flower_dataset,
    bob: path_to_flower_dataset,
}
history = fed_model.fit(
    data,
    None,
    validation_data=data,
    epochs=5,
    batch_size=32,
    aggregate_freq=2,
    sampler_method="batch",
    random_seed=1234,
    dp_spent_step_freq=1,
    dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7ff4efc87af0>, 'x': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 5, 'verbose': 1, 'callbacks': None, 'validation_data': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {alice: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff600fb7ee0>, bob: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff6007148b0>}}
100%|██████████| 30/30 [00:32<00:00,  1.08s/it, epoch: 1/5 -  multiclassaccuracy:0.3760416805744171  multiclassprecision:0.3760416805744171  val_multiclassaccuracy:0.0  val_multiclassprecision:0.0 ]
100%|██████████| 8/8 [00:10<00:00,  1.27s/it, epoch: 2/5 -  multiclassaccuracy:0.5078125  multiclassprecision:0.5078125  val_multiclassaccuracy:0.1618257313966751  val_multiclassprecision:0.1618257313966751 ]
100%|██████████| 8/8 [00:10<00:00,  1.28s/it, epoch: 3/5 -  multiclassaccuracy:0.51171875  multiclassprecision:0.51171875  val_multiclassaccuracy:0.004149377811700106  val_multiclassprecision:0.004149377811700106 ]
100%|██████████| 8/8 [00:10<00:00,  1.27s/it, epoch: 4/5 -  multiclassaccuracy:0.5390625  multiclassprecision:0.5390625  val_multiclassaccuracy:0.02074688859283924  val_multiclassprecision:0.02074688859283924 ]
100%|██████████| 8/8 [00:10<00:00,  1.28s/it, epoch: 5/5 -  multiclassaccuracy:0.5703125  multiclassprecision:0.5703125  val_multiclassaccuracy:0.016597511246800423  val_multiclassprecision:0.016597511246800423 ]