基于 Keras Applications 的预训练模型在隐语联邦学习环境下的微调#

引言#

预训练模型加载和精调在机器学习中非常重要。一般来说,从头训练一个非常大的模型,不仅需要大量的算力资源,同时也需要耗费大量的时间。所以在传统的机器学习中,使用预训练模型,然后针对具体的任务做微调和迁移学习非常普遍。同样的,对于联邦学习来说,如果能够加载预训练模型进行微调和迁移学习,不仅能够节省参与方的算力资源,降低参与方的准入门槛,同时也能够加快模型的学习速度。

得益于隐语联邦学习模块优异的兼容性,使得其可以直接加载TensorFlow.Keras的一系列预训练模型;本教程将基于TensorFlow.Keras的InceptionV3微调教程展现如何基于TensorFlow.Keras的预训练模型在SecretFlow的框架下进行微调,充分展现SecretFlow的易用性。

加载数据集#

数据集介绍#

Flower 数据集介绍:flower 数据集是一个包含了 5 种花卉(雏菊、蒲公英、玫瑰、向日葵、郁金香)共计 4323 张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片,每张图片的分辨率为 320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是:daisy(633),dandelion(898),rose(641),sunflower(699),tulip(852)

下载地址: http://download.tensorflow.org/example_images/flower_photos.tgz

下载数据集并解压#

[1]:
import tempfile
import tensorflow as tf


_temp_dir = tempfile.mkdtemp()
path_to_flower_dataset = tf.keras.utils.get_file(
    "flower_photos",
    "https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz",
    untar=True,
    cache_dir=_temp_dir,
)
2023-10-11 07:11:54.892985: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-11 07:11:55.019580: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-11 07:11:57.008960: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz
67588319/67588319 [==============================] - 2s 0us/step

加载数据集#

[2]:
import math
import tensorflow as tf

img_height = 180
img_width = 180
batch_size = 32
# In this example, we use the TensorFlow interface for development.
data_set = tf.keras.utils.image_dataset_from_directory(
    path_to_flower_dataset,
    validation_split=0.2,
    subset="both",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
)
Found 1201 files belonging to 5 classes.
Using 961 files for training.
Using 240 files for validation.
2023-10-11 07:12:05.321890: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 12653 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:3b:00.0, compute capability: 7.5
2023-10-11 07:12:05.324020: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13775 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:3c:00.0, compute capability: 7.5

划分数据集#

[3]:
train_set = data_set[0]
test_set = data_set[1]

查看数据集#

[4]:
print(type(train_set), type(test_set))
<class 'tensorflow.python.data.ops.batch_op._BatchDataset'> <class 'tensorflow.python.data.ops.batch_op._BatchDataset'>
[5]:
x, y = next(iter(train_set))
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
2023-10-11 07:12:05.799561: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [961]
         [[{{node Placeholder/_4}}]]
2023-10-11 07:12:05.800177: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [961]
         [[{{node Placeholder/_0}}]]
x.shape = (32, 180, 180, 3)
y.shape = (32,)

单机模式进行微调#

单机模式下进行预训练模型的微调,基本上参考TensorFlow.Keras的官方教程,并根据数据集格式在编译模型的参数上作适当的修改,但影响不大;

微调顶部分类器#

[6]:
import matplotlib.pyplot as plt

from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D


# create the base pre-trained model
base_model = InceptionV3(weights='imagenet', include_top=False)

# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 10 classes
predictions = Dense(10, activation='softmax')(x)

# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)

# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
    layer.trainable = False

# compile the model (should be done *after* setting layers to non-trainable)
model.compile(
    optimizer='rmsprop',
    loss='sparse_categorical_crossentropy',
    metrics=["accuracy"],
)
[7]:
# train the model on the new data for a few epochs
history = model.fit(train_set, validation_data=test_set, epochs=50)

# at this point, the top layers are well trained and we can start fine-tuning
# convolutional layers from inception V3. We will freeze the bottom N layers
# and train the remaining top layers.
Epoch 1/50
2023-10-11 07:12:13.444936: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8600
2023-10-11 07:12:14.278747: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-10-11 07:12:16.692342: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x561a8438d460 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-11 07:12:16.692403: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
2023-10-11 07:12:16.692418: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
2023-10-11 07:12:17.022370: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-11 07:12:18.304269: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-10-11 07:12:18.552866: I ./tensorflow/compiler/jit/device_compiler.h:180] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
31/31 [==============================] - ETA: 0s - loss: 112.3790 - accuracy: 0.2206
2023-10-11 07:12:21.562516: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [240]
         [[{{node Placeholder/_4}}]]
2023-10-11 07:12:21.562776: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [240]
         [[{{node Placeholder/_4}}]]
31/31 [==============================] - 15s 175ms/step - loss: 112.3790 - accuracy: 0.2206 - val_loss: 34.8170 - val_accuracy: 0.1500
Epoch 2/50
31/31 [==============================] - 2s 57ms/step - loss: 18.9387 - accuracy: 0.2737 - val_loss: 21.3622 - val_accuracy: 0.2625
Epoch 3/50
31/31 [==============================] - 2s 57ms/step - loss: 14.2041 - accuracy: 0.3091 - val_loss: 16.4007 - val_accuracy: 0.2625
Epoch 4/50
31/31 [==============================] - 2s 59ms/step - loss: 11.6635 - accuracy: 0.2716 - val_loss: 14.6091 - val_accuracy: 0.2167
Epoch 5/50
31/31 [==============================] - 2s 57ms/step - loss: 8.6897 - accuracy: 0.3215 - val_loss: 7.3476 - val_accuracy: 0.4042
Epoch 6/50
31/31 [==============================] - 2s 57ms/step - loss: 8.0223 - accuracy: 0.2914 - val_loss: 3.7781 - val_accuracy: 0.3417
Epoch 7/50
31/31 [==============================] - 2s 57ms/step - loss: 6.0596 - accuracy: 0.3413 - val_loss: 6.9468 - val_accuracy: 0.2417
Epoch 8/50
31/31 [==============================] - 2s 59ms/step - loss: 4.8735 - accuracy: 0.3486 - val_loss: 14.5767 - val_accuracy: 0.2583
Epoch 9/50
31/31 [==============================] - 2s 57ms/step - loss: 3.7563 - accuracy: 0.3954 - val_loss: 7.5974 - val_accuracy: 0.2625
Epoch 10/50
31/31 [==============================] - 2s 58ms/step - loss: 3.1183 - accuracy: 0.3871 - val_loss: 9.8358 - val_accuracy: 0.2583
Epoch 11/50
31/31 [==============================] - 2s 58ms/step - loss: 3.5958 - accuracy: 0.3725 - val_loss: 3.7865 - val_accuracy: 0.2792
Epoch 12/50
31/31 [==============================] - 2s 58ms/step - loss: 2.4401 - accuracy: 0.4287 - val_loss: 5.4358 - val_accuracy: 0.2833
Epoch 13/50
31/31 [==============================] - 2s 57ms/step - loss: 2.2799 - accuracy: 0.4194 - val_loss: 4.6291 - val_accuracy: 0.2208
Epoch 14/50
31/31 [==============================] - 2s 58ms/step - loss: 2.3558 - accuracy: 0.4204 - val_loss: 4.1438 - val_accuracy: 0.2125
Epoch 15/50
31/31 [==============================] - 2s 57ms/step - loss: 1.8603 - accuracy: 0.4828 - val_loss: 6.8592 - val_accuracy: 0.2917
Epoch 16/50
31/31 [==============================] - 2s 58ms/step - loss: 1.8960 - accuracy: 0.4880 - val_loss: 2.8224 - val_accuracy: 0.3042
Epoch 17/50
31/31 [==============================] - 2s 58ms/step - loss: 1.7029 - accuracy: 0.4984 - val_loss: 3.8373 - val_accuracy: 0.2250
Epoch 18/50
31/31 [==============================] - 2s 58ms/step - loss: 1.4418 - accuracy: 0.5120 - val_loss: 3.6055 - val_accuracy: 0.3292
Epoch 19/50
31/31 [==============================] - 2s 57ms/step - loss: 1.5190 - accuracy: 0.5245 - val_loss: 3.9276 - val_accuracy: 0.3458
Epoch 20/50
31/31 [==============================] - 2s 58ms/step - loss: 1.3073 - accuracy: 0.5619 - val_loss: 6.1296 - val_accuracy: 0.1875
Epoch 21/50
31/31 [==============================] - 2s 59ms/step - loss: 1.3950 - accuracy: 0.5390 - val_loss: 3.7171 - val_accuracy: 0.2375
Epoch 22/50
31/31 [==============================] - 2s 59ms/step - loss: 1.1851 - accuracy: 0.5640 - val_loss: 5.6681 - val_accuracy: 0.1708
Epoch 23/50
31/31 [==============================] - 2s 58ms/step - loss: 1.3379 - accuracy: 0.5838 - val_loss: 2.4407 - val_accuracy: 0.2625
Epoch 24/50
31/31 [==============================] - 2s 59ms/step - loss: 1.1016 - accuracy: 0.6098 - val_loss: 3.4062 - val_accuracy: 0.2917
Epoch 25/50
31/31 [==============================] - 2s 58ms/step - loss: 1.1200 - accuracy: 0.5931 - val_loss: 2.9323 - val_accuracy: 0.3042
Epoch 26/50
31/31 [==============================] - 2s 59ms/step - loss: 1.0031 - accuracy: 0.6389 - val_loss: 2.8517 - val_accuracy: 0.2708
Epoch 27/50
31/31 [==============================] - 2s 59ms/step - loss: 1.0253 - accuracy: 0.6306 - val_loss: 4.8238 - val_accuracy: 0.3000
Epoch 28/50
31/31 [==============================] - 2s 59ms/step - loss: 1.0335 - accuracy: 0.6576 - val_loss: 3.0405 - val_accuracy: 0.2958
Epoch 29/50
31/31 [==============================] - 2s 58ms/step - loss: 1.1657 - accuracy: 0.6181 - val_loss: 3.3026 - val_accuracy: 0.2375
Epoch 30/50
31/31 [==============================] - 2s 59ms/step - loss: 0.9623 - accuracy: 0.6629 - val_loss: 3.2407 - val_accuracy: 0.2875
Epoch 31/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8993 - accuracy: 0.6920 - val_loss: 2.2036 - val_accuracy: 0.3917
Epoch 32/50
31/31 [==============================] - 2s 58ms/step - loss: 0.9263 - accuracy: 0.6816 - val_loss: 3.2231 - val_accuracy: 0.2917
Epoch 33/50
31/31 [==============================] - 2s 60ms/step - loss: 0.8958 - accuracy: 0.6930 - val_loss: 3.6673 - val_accuracy: 0.2583
Epoch 34/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8155 - accuracy: 0.7045 - val_loss: 3.7752 - val_accuracy: 0.2667
Epoch 35/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8687 - accuracy: 0.6982 - val_loss: 3.5233 - val_accuracy: 0.3708
Epoch 36/50
31/31 [==============================] - 2s 59ms/step - loss: 0.9007 - accuracy: 0.7138 - val_loss: 2.5410 - val_accuracy: 0.3542
Epoch 37/50
31/31 [==============================] - 2s 60ms/step - loss: 0.7482 - accuracy: 0.7378 - val_loss: 3.7791 - val_accuracy: 0.3583
Epoch 38/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8524 - accuracy: 0.7534 - val_loss: 2.9299 - val_accuracy: 0.3375
Epoch 39/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6633 - accuracy: 0.7607 - val_loss: 3.8067 - val_accuracy: 0.3125
Epoch 40/50
31/31 [==============================] - 2s 59ms/step - loss: 0.7970 - accuracy: 0.7336 - val_loss: 3.9137 - val_accuracy: 0.3333
Epoch 41/50
31/31 [==============================] - 2s 59ms/step - loss: 0.7073 - accuracy: 0.7648 - val_loss: 2.9778 - val_accuracy: 0.3542
Epoch 42/50
31/31 [==============================] - 2s 60ms/step - loss: 0.6207 - accuracy: 0.7815 - val_loss: 3.2358 - val_accuracy: 0.3458
Epoch 43/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6773 - accuracy: 0.7700 - val_loss: 3.8493 - val_accuracy: 0.3292
Epoch 44/50
31/31 [==============================] - 2s 60ms/step - loss: 0.5552 - accuracy: 0.7950 - val_loss: 4.1360 - val_accuracy: 0.3292
Epoch 45/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5993 - accuracy: 0.8044 - val_loss: 3.3667 - val_accuracy: 0.3625
Epoch 46/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6991 - accuracy: 0.7877 - val_loss: 2.8805 - val_accuracy: 0.3625
Epoch 47/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6570 - accuracy: 0.7908 - val_loss: 3.2625 - val_accuracy: 0.3375
Epoch 48/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5095 - accuracy: 0.8293 - val_loss: 4.0133 - val_accuracy: 0.3750
Epoch 49/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5420 - accuracy: 0.8148 - val_loss: 3.1207 - val_accuracy: 0.3042
Epoch 50/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5678 - accuracy: 0.8023 - val_loss: 3.8196 - val_accuracy: 0.3583
[8]:
history.history.keys()
[8]:
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
[9]:
# Draw accuracy values for training & validation
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Draw loss for training & validation
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
../_images/tutorial_tensorflow-finetune-inception-v3_15_0.png
../_images/tutorial_tensorflow-finetune-inception-v3_15_1.png

冻结底层网络层微调顶层网络层#

[10]:
# let's visualize layer names and layer indices to see how many layers
# we should freeze:
for i, layer in enumerate(base_model.layers):
    print(i, layer.name)
0 input_1
1 conv2d
2 batch_normalization
3 activation
4 conv2d_1
5 batch_normalization_1
6 activation_1
7 conv2d_2
8 batch_normalization_2
9 activation_2
10 max_pooling2d
11 conv2d_3
12 batch_normalization_3
13 activation_3
14 conv2d_4
15 batch_normalization_4
16 activation_4
17 max_pooling2d_1
18 conv2d_8
19 batch_normalization_8
20 activation_8
21 conv2d_6
22 conv2d_9
23 batch_normalization_6
24 batch_normalization_9
25 activation_6
26 activation_9
27 average_pooling2d
28 conv2d_5
29 conv2d_7
30 conv2d_10
31 conv2d_11
32 batch_normalization_5
33 batch_normalization_7
34 batch_normalization_10
35 batch_normalization_11
36 activation_5
37 activation_7
38 activation_10
39 activation_11
40 mixed0
41 conv2d_15
42 batch_normalization_15
43 activation_15
44 conv2d_13
45 conv2d_16
46 batch_normalization_13
47 batch_normalization_16
48 activation_13
49 activation_16
50 average_pooling2d_1
51 conv2d_12
52 conv2d_14
53 conv2d_17
54 conv2d_18
55 batch_normalization_12
56 batch_normalization_14
57 batch_normalization_17
58 batch_normalization_18
59 activation_12
60 activation_14
61 activation_17
62 activation_18
63 mixed1
64 conv2d_22
65 batch_normalization_22
66 activation_22
67 conv2d_20
68 conv2d_23
69 batch_normalization_20
70 batch_normalization_23
71 activation_20
72 activation_23
73 average_pooling2d_2
74 conv2d_19
75 conv2d_21
76 conv2d_24
77 conv2d_25
78 batch_normalization_19
79 batch_normalization_21
80 batch_normalization_24
81 batch_normalization_25
82 activation_19
83 activation_21
84 activation_24
85 activation_25
86 mixed2
87 conv2d_27
88 batch_normalization_27
89 activation_27
90 conv2d_28
91 batch_normalization_28
92 activation_28
93 conv2d_26
94 conv2d_29
95 batch_normalization_26
96 batch_normalization_29
97 activation_26
98 activation_29
99 max_pooling2d_2
100 mixed3
101 conv2d_34
102 batch_normalization_34
103 activation_34
104 conv2d_35
105 batch_normalization_35
106 activation_35
107 conv2d_31
108 conv2d_36
109 batch_normalization_31
110 batch_normalization_36
111 activation_31
112 activation_36
113 conv2d_32
114 conv2d_37
115 batch_normalization_32
116 batch_normalization_37
117 activation_32
118 activation_37
119 average_pooling2d_3
120 conv2d_30
121 conv2d_33
122 conv2d_38
123 conv2d_39
124 batch_normalization_30
125 batch_normalization_33
126 batch_normalization_38
127 batch_normalization_39
128 activation_30
129 activation_33
130 activation_38
131 activation_39
132 mixed4
133 conv2d_44
134 batch_normalization_44
135 activation_44
136 conv2d_45
137 batch_normalization_45
138 activation_45
139 conv2d_41
140 conv2d_46
141 batch_normalization_41
142 batch_normalization_46
143 activation_41
144 activation_46
145 conv2d_42
146 conv2d_47
147 batch_normalization_42
148 batch_normalization_47
149 activation_42
150 activation_47
151 average_pooling2d_4
152 conv2d_40
153 conv2d_43
154 conv2d_48
155 conv2d_49
156 batch_normalization_40
157 batch_normalization_43
158 batch_normalization_48
159 batch_normalization_49
160 activation_40
161 activation_43
162 activation_48
163 activation_49
164 mixed5
165 conv2d_54
166 batch_normalization_54
167 activation_54
168 conv2d_55
169 batch_normalization_55
170 activation_55
171 conv2d_51
172 conv2d_56
173 batch_normalization_51
174 batch_normalization_56
175 activation_51
176 activation_56
177 conv2d_52
178 conv2d_57
179 batch_normalization_52
180 batch_normalization_57
181 activation_52
182 activation_57
183 average_pooling2d_5
184 conv2d_50
185 conv2d_53
186 conv2d_58
187 conv2d_59
188 batch_normalization_50
189 batch_normalization_53
190 batch_normalization_58
191 batch_normalization_59
192 activation_50
193 activation_53
194 activation_58
195 activation_59
196 mixed6
197 conv2d_64
198 batch_normalization_64
199 activation_64
200 conv2d_65
201 batch_normalization_65
202 activation_65
203 conv2d_61
204 conv2d_66
205 batch_normalization_61
206 batch_normalization_66
207 activation_61
208 activation_66
209 conv2d_62
210 conv2d_67
211 batch_normalization_62
212 batch_normalization_67
213 activation_62
214 activation_67
215 average_pooling2d_6
216 conv2d_60
217 conv2d_63
218 conv2d_68
219 conv2d_69
220 batch_normalization_60
221 batch_normalization_63
222 batch_normalization_68
223 batch_normalization_69
224 activation_60
225 activation_63
226 activation_68
227 activation_69
228 mixed7
229 conv2d_72
230 batch_normalization_72
231 activation_72
232 conv2d_73
233 batch_normalization_73
234 activation_73
235 conv2d_70
236 conv2d_74
237 batch_normalization_70
238 batch_normalization_74
239 activation_70
240 activation_74
241 conv2d_71
242 conv2d_75
243 batch_normalization_71
244 batch_normalization_75
245 activation_71
246 activation_75
247 max_pooling2d_3
248 mixed8
249 conv2d_80
250 batch_normalization_80
251 activation_80
252 conv2d_77
253 conv2d_81
254 batch_normalization_77
255 batch_normalization_81
256 activation_77
257 activation_81
258 conv2d_78
259 conv2d_79
260 conv2d_82
261 conv2d_83
262 average_pooling2d_7
263 conv2d_76
264 batch_normalization_78
265 batch_normalization_79
266 batch_normalization_82
267 batch_normalization_83
268 conv2d_84
269 batch_normalization_76
270 activation_78
271 activation_79
272 activation_82
273 activation_83
274 batch_normalization_84
275 activation_76
276 mixed9_0
277 concatenate
278 activation_84
279 mixed9
280 conv2d_89
281 batch_normalization_89
282 activation_89
283 conv2d_86
284 conv2d_90
285 batch_normalization_86
286 batch_normalization_90
287 activation_86
288 activation_90
289 conv2d_87
290 conv2d_88
291 conv2d_91
292 conv2d_92
293 average_pooling2d_8
294 conv2d_85
295 batch_normalization_87
296 batch_normalization_88
297 batch_normalization_91
298 batch_normalization_92
299 conv2d_93
300 batch_normalization_85
301 activation_87
302 activation_88
303 activation_91
304 activation_92
305 batch_normalization_93
306 activation_85
307 mixed9_1
308 concatenate_1
309 activation_93
310 mixed10
[11]:
# we chose to train the top 2 inception blocks, i.e. we will freeze
# the first 249 layers and unfreeze the rest:
for layer in model.layers[:249]:
    layer.trainable = False
for layer in model.layers[249:]:
    layer.trainable = True
[12]:
# we need to recompile the model for these modifications to take effect
# we use SGD with a low learning rate
from tensorflow.keras.optimizers import SGD

model.compile(
    optimizer=SGD(learning_rate=0.0001, momentum=0.9),
    loss='sparse_categorical_crossentropy',
    metrics=["accuracy"],
)
[13]:
# we train our model again (this time fine-tuning the top 2 inception blocks
# alongside the top Dense layers
history = model.fit(train_set, validation_data=test_set, epochs=50)
Epoch 1/50
31/31 [==============================] - 13s 118ms/step - loss: 1.7043 - accuracy: 0.2882 - val_loss: 1.6911 - val_accuracy: 0.3458
Epoch 2/50
31/31 [==============================] - 2s 71ms/step - loss: 1.6528 - accuracy: 0.3330 - val_loss: 1.6635 - val_accuracy: 0.3375
Epoch 3/50
31/31 [==============================] - 2s 70ms/step - loss: 1.6001 - accuracy: 0.3673 - val_loss: 1.6358 - val_accuracy: 0.3250
Epoch 4/50
31/31 [==============================] - 2s 71ms/step - loss: 1.5633 - accuracy: 0.3975 - val_loss: 1.6324 - val_accuracy: 0.3208
Epoch 5/50
31/31 [==============================] - 2s 71ms/step - loss: 1.5255 - accuracy: 0.3975 - val_loss: 1.6274 - val_accuracy: 0.3250
Epoch 6/50
31/31 [==============================] - 2s 71ms/step - loss: 1.4992 - accuracy: 0.4048 - val_loss: 1.6040 - val_accuracy: 0.3167
Epoch 7/50
31/31 [==============================] - 2s 71ms/step - loss: 1.4658 - accuracy: 0.4287 - val_loss: 1.5823 - val_accuracy: 0.3292
Epoch 8/50
31/31 [==============================] - 2s 72ms/step - loss: 1.4350 - accuracy: 0.4454 - val_loss: 1.5678 - val_accuracy: 0.3208
Epoch 9/50
31/31 [==============================] - 2s 71ms/step - loss: 1.4175 - accuracy: 0.4443 - val_loss: 1.5482 - val_accuracy: 0.3250
Epoch 10/50
31/31 [==============================] - 2s 71ms/step - loss: 1.3867 - accuracy: 0.4589 - val_loss: 1.5394 - val_accuracy: 0.3292
Epoch 11/50
31/31 [==============================] - 2s 71ms/step - loss: 1.3550 - accuracy: 0.4880 - val_loss: 1.5272 - val_accuracy: 0.3167
Epoch 12/50
31/31 [==============================] - 2s 72ms/step - loss: 1.3254 - accuracy: 0.4932 - val_loss: 1.5242 - val_accuracy: 0.3250
Epoch 13/50
31/31 [==============================] - 2s 72ms/step - loss: 1.3148 - accuracy: 0.5120 - val_loss: 1.5063 - val_accuracy: 0.3458
Epoch 14/50
31/31 [==============================] - 2s 72ms/step - loss: 1.2890 - accuracy: 0.5297 - val_loss: 1.4933 - val_accuracy: 0.3500
Epoch 15/50
31/31 [==============================] - 2s 72ms/step - loss: 1.2470 - accuracy: 0.5609 - val_loss: 1.4905 - val_accuracy: 0.3333
Epoch 16/50
31/31 [==============================] - 2s 71ms/step - loss: 1.2174 - accuracy: 0.5525 - val_loss: 1.4791 - val_accuracy: 0.3458
Epoch 17/50
31/31 [==============================] - 2s 71ms/step - loss: 1.1906 - accuracy: 0.5723 - val_loss: 1.4681 - val_accuracy: 0.3500
Epoch 18/50
31/31 [==============================] - 2s 72ms/step - loss: 1.1686 - accuracy: 0.5817 - val_loss: 1.4669 - val_accuracy: 0.3500
Epoch 19/50
31/31 [==============================] - 2s 72ms/step - loss: 1.1449 - accuracy: 0.6046 - val_loss: 1.4621 - val_accuracy: 0.3625
Epoch 20/50
31/31 [==============================] - 2s 70ms/step - loss: 1.1219 - accuracy: 0.6181 - val_loss: 1.4552 - val_accuracy: 0.3667
Epoch 21/50
31/31 [==============================] - 2s 72ms/step - loss: 1.0932 - accuracy: 0.6327 - val_loss: 1.4428 - val_accuracy: 0.3708
Epoch 22/50
31/31 [==============================] - 2s 71ms/step - loss: 1.0565 - accuracy: 0.6462 - val_loss: 1.4397 - val_accuracy: 0.3625
Epoch 23/50
31/31 [==============================] - 2s 71ms/step - loss: 1.0192 - accuracy: 0.6722 - val_loss: 1.4278 - val_accuracy: 0.3792
Epoch 24/50
31/31 [==============================] - 2s 72ms/step - loss: 0.9906 - accuracy: 0.6785 - val_loss: 1.4345 - val_accuracy: 0.3833
Epoch 25/50
31/31 [==============================] - 2s 72ms/step - loss: 0.9632 - accuracy: 0.6930 - val_loss: 1.4148 - val_accuracy: 0.3875
Epoch 26/50
31/31 [==============================] - 2s 72ms/step - loss: 0.9382 - accuracy: 0.7242 - val_loss: 1.4085 - val_accuracy: 0.4000
Epoch 27/50
31/31 [==============================] - 2s 73ms/step - loss: 0.9137 - accuracy: 0.7471 - val_loss: 1.4146 - val_accuracy: 0.3958
Epoch 28/50
31/31 [==============================] - 2s 72ms/step - loss: 0.8796 - accuracy: 0.7409 - val_loss: 1.4115 - val_accuracy: 0.4042
Epoch 29/50
31/31 [==============================] - 2s 72ms/step - loss: 0.8471 - accuracy: 0.7419 - val_loss: 1.4027 - val_accuracy: 0.4167
Epoch 30/50
31/31 [==============================] - 2s 73ms/step - loss: 0.8274 - accuracy: 0.7607 - val_loss: 1.3981 - val_accuracy: 0.3958
Epoch 31/50
31/31 [==============================] - 2s 72ms/step - loss: 0.8057 - accuracy: 0.7638 - val_loss: 1.3978 - val_accuracy: 0.4125
Epoch 32/50
31/31 [==============================] - 2s 72ms/step - loss: 0.7804 - accuracy: 0.7950 - val_loss: 1.4013 - val_accuracy: 0.4083
Epoch 33/50
31/31 [==============================] - 2s 72ms/step - loss: 0.7360 - accuracy: 0.8158 - val_loss: 1.3971 - val_accuracy: 0.4125
Epoch 34/50
31/31 [==============================] - 2s 73ms/step - loss: 0.7070 - accuracy: 0.8158 - val_loss: 1.3992 - val_accuracy: 0.3875
Epoch 35/50
31/31 [==============================] - 2s 72ms/step - loss: 0.6836 - accuracy: 0.8252 - val_loss: 1.3861 - val_accuracy: 0.3958
Epoch 36/50
31/31 [==============================] - 2s 72ms/step - loss: 0.6650 - accuracy: 0.8408 - val_loss: 1.3846 - val_accuracy: 0.4042
Epoch 37/50
31/31 [==============================] - 2s 72ms/step - loss: 0.6598 - accuracy: 0.8252 - val_loss: 1.4019 - val_accuracy: 0.4250
Epoch 38/50
31/31 [==============================] - 2s 73ms/step - loss: 0.6042 - accuracy: 0.8522 - val_loss: 1.3979 - val_accuracy: 0.4208
Epoch 39/50
31/31 [==============================] - 2s 72ms/step - loss: 0.5765 - accuracy: 0.8793 - val_loss: 1.3919 - val_accuracy: 0.4042
Epoch 40/50
31/31 [==============================] - 2s 73ms/step - loss: 0.5533 - accuracy: 0.8907 - val_loss: 1.3877 - val_accuracy: 0.4042
Epoch 41/50
31/31 [==============================] - 2s 71ms/step - loss: 0.5341 - accuracy: 0.8772 - val_loss: 1.4072 - val_accuracy: 0.4083
Epoch 42/50
31/31 [==============================] - 2s 72ms/step - loss: 0.5130 - accuracy: 0.9043 - val_loss: 1.4061 - val_accuracy: 0.3833
Epoch 43/50
31/31 [==============================] - 2s 73ms/step - loss: 0.4917 - accuracy: 0.9084 - val_loss: 1.4121 - val_accuracy: 0.3750
Epoch 44/50
31/31 [==============================] - 2s 72ms/step - loss: 0.4768 - accuracy: 0.9136 - val_loss: 1.4326 - val_accuracy: 0.4042
Epoch 45/50
31/31 [==============================] - 2s 72ms/step - loss: 0.4641 - accuracy: 0.9084 - val_loss: 1.4192 - val_accuracy: 0.4083
Epoch 46/50
31/31 [==============================] - 2s 71ms/step - loss: 0.4416 - accuracy: 0.9282 - val_loss: 1.4249 - val_accuracy: 0.4042
Epoch 47/50
31/31 [==============================] - 2s 71ms/step - loss: 0.4154 - accuracy: 0.9251 - val_loss: 1.4281 - val_accuracy: 0.3958
Epoch 48/50
31/31 [==============================] - 2s 72ms/step - loss: 0.3910 - accuracy: 0.9396 - val_loss: 1.4334 - val_accuracy: 0.4000
Epoch 49/50
31/31 [==============================] - 2s 71ms/step - loss: 0.3940 - accuracy: 0.9396 - val_loss: 1.4281 - val_accuracy: 0.4000
Epoch 50/50
31/31 [==============================] - 2s 71ms/step - loss: 0.3821 - accuracy: 0.9490 - val_loss: 1.4365 - val_accuracy: 0.4250
[14]:
# Draw accuracy values for training & validation
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Draw loss for training & validation
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
../_images/tutorial_tensorflow-finetune-inception-v3_21_0.png
../_images/tutorial_tensorflow-finetune-inception-v3_21_1.png

单机模式小结#

以上我们按照官方教程在数据集Flower 上成功微调了 InceptionV3 模型,分别是微调顶部分类器冻结底层网络层微调顶层网络层。接下来我们将展示如何将单机模式下的微调拓展到联邦学习模式下进行微调。

联邦学习模式进行微调#

环境设置#

首先我们初始化各个参与方。

[15]:
%load_ext autoreload
%autoreload 2
[16]:
import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address="local", log_to_driver=False)
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
The version of SecretFlow: 1.2.0.dev20230926
2023-10-11 07:16:09,152 INFO worker.py:1538 -- Started a local Ray instance.

定义Dataloader#

我们可以参考TensorFlow下的DataBuilder教程定义我们自己的DataBuilder。

[17]:
def create_dataset_builder(
    batch_size=32,
):
    def dataset_builder(folder_path, stage="train"):
        import math

        import tensorflow as tf

        img_height = 180
        img_width = 180
        data_set = tf.keras.utils.image_dataset_from_directory(
            folder_path,
            validation_split=0.2,
            subset="both",
            seed=123,
            image_size=(img_height, img_width),
            batch_size=batch_size,
        )
        if stage == "train":
            train_dataset = data_set[0]
            train_step_per_epoch = math.ceil(len(data_set[0].file_paths) / batch_size)
            return train_dataset, train_step_per_epoch
        elif stage == "eval":
            eval_dataset = data_set[1]
            eval_step_per_epoch = math.ceil(len(data_set[1].file_paths) / batch_size)
            return eval_dataset, eval_step_per_epoch

    return dataset_builder
[18]:
data_builder_dict = {
    alice: create_dataset_builder(
        batch_size=32,
    ),
    bob: create_dataset_builder(
        batch_size=32,
    ),
}

定义 SecureAggregator#

[19]:
from secretflow.ml.nn import FLModel
from secretflow.security.aggregation import SecureAggregator

device_list = [alice, bob]
aggregator = SecureAggregator(charlie, [alice, bob])
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.

定义数据加载路径#

为了简便起见,我们在 单机模拟模式下直接加载同一处路径所对应的数据集

[20]:
data = {
    alice: path_to_flower_dataset,
    bob: path_to_flower_dataset,
}

定义联邦学习训练参数#

[21]:
epochs = 50
batch_size = 32
aggregate_freq = 2
sampler_method = "batch"
random_seed = 1234
dp_spent_step_freq = 1

微调顶部分类器#

我们只要参照教程里对模型的定义,在函数里完成我们对模型的定义即可;可以看到代码几乎不需要作任何修改,只需要进行适当的封装。 为了方便作对比实验,我们额外添加是否加载权重的选项。

[22]:
def create_inception_v3_model_classifier(num_classes, is_load_weight=True):
    def create_model():
        from tensorflow import keras

        # Create model
        # create the base pre-trained model
        if is_load_weight:
            base_model = InceptionV3(weights='imagenet', include_top=False)
        else:
            base_model = InceptionV3(weights=None, include_top=False)

        # add a global spatial average pooling layer
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        # let's add a fully-connected layer
        x = Dense(1024, activation='relu')(x)
        # and a logistic layer -- let's say we have 10 classes
        predictions = Dense(num_classes, activation='softmax')(x)

        # this is the model we will train
        model = Model(inputs=base_model.input, outputs=predictions)

        # first: train only the top layers (which were randomly initialized)
        # i.e. freeze all convolutional InceptionV3 layers
        for layer in base_model.layers:
            layer.trainable = False

        # Compile model
        model.compile(
            optimizer='rmsprop',
            loss='sparse_categorical_crossentropy',
            metrics=["accuracy"],
        )

        return model

    return create_model

加载预训练模型权重并且微调#

[23]:
# prepare model
num_classes = 5

# keras model
weight_model = create_inception_v3_model_classifier(
    num_classes=num_classes, is_load_weight=True
)


fed_model = FLModel(
    device_list=device_list,
    model=weight_model,
    aggregator=aggregator,
    backend="tensorflow",
    strategy="fed_avg_w",
    random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[24]:
history = fed_model.fit(
    data,
    None,
    validation_data=data,
    epochs=epochs,
    batch_size=batch_size,
    aggregate_freq=aggregate_freq,
    sampler_method=sampler_method,
    random_seed=random_seed,
    dp_spent_step_freq=dp_spent_step_freq,
    dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbd4bb1cb50>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:44,  3.28s/it, epoch: 1/50 -  loss:102.29530334472656  accuracy:0.25494277477264404  val_loss:36.37749099731445  val_accuracy:0.25 ]
32it [01:39,  3.10s/it, epoch: 2/50 -  loss:23.1259708404541  accuracy:0.27976685762405396  val_loss:10.871068954467773  val_accuracy:0.26249998807907104 ]
32it [01:38,  3.08s/it, epoch: 3/50 -  loss:13.160760879516602  accuracy:0.2706078290939331  val_loss:15.88563346862793  val_accuracy:0.2708333432674408 ]
32it [01:38,  3.06s/it, epoch: 4/50 -  loss:11.84687328338623  accuracy:0.2930890917778015  val_loss:27.221759796142578  val_accuracy:0.25 ]
32it [01:37,  3.06s/it, epoch: 5/50 -  loss:12.541316032409668  accuracy:0.3080765902996063  val_loss:46.98291778564453  val_accuracy:0.1458333283662796 ]
32it [01:38,  3.07s/it, epoch: 6/50 -  loss:15.415522575378418  accuracy:0.29142382740974426  val_loss:10.973555564880371  val_accuracy:0.17916665971279144 ]
32it [01:38,  3.07s/it, epoch: 7/50 -  loss:6.303822994232178  accuracy:0.3255620300769806  val_loss:8.56163501739502  val_accuracy:0.20416666567325592 ]
32it [01:54,  3.58s/it, epoch: 8/50 -  loss:5.007680892944336  accuracy:0.34637802839279175  val_loss:6.916962146759033  val_accuracy:0.3499999940395355 ]
32it [01:55,  3.61s/it, epoch: 9/50 -  loss:3.740983247756958  accuracy:0.3855120837688446  val_loss:3.589289665222168  val_accuracy:0.2083333283662796 ]
32it [01:55,  3.62s/it, epoch: 10/50 -  loss:2.5561084747314453  accuracy:0.35803496837615967  val_loss:5.7687087059021  val_accuracy:0.15833333134651184 ]
32it [01:40,  3.13s/it, epoch: 11/50 -  loss:2.6316604614257812  accuracy:0.3921732008457184  val_loss:4.043315410614014  val_accuracy:0.25 ]
32it [01:38,  3.09s/it, epoch: 12/50 -  loss:2.2317721843719482  accuracy:0.4063280522823334  val_loss:2.3408310413360596  val_accuracy:0.27916666865348816 ]
32it [01:39,  3.12s/it, epoch: 13/50 -  loss:1.6880862712860107  accuracy:0.41631972789764404  val_loss:4.694206237792969  val_accuracy:0.17083333432674408 ]
32it [01:39,  3.11s/it, epoch: 14/50 -  loss:2.0930988788604736  accuracy:0.4113239049911499  val_loss:3.81128191947937  val_accuracy:0.17916665971279144 ]
32it [01:50,  3.44s/it, epoch: 15/50 -  loss:1.8601850271224976  accuracy:0.42381349205970764  val_loss:2.3201189041137695  val_accuracy:0.28333333134651184 ]
32it [01:47,  3.36s/it, epoch: 16/50 -  loss:1.5195019245147705  accuracy:0.43880099058151245  val_loss:1.9348057508468628  val_accuracy:0.3499999940395355 ]
32it [02:12,  4.13s/it, epoch: 17/50 -  loss:1.4174845218658447  accuracy:0.47960034012794495  val_loss:2.867670774459839  val_accuracy:0.27916666865348816 ]
32it [01:45,  3.29s/it, epoch: 18/50 -  loss:1.5716767311096191  accuracy:0.4995836913585663  val_loss:2.5769591331481934  val_accuracy:0.2874999940395355 ]
32it [01:40,  3.14s/it, epoch: 19/50 -  loss:1.5218170881271362  accuracy:0.4787676930427551  val_loss:2.084118604660034  val_accuracy:0.27916666865348816 ]
32it [01:38,  3.08s/it, epoch: 20/50 -  loss:1.3385194540023804  accuracy:0.4970857501029968  val_loss:2.267355442047119  val_accuracy:0.34166666865348816 ]
32it [01:40,  3.13s/it, epoch: 21/50 -  loss:1.4143364429473877  accuracy:0.5145711898803711  val_loss:1.8168503046035767  val_accuracy:0.2916666567325592 ]
32it [01:41,  3.17s/it, epoch: 22/50 -  loss:1.223915696144104  accuracy:0.5253955125808716  val_loss:2.421297550201416  val_accuracy:0.34166666865348816 ]
32it [01:51,  3.48s/it, epoch: 23/50 -  loss:1.3997722864151  accuracy:0.5278934240341187  val_loss:2.1914479732513428  val_accuracy:0.38333332538604736 ]
32it [01:39,  3.10s/it, epoch: 24/50 -  loss:1.2561317682266235  accuracy:0.5770191550254822  val_loss:2.868597984313965  val_accuracy:0.3291666805744171 ]
32it [01:44,  3.28s/it, epoch: 25/50 -  loss:1.4317070245742798  accuracy:0.5611990094184875  val_loss:2.570885181427002  val_accuracy:0.3166666626930237 ]
32it [01:45,  3.30s/it, epoch: 26/50 -  loss:1.3605040311813354  accuracy:0.5578684210777283  val_loss:1.9396979808807373  val_accuracy:0.3083333373069763 ]
32it [01:39,  3.11s/it, epoch: 27/50 -  loss:1.1538910865783691  accuracy:0.5711906552314758  val_loss:2.737755060195923  val_accuracy:0.3333333432674408 ]
32it [01:52,  3.53s/it, epoch: 28/50 -  loss:1.295885682106018  accuracy:0.5961698293685913  val_loss:2.398267984390259  val_accuracy:0.36666667461395264 ]
32it [01:42,  3.20s/it, epoch: 29/50 -  loss:1.2216304540634155  accuracy:0.5845128893852234  val_loss:5.456529140472412  val_accuracy:0.3166666626930237 ]
32it [01:42,  3.22s/it, epoch: 30/50 -  loss:1.8455870151519775  accuracy:0.6069941520690918  val_loss:2.622253894805908  val_accuracy:0.3499999940395355 ]
32it [01:37,  3.06s/it, epoch: 31/50 -  loss:1.2096481323242188  accuracy:0.607826828956604  val_loss:2.6999380588531494  val_accuracy:0.3583333194255829 ]
32it [01:52,  3.52s/it, epoch: 32/50 -  loss:1.25503671169281  accuracy:0.6153205633163452  val_loss:3.1902101039886475  val_accuracy:0.3916666805744171 ]
32it [01:40,  3.13s/it, epoch: 33/50 -  loss:1.321710467338562  accuracy:0.6319733262062073  val_loss:2.9719605445861816  val_accuracy:0.3125 ]
32it [02:03,  3.87s/it, epoch: 34/50 -  loss:1.2177612781524658  accuracy:0.6394671201705933  val_loss:2.9436466693878174  val_accuracy:0.3499999940395355 ]
32it [01:38,  3.07s/it, epoch: 35/50 -  loss:1.2386927604675293  accuracy:0.6561198830604553  val_loss:4.200243949890137  val_accuracy:0.2666666805744171 ]
32it [01:41,  3.19s/it, epoch: 36/50 -  loss:1.5145163536071777  accuracy:0.6344712972640991  val_loss:2.730128765106201  val_accuracy:0.375 ]
32it [01:40,  3.16s/it, epoch: 37/50 -  loss:1.150541067123413  accuracy:0.6644462943077087  val_loss:3.111203670501709  val_accuracy:0.34166666865348816 ]
32it [01:50,  3.45s/it, epoch: 38/50 -  loss:1.268009066581726  accuracy:0.6477935314178467  val_loss:3.186652183532715  val_accuracy:0.2958333194255829 ]
32it [01:43,  3.23s/it, epoch: 39/50 -  loss:1.1637378931045532  accuracy:0.6727727055549622  val_loss:2.7717678546905518  val_accuracy:0.32499998807907104 ]
32it [01:43,  3.25s/it, epoch: 40/50 -  loss:1.1945478916168213  accuracy:0.6594504714012146  val_loss:4.414767265319824  val_accuracy:0.3375000059604645 ]
32it [01:41,  3.19s/it, epoch: 41/50 -  loss:1.4453009366989136  accuracy:0.6860949397087097  val_loss:3.308169364929199  val_accuracy:0.3166666626930237 ]
32it [01:43,  3.22s/it, epoch: 42/50 -  loss:1.1969460248947144  accuracy:0.6835970282554626  val_loss:3.0412395000457764  val_accuracy:0.40833333134651184 ]
32it [01:50,  3.44s/it, epoch: 43/50 -  loss:1.0718883275985718  accuracy:0.7085762023925781  val_loss:4.295853137969971  val_accuracy:0.32499998807907104 ]
32it [01:48,  3.38s/it, epoch: 44/50 -  loss:1.3032809495925903  accuracy:0.7027477025985718  val_loss:3.5949547290802  val_accuracy:0.375 ]
32it [01:47,  3.35s/it, epoch: 45/50 -  loss:1.3326811790466309  accuracy:0.6827643513679504  val_loss:3.4334940910339355  val_accuracy:0.3791666626930237 ]
32it [01:43,  3.23s/it, epoch: 46/50 -  loss:1.1082518100738525  accuracy:0.7252289652824402  val_loss:2.774402141571045  val_accuracy:0.3499999940395355 ]
32it [01:43,  3.22s/it, epoch: 47/50 -  loss:1.052221655845642  accuracy:0.6994171738624573  val_loss:2.918473958969116  val_accuracy:0.36250001192092896 ]
32it [01:40,  3.14s/it, epoch: 48/50 -  loss:1.0545026063919067  accuracy:0.7152373194694519  val_loss:3.524212121963501  val_accuracy:0.375 ]
32it [01:43,  3.22s/it, epoch: 49/50 -  loss:1.1607937812805176  accuracy:0.7243963479995728  val_loss:3.213914632797241  val_accuracy:0.3499999940395355 ]
32it [01:38,  3.08s/it, epoch: 50/50 -  loss:1.0676075220108032  accuracy:0.7277268767356873  val_loss:4.0257086753845215  val_accuracy:0.34583333134651184 ]
[25]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
../_images/tutorial_tensorflow-finetune-inception-v3_40_0.png
../_images/tutorial_tensorflow-finetune-inception-v3_40_1.png

只加载网络结构同时随机初始化#

[26]:
# keras model
no_weight_model = create_inception_v3_model_classifier(
    num_classes=num_classes, is_load_weight=False
)


fed_model = FLModel(
    device_list=device_list,
    model=no_weight_model,
    aggregator=aggregator,
    backend="tensorflow",
    strategy="fed_avg_w",
    random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[27]:
history = fed_model.fit(
    data,
    None,
    validation_data=data,
    epochs=epochs,
    batch_size=batch_size,
    aggregate_freq=aggregate_freq,
    sampler_method=sampler_method,
    random_seed=random_seed,
    dp_spent_step_freq=dp_spent_step_freq,
    dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbd4baefbe0>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:49,  3.41s/it, epoch: 1/50 -  loss:1.6169096231460571  accuracy:0.2788761854171753  val_loss:2.0081751346588135  val_accuracy:0.25 ]
32it [01:41,  3.18s/it, epoch: 2/50 -  loss:1.6776905059814453  accuracy:0.2805995047092438  val_loss:1.5821692943572998  val_accuracy:0.2083333283662796 ]
32it [01:40,  3.15s/it, epoch: 3/50 -  loss:1.5732810497283936  accuracy:0.27560365200042725  val_loss:1.6988409757614136  val_accuracy:0.25 ]
32it [01:43,  3.22s/it, epoch: 4/50 -  loss:1.5928469896316528  accuracy:0.29059118032455444  val_loss:1.6616723537445068  val_accuracy:0.25 ]
32it [01:43,  3.25s/it, epoch: 5/50 -  loss:1.5776023864746094  accuracy:0.2972522974014282  val_loss:1.5816072225570679  val_accuracy:0.25 ]
32it [01:39,  3.10s/it, epoch: 6/50 -  loss:1.5489486455917358  accuracy:0.3064113259315491  val_loss:1.5895153284072876  val_accuracy:0.25 ]
32it [01:38,  3.09s/it, epoch: 7/50 -  loss:1.5415023565292358  accuracy:0.33888426423072815  val_loss:1.6028796434402466  val_accuracy:0.2874999940395355 ]
32it [01:42,  3.21s/it, epoch: 8/50 -  loss:1.5391955375671387  accuracy:0.3488759398460388  val_loss:1.6169476509094238  val_accuracy:0.2666666805744171 ]
32it [01:39,  3.11s/it, epoch: 9/50 -  loss:1.5235228538513184  accuracy:0.37552040815353394  val_loss:1.6155515909194946  val_accuracy:0.25 ]
32it [01:42,  3.21s/it, epoch: 10/50 -  loss:1.5191270112991333  accuracy:0.3522064983844757  val_loss:1.600549340248108  val_accuracy:0.32499998807907104 ]
32it [01:39,  3.09s/it, epoch: 11/50 -  loss:1.5050208568572998  accuracy:0.3705245554447174  val_loss:1.6656410694122314  val_accuracy:0.25 ]
32it [01:45,  3.31s/it, epoch: 12/50 -  loss:1.5016016960144043  accuracy:0.3563697040081024  val_loss:1.5471644401550293  val_accuracy:0.2541666626930237 ]
32it [01:36,  3.02s/it, epoch: 13/50 -  loss:1.465255618095398  accuracy:0.37968358397483826  val_loss:1.54298996925354  val_accuracy:0.3791666626930237 ]
32it [01:42,  3.21s/it, epoch: 14/50 -  loss:1.4461535215377808  accuracy:0.40549543499946594  val_loss:1.5411031246185303  val_accuracy:0.3791666626930237 ]
32it [01:38,  3.07s/it, epoch: 15/50 -  loss:1.4370651245117188  accuracy:0.40882596373558044  val_loss:1.490233063697815  val_accuracy:0.3291666805744171 ]
32it [01:41,  3.17s/it, epoch: 16/50 -  loss:1.413556456565857  accuracy:0.4046627879142761  val_loss:1.590057373046875  val_accuracy:0.3333333432674408 ]
32it [01:42,  3.19s/it, epoch: 17/50 -  loss:1.425291895866394  accuracy:0.40549543499946594  val_loss:1.5726838111877441  val_accuracy:0.26249998807907104 ]
32it [01:46,  3.34s/it, epoch: 18/50 -  loss:1.411468267440796  accuracy:0.3930058181285858  val_loss:1.6937826871871948  val_accuracy:0.25 ]
32it [01:39,  3.10s/it, epoch: 19/50 -  loss:1.4267196655273438  accuracy:0.4004995822906494  val_loss:1.7930010557174683  val_accuracy:0.25 ]
32it [01:41,  3.16s/it, epoch: 20/50 -  loss:1.444821834564209  accuracy:0.39134055376052856  val_loss:1.909441351890564  val_accuracy:0.2708333432674408 ]
32it [01:43,  3.22s/it, epoch: 21/50 -  loss:1.4686118364334106  accuracy:0.40216487646102905  val_loss:1.622828722000122  val_accuracy:0.3083333373069763 ]
32it [01:39,  3.10s/it, epoch: 22/50 -  loss:1.4035924673080444  accuracy:0.4104912579059601  val_loss:1.453572392463684  val_accuracy:0.3708333373069763 ]
32it [01:45,  3.31s/it, epoch: 23/50 -  loss:1.359868049621582  accuracy:0.41631972789764404  val_loss:1.6003872156143188  val_accuracy:0.3583333194255829 ]
32it [01:39,  3.09s/it, epoch: 24/50 -  loss:1.3811829090118408  accuracy:0.42381349205970764  val_loss:1.5266224145889282  val_accuracy:0.3166666626930237 ]
32it [01:42,  3.22s/it, epoch: 25/50 -  loss:1.3589435815811157  accuracy:0.40965861082077026  val_loss:2.0248773097991943  val_accuracy:0.2708333432674408 ]
32it [01:39,  3.11s/it, epoch: 26/50 -  loss:1.4731003046035767  accuracy:0.4154870808124542  val_loss:1.4912035465240479  val_accuracy:0.34583333134651184 ]
32it [01:42,  3.20s/it, epoch: 27/50 -  loss:1.3412293195724487  accuracy:0.41631972789764404  val_loss:1.3930459022521973  val_accuracy:0.3791666626930237 ]
32it [01:42,  3.20s/it, epoch: 28/50 -  loss:1.314136266708374  accuracy:0.43130725622177124  val_loss:1.555684208869934  val_accuracy:0.3166666626930237 ]
32it [01:52,  3.51s/it, epoch: 29/50 -  loss:1.3471096754074097  accuracy:0.42797669768333435  val_loss:1.6817519664764404  val_accuracy:0.2666666805744171 ]
32it [01:38,  3.06s/it, epoch: 30/50 -  loss:1.3709834814071655  accuracy:0.407993346452713  val_loss:1.4343681335449219  val_accuracy:0.36666667461395264 ]
32it [01:48,  3.41s/it, epoch: 31/50 -  loss:1.3120677471160889  accuracy:0.43463781476020813  val_loss:1.4698201417922974  val_accuracy:0.3499999940395355 ]
32it [01:37,  3.05s/it, epoch: 32/50 -  loss:1.3166753053665161  accuracy:0.43130725622177124  val_loss:1.4647256135940552  val_accuracy:0.34166666865348816 ]
32it [01:40,  3.15s/it, epoch: 33/50 -  loss:1.316184401512146  accuracy:0.4263114035129547  val_loss:1.7285205125808716  val_accuracy:0.25833332538604736 ]
32it [01:41,  3.17s/it, epoch: 34/50 -  loss:1.3727495670318604  accuracy:0.4154870808124542  val_loss:1.3565254211425781  val_accuracy:0.3916666805744171 ]
32it [01:44,  3.26s/it, epoch: 35/50 -  loss:1.280211329460144  accuracy:0.45045796036720276  val_loss:1.469690203666687  val_accuracy:0.30000001192092896 ]
32it [01:44,  3.27s/it, epoch: 36/50 -  loss:1.3061801195144653  accuracy:0.41715237498283386  val_loss:1.4475810527801514  val_accuracy:0.38749998807907104 ]
32it [01:39,  3.11s/it, epoch: 37/50 -  loss:1.2898013591766357  accuracy:0.443796843290329  val_loss:1.6531250476837158  val_accuracy:0.28333333134651184 ]
32it [01:43,  3.22s/it, epoch: 38/50 -  loss:1.3455705642700195  accuracy:0.42714405059814453  val_loss:1.5225203037261963  val_accuracy:0.38749998807907104 ]
32it [01:39,  3.12s/it, epoch: 39/50 -  loss:1.30066978931427  accuracy:0.4371357262134552  val_loss:1.4351023435592651  val_accuracy:0.38749998807907104 ]
32it [01:42,  3.21s/it, epoch: 40/50 -  loss:1.2890324592590332  accuracy:0.45295587182044983  val_loss:1.563452959060669  val_accuracy:0.3125 ]
32it [01:37,  3.03s/it, epoch: 41/50 -  loss:1.3143178224563599  accuracy:0.43963363766670227  val_loss:1.3913853168487549  val_accuracy:0.4124999940395355 ]
32it [01:46,  3.34s/it, epoch: 42/50 -  loss:1.2762908935546875  accuracy:0.45711907744407654  val_loss:1.3839548826217651  val_accuracy:0.3916666805744171 ]
32it [01:36,  3.00s/it, epoch: 43/50 -  loss:1.2671079635620117  accuracy:0.4587843418121338  val_loss:1.4633164405822754  val_accuracy:0.38749998807907104 ]
32it [01:42,  3.20s/it, epoch: 44/50 -  loss:1.2803606986999512  accuracy:0.45711907744407654  val_loss:1.5430619716644287  val_accuracy:0.30000001192092896 ]
32it [01:39,  3.11s/it, epoch: 45/50 -  loss:1.3015990257263184  accuracy:0.437968373298645  val_loss:2.1071672439575195  val_accuracy:0.2958333194255829 ]
32it [01:43,  3.24s/it, epoch: 46/50 -  loss:1.4335447549819946  accuracy:0.4371357262134552  val_loss:1.4610460996627808  val_accuracy:0.3583333194255829 ]
32it [01:42,  3.21s/it, epoch: 47/50 -  loss:1.276992678642273  accuracy:0.46128225326538086  val_loss:1.4364838600158691  val_accuracy:0.375 ]
32it [01:44,  3.25s/it, epoch: 48/50 -  loss:1.2780916690826416  accuracy:0.4662781059741974  val_loss:1.4993922710418701  val_accuracy:0.3916666805744171 ]
32it [01:42,  3.22s/it, epoch: 49/50 -  loss:1.2837833166122437  accuracy:0.443796843290329  val_loss:1.5112351179122925  val_accuracy:0.3333333432674408 ]
32it [01:41,  3.17s/it, epoch: 50/50 -  loss:1.29127836227417  accuracy:0.45045796036720276  val_loss:1.6569411754608154  val_accuracy:0.34583333134651184 ]
[28]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
../_images/tutorial_tensorflow-finetune-inception-v3_44_0.png
../_images/tutorial_tensorflow-finetune-inception-v3_44_1.png

冻结底层微调顶层网络#

我们只要参照教程里对模型的定义,在函数里完成我们对模型的定义即可;可以看到代码几乎不需要作任何修改,只需要进行适当的封装。 为了方便作对比实验,我们额外添加是否加载权重的选项。

[29]:
def create_inception_v3_model_fine_tune(num_classes, is_load_weight=True):
    def create_model():
        from tensorflow import keras

        # Create model
        # create the base pre-trained model
        if is_load_weight:
            base_model = InceptionV3(weights='imagenet', include_top=False)
        else:
            base_model = InceptionV3(weights=None, include_top=False)

        # add a global spatial average pooling layer
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        # let's add a fully-connected layer
        x = Dense(1024, activation='relu')(x)
        # and a logistic layer -- let's say we have 10 classes
        predictions = Dense(num_classes, activation='softmax')(x)

        # this is the model we will train
        model = Model(inputs=base_model.input, outputs=predictions)

        for layer in model.layers[:249]:
            layer.trainable = False
        for layer in model.layers[249:]:
            layer.trainable = True

        # Compile model
        model.compile(
            optimizer=SGD(learning_rate=0.0001, momentum=0.9),
            loss='sparse_categorical_crossentropy',
            metrics=["accuracy"],
        )

        return model

    return create_model

加载预训练模型权重并且微调#

[30]:
# keras model
weight_model = create_inception_v3_model_fine_tune(
    num_classes=num_classes, is_load_weight=True
)


fed_model = FLModel(
    device_list=device_list,
    model=weight_model,
    aggregator=aggregator,
    backend="tensorflow",
    strategy="fed_avg_w",
    random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[31]:
history = fed_model.fit(
    data,
    None,
    validation_data=data,
    epochs=epochs,
    batch_size=batch_size,
    aggregate_freq=aggregate_freq,
    sampler_method=sampler_method,
    random_seed=random_seed,
    dp_spent_step_freq=dp_spent_step_freq,
    dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbc1c65c160>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:50,  3.46s/it, epoch: 1/50 -  loss:1.650011658668518  accuracy:0.20915712416172028  val_loss:1.656758427619934  val_accuracy:0.24583333730697632 ]
32it [01:46,  3.32s/it, epoch: 2/50 -  loss:1.5472569465637207  accuracy:0.32389676570892334  val_loss:1.58930242061615  val_accuracy:0.28333333134651184 ]
32it [01:45,  3.29s/it, epoch: 3/50 -  loss:1.4915741682052612  accuracy:0.36552873253822327  val_loss:1.5358479022979736  val_accuracy:0.30000001192092896 ]
32it [01:43,  3.25s/it, epoch: 4/50 -  loss:1.4296882152557373  accuracy:0.4004995822906494  val_loss:1.5131781101226807  val_accuracy:0.32499998807907104 ]
32it [01:45,  3.30s/it, epoch: 5/50 -  loss:1.3764379024505615  accuracy:0.43963363766670227  val_loss:1.485734462738037  val_accuracy:0.36666667461395264 ]
32it [01:41,  3.16s/it, epoch: 6/50 -  loss:1.334008812904358  accuracy:0.4845961630344391  val_loss:1.4631234407424927  val_accuracy:0.3541666567325592 ]
32it [01:53,  3.56s/it, epoch: 7/50 -  loss:1.2897144556045532  accuracy:0.5104079842567444  val_loss:1.4368029832839966  val_accuracy:0.3499999940395355 ]
32it [01:38,  3.08s/it, epoch: 8/50 -  loss:1.2628018856048584  accuracy:0.526228129863739  val_loss:1.4215635061264038  val_accuracy:0.36666667461395264 ]
32it [01:44,  3.27s/it, epoch: 9/50 -  loss:1.2168503999710083  accuracy:0.5553705096244812  val_loss:1.4049427509307861  val_accuracy:0.3916666805744171 ]
32it [01:36,  3.03s/it, epoch: 10/50 -  loss:1.187528371810913  accuracy:0.5611990094184875  val_loss:1.3869177103042603  val_accuracy:0.3916666805744171 ]
32it [01:37,  3.04s/it, epoch: 11/50 -  loss:1.1617594957351685  accuracy:0.5903413891792297  val_loss:1.3873077630996704  val_accuracy:0.3958333432674408 ]
32it [01:42,  3.19s/it, epoch: 12/50 -  loss:1.1248602867126465  accuracy:0.5936719179153442  val_loss:1.3818570375442505  val_accuracy:0.4000000059604645 ]
32it [01:39,  3.11s/it, epoch: 13/50 -  loss:1.0991374254226685  accuracy:0.6369692087173462  val_loss:1.366493821144104  val_accuracy:0.42500001192092896 ]
32it [01:45,  3.31s/it, epoch: 14/50 -  loss:1.0843088626861572  accuracy:0.6353039145469666  val_loss:1.3633487224578857  val_accuracy:0.42916667461395264 ]
32it [01:41,  3.17s/it, epoch: 15/50 -  loss:1.059274673461914  accuracy:0.6636136770248413  val_loss:1.3560353517532349  val_accuracy:0.4375 ]
32it [01:44,  3.26s/it, epoch: 16/50 -  loss:1.0273948907852173  accuracy:0.6644462943077087  val_loss:1.3572434186935425  val_accuracy:0.4333333373069763 ]
32it [01:39,  3.12s/it, epoch: 17/50 -  loss:0.9911747574806213  accuracy:0.6994171738624573  val_loss:1.3419238328933716  val_accuracy:0.44999998807907104 ]
32it [01:50,  3.44s/it, epoch: 18/50 -  loss:0.9911114573478699  accuracy:0.67194002866745  val_loss:1.3493646383285522  val_accuracy:0.42500001192092896 ]
32it [01:40,  3.16s/it, epoch: 19/50 -  loss:0.9671102166175842  accuracy:0.6852622628211975  val_loss:1.3433830738067627  val_accuracy:0.42916667461395264 ]
32it [01:46,  3.32s/it, epoch: 20/50 -  loss:0.937901496887207  accuracy:0.7052456140518188  val_loss:1.3399254083633423  val_accuracy:0.4416666626930237 ]
32it [01:42,  3.21s/it, epoch: 21/50 -  loss:0.906328558921814  accuracy:0.7277268767356873  val_loss:1.328268051147461  val_accuracy:0.4541666805744171 ]
32it [01:52,  3.53s/it, epoch: 22/50 -  loss:0.8943104147911072  accuracy:0.7402164936065674  val_loss:1.3352607488632202  val_accuracy:0.44583332538604736 ]
32it [01:44,  3.26s/it, epoch: 23/50 -  loss:0.8718297481536865  accuracy:0.7452123165130615  val_loss:1.310567021369934  val_accuracy:0.4583333432674408 ]
32it [01:53,  3.53s/it, epoch: 24/50 -  loss:0.8595983982086182  accuracy:0.746044933795929  val_loss:1.3058205842971802  val_accuracy:0.46666666865348816 ]
32it [01:40,  3.14s/it, epoch: 25/50 -  loss:0.8207837343215942  accuracy:0.7618651390075684  val_loss:1.3195736408233643  val_accuracy:0.47083333134651184 ]
32it [01:41,  3.18s/it, epoch: 26/50 -  loss:0.8151105642318726  accuracy:0.7718567848205566  val_loss:1.3052656650543213  val_accuracy:0.4625000059604645 ]
32it [01:39,  3.11s/it, epoch: 27/50 -  loss:0.8001042008399963  accuracy:0.7718567848205566  val_loss:1.3049904108047485  val_accuracy:0.46666666865348816 ]
32it [01:42,  3.20s/it, epoch: 28/50 -  loss:0.7668044567108154  accuracy:0.7901748418807983  val_loss:1.307594895362854  val_accuracy:0.4625000059604645 ]
32it [01:40,  3.13s/it, epoch: 29/50 -  loss:0.7354434728622437  accuracy:0.8043297529220581  val_loss:1.306896686553955  val_accuracy:0.47083333134651184 ]
32it [01:42,  3.21s/it, epoch: 30/50 -  loss:0.7426672577857971  accuracy:0.8018317818641663  val_loss:1.3118059635162354  val_accuracy:0.4833333194255829 ]
32it [01:38,  3.07s/it, epoch: 31/50 -  loss:0.7213742733001709  accuracy:0.8126561045646667  val_loss:1.296970248222351  val_accuracy:0.4625000059604645 ]
32it [01:37,  3.04s/it, epoch: 32/50 -  loss:0.700598955154419  accuracy:0.8184846043586731  val_loss:1.305118203163147  val_accuracy:0.4583333432674408 ]
32it [01:37,  3.05s/it, epoch: 33/50 -  loss:0.6826915144920349  accuracy:0.8259783387184143  val_loss:1.3077011108398438  val_accuracy:0.4791666567325592 ]
32it [01:37,  3.05s/it, epoch: 34/50 -  loss:0.6565826535224915  accuracy:0.8251457214355469  val_loss:1.301528811454773  val_accuracy:0.46666666865348816 ]
32it [01:40,  3.13s/it, epoch: 35/50 -  loss:0.656404972076416  accuracy:0.8226478099822998  val_loss:1.306329369544983  val_accuracy:0.4749999940395355 ]
32it [01:40,  3.13s/it, epoch: 36/50 -  loss:0.6250473856925964  accuracy:0.8384679555892944  val_loss:1.3014804124832153  val_accuracy:0.46666666865348816 ]
32it [01:45,  3.29s/it, epoch: 37/50 -  loss:0.6090459227561951  accuracy:0.8426311612129211  val_loss:1.3036556243896484  val_accuracy:0.46666666865348816 ]
32it [01:41,  3.17s/it, epoch: 38/50 -  loss:0.6007826924324036  accuracy:0.8409658670425415  val_loss:1.3221803903579712  val_accuracy:0.46666666865348816 ]
32it [01:45,  3.29s/it, epoch: 39/50 -  loss:0.5874541997909546  accuracy:0.8467943668365479  val_loss:1.3001883029937744  val_accuracy:0.4583333432674408 ]
32it [01:40,  3.14s/it, epoch: 40/50 -  loss:0.5784027576446533  accuracy:0.8459616899490356  val_loss:1.29843270778656  val_accuracy:0.4583333432674408 ]
32it [01:45,  3.31s/it, epoch: 41/50 -  loss:0.5616000890731812  accuracy:0.8434637784957886  val_loss:1.3310949802398682  val_accuracy:0.4749999940395355 ]
32it [01:44,  3.25s/it, epoch: 42/50 -  loss:0.564281165599823  accuracy:0.8559533953666687  val_loss:1.3145432472229004  val_accuracy:0.47083333134651184 ]
32it [01:43,  3.25s/it, epoch: 43/50 -  loss:0.54606032371521  accuracy:0.8592839241027832  val_loss:1.311663269996643  val_accuracy:0.4625000059604645 ]
32it [01:44,  3.25s/it, epoch: 44/50 -  loss:0.5349538922309875  accuracy:0.8601165413856506  val_loss:1.3167394399642944  val_accuracy:0.4625000059604645 ]
32it [01:46,  3.32s/it, epoch: 45/50 -  loss:0.5256470441818237  accuracy:0.8667776584625244  val_loss:1.3188197612762451  val_accuracy:0.4791666567325592 ]
32it [01:41,  3.16s/it, epoch: 46/50 -  loss:0.5097993016242981  accuracy:0.8676103353500366  val_loss:1.3174797296524048  val_accuracy:0.4791666567325592 ]
32it [01:40,  3.15s/it, epoch: 47/50 -  loss:0.49822941422462463  accuracy:0.8742714524269104  val_loss:1.3001458644866943  val_accuracy:0.4583333432674408 ]
32it [01:47,  3.35s/it, epoch: 48/50 -  loss:0.4908905625343323  accuracy:0.8667776584625244  val_loss:1.3209236860275269  val_accuracy:0.46666666865348816 ]
32it [01:39,  3.10s/it, epoch: 49/50 -  loss:0.48140043020248413  accuracy:0.8726061582565308  val_loss:1.3325889110565186  val_accuracy:0.4749999940395355 ]
32it [01:43,  3.23s/it, epoch: 50/50 -  loss:0.4809170365333557  accuracy:0.8792672753334045  val_loss:1.336763858795166  val_accuracy:0.4833333194255829 ]
[32]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
../_images/tutorial_tensorflow-finetune-inception-v3_50_0.png
../_images/tutorial_tensorflow-finetune-inception-v3_50_1.png

只加载网络结构同时随机初始化#

[33]:
# keras model
no_weight_model = create_inception_v3_model_fine_tune(
    num_classes=num_classes, is_load_weight=False
)


fed_model = FLModel(
    device_list=device_list,
    model=no_weight_model,
    aggregator=aggregator,
    backend="tensorflow",
    strategy="fed_avg_w",
    random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[34]:
history = fed_model.fit(
    data,
    None,
    validation_data=data,
    epochs=epochs,
    batch_size=batch_size,
    aggregate_freq=aggregate_freq,
    sampler_method=sampler_method,
    random_seed=random_seed,
    dp_spent_step_freq=dp_spent_step_freq,
    dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbc24074700>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:48,  3.38s/it, epoch: 1/50 -  loss:1.5617533922195435  accuracy:0.27783557772636414  val_loss:1.6185896396636963  val_accuracy:0.13750000298023224 ]
32it [01:45,  3.30s/it, epoch: 2/50 -  loss:1.4592828750610352  accuracy:0.34970858693122864  val_loss:1.6081194877624512  val_accuracy:0.15833333134651184 ]
32it [01:41,  3.16s/it, epoch: 3/50 -  loss:1.388456106185913  accuracy:0.4254787564277649  val_loss:1.5979021787643433  val_accuracy:0.2750000059604645 ]
32it [01:45,  3.28s/it, epoch: 4/50 -  loss:1.3454346656799316  accuracy:0.43547043204307556  val_loss:1.5885674953460693  val_accuracy:0.2916666567325592 ]
32it [01:43,  3.23s/it, epoch: 5/50 -  loss:1.3147145509719849  accuracy:0.46461281180381775  val_loss:1.57625150680542  val_accuracy:0.28333333134651184 ]
32it [01:42,  3.21s/it, epoch: 6/50 -  loss:1.2754597663879395  accuracy:0.4854288101196289  val_loss:1.563430905342102  val_accuracy:0.30000001192092896 ]
32it [01:38,  3.07s/it, epoch: 7/50 -  loss:1.2425297498703003  accuracy:0.5045795440673828  val_loss:1.5455553531646729  val_accuracy:0.2750000059604645 ]
32it [01:45,  3.30s/it, epoch: 8/50 -  loss:1.2141278982162476  accuracy:0.5129058957099915  val_loss:1.5239124298095703  val_accuracy:0.2874999940395355 ]
32it [01:51,  3.48s/it, epoch: 9/50 -  loss:1.1969578266143799  accuracy:0.5237302184104919  val_loss:1.4922692775726318  val_accuracy:0.38333332538604736 ]
32it [01:37,  3.05s/it, epoch: 10/50 -  loss:1.1684715747833252  accuracy:0.5653622150421143  val_loss:1.4653935432434082  val_accuracy:0.36666667461395264 ]
32it [01:38,  3.09s/it, epoch: 11/50 -  loss:1.1470826864242554  accuracy:0.5678601264953613  val_loss:1.439568281173706  val_accuracy:0.40833333134651184 ]
32it [01:42,  3.19s/it, epoch: 12/50 -  loss:1.1196565628051758  accuracy:0.572023332118988  val_loss:1.399143934249878  val_accuracy:0.44583332538604736 ]
32it [01:47,  3.36s/it, epoch: 13/50 -  loss:1.1069118976593018  accuracy:0.6036636233329773  val_loss:1.3591387271881104  val_accuracy:0.47083333134651184 ]
32it [01:48,  3.40s/it, epoch: 14/50 -  loss:1.0912777185440063  accuracy:0.6036636233329773  val_loss:1.327014684677124  val_accuracy:0.48750001192092896 ]
32it [01:50,  3.45s/it, epoch: 15/50 -  loss:1.0678472518920898  accuracy:0.6094920635223389  val_loss:1.2993195056915283  val_accuracy:0.5083333253860474 ]
32it [01:41,  3.18s/it, epoch: 16/50 -  loss:1.0440828800201416  accuracy:0.6111573576927185  val_loss:1.2653034925460815  val_accuracy:0.5166666507720947 ]
32it [01:48,  3.38s/it, epoch: 17/50 -  loss:1.0225656032562256  accuracy:0.6294754147529602  val_loss:1.231634497642517  val_accuracy:0.5249999761581421 ]
32it [01:42,  3.20s/it, epoch: 18/50 -  loss:1.020599365234375  accuracy:0.6203163862228394  val_loss:1.2129393815994263  val_accuracy:0.5333333611488342 ]
32it [01:45,  3.29s/it, epoch: 19/50 -  loss:0.9966413974761963  accuracy:0.6286427974700928  val_loss:1.1990076303482056  val_accuracy:0.5333333611488342 ]
32it [01:42,  3.21s/it, epoch: 20/50 -  loss:0.9871851801872253  accuracy:0.6402997374534607  val_loss:1.176818609237671  val_accuracy:0.550000011920929 ]
32it [01:51,  3.49s/it, epoch: 21/50 -  loss:0.9735381603240967  accuracy:0.6511240601539612  val_loss:1.164151906967163  val_accuracy:0.512499988079071 ]
32it [01:44,  3.27s/it, epoch: 22/50 -  loss:0.9548952579498291  accuracy:0.6486261487007141  val_loss:1.1686121225357056  val_accuracy:0.5333333611488342 ]
32it [01:43,  3.22s/it, epoch: 23/50 -  loss:0.9554542899131775  accuracy:0.6444629430770874  val_loss:1.1495128870010376  val_accuracy:0.5333333611488342 ]
32it [01:55,  3.61s/it, epoch: 24/50 -  loss:0.9482444524765015  accuracy:0.6494587659835815  val_loss:1.136460542678833  val_accuracy:0.5375000238418579 ]
32it [01:52,  3.53s/it, epoch: 25/50 -  loss:0.9186312556266785  accuracy:0.6702747941017151  val_loss:1.1416422128677368  val_accuracy:0.5416666865348816 ]
32it [01:42,  3.21s/it, epoch: 26/50 -  loss:0.9188517928123474  accuracy:0.6736053228378296  val_loss:1.1264033317565918  val_accuracy:0.550000011920929 ]
32it [01:41,  3.17s/it, epoch: 27/50 -  loss:0.9028259515762329  accuracy:0.681931734085083  val_loss:1.1303426027297974  val_accuracy:0.550000011920929 ]
32it [01:45,  3.31s/it, epoch: 28/50 -  loss:0.8891870975494385  accuracy:0.6844296455383301  val_loss:1.1231117248535156  val_accuracy:0.5458333492279053 ]
32it [01:43,  3.23s/it, epoch: 29/50 -  loss:0.8797008991241455  accuracy:0.6977518796920776  val_loss:1.1278127431869507  val_accuracy:0.5416666865348816 ]
32it [01:47,  3.37s/it, epoch: 30/50 -  loss:0.878724217414856  accuracy:0.6752706170082092  val_loss:1.1188760995864868  val_accuracy:0.574999988079071 ]
32it [01:40,  3.13s/it, epoch: 31/50 -  loss:0.8734909296035767  accuracy:0.703580379486084  val_loss:1.121822476387024  val_accuracy:0.550000011920929 ]
32it [01:50,  3.44s/it, epoch: 32/50 -  loss:0.8486524820327759  accuracy:0.6985844969749451  val_loss:1.127307415008545  val_accuracy:0.5416666865348816 ]
32it [01:45,  3.29s/it, epoch: 33/50 -  loss:0.8515317440032959  accuracy:0.6944212913513184  val_loss:1.1144821643829346  val_accuracy:0.5625 ]
32it [01:44,  3.26s/it, epoch: 34/50 -  loss:0.8291710615158081  accuracy:0.7135720252990723  val_loss:1.1182698011398315  val_accuracy:0.5458333492279053 ]
32it [01:49,  3.41s/it, epoch: 35/50 -  loss:0.8350727558135986  accuracy:0.7019150853157043  val_loss:1.1012609004974365  val_accuracy:0.5833333134651184 ]
32it [01:48,  3.38s/it, epoch: 36/50 -  loss:0.8297075033187866  accuracy:0.7227310538291931  val_loss:1.1068880558013916  val_accuracy:0.5625 ]
32it [01:49,  3.41s/it, epoch: 37/50 -  loss:0.809542715549469  accuracy:0.7293921709060669  val_loss:1.1091495752334595  val_accuracy:0.5458333492279053 ]
32it [01:43,  3.24s/it, epoch: 38/50 -  loss:0.8065611720085144  accuracy:0.7185678482055664  val_loss:1.09843909740448  val_accuracy:0.5708333253860474 ]
32it [01:46,  3.34s/it, epoch: 39/50 -  loss:0.7942371964454651  accuracy:0.7310574650764465  val_loss:1.105940580368042  val_accuracy:0.5708333253860474 ]
32it [01:41,  3.16s/it, epoch: 40/50 -  loss:0.7934896349906921  accuracy:0.7227310538291931  val_loss:1.1048657894134521  val_accuracy:0.5666666626930237 ]
32it [01:43,  3.23s/it, epoch: 41/50 -  loss:0.7780565023422241  accuracy:0.7327227592468262  val_loss:1.0963842868804932  val_accuracy:0.574999988079071 ]
32it [01:43,  3.24s/it, epoch: 42/50 -  loss:0.7623825669288635  accuracy:0.7452123165130615  val_loss:1.0992350578308105  val_accuracy:0.5583333373069763 ]
32it [01:44,  3.27s/it, epoch: 43/50 -  loss:0.7639929056167603  accuracy:0.7385511994361877  val_loss:1.0936287641525269  val_accuracy:0.5708333253860474 ]
32it [01:40,  3.13s/it, epoch: 44/50 -  loss:0.7649388909339905  accuracy:0.7485429048538208  val_loss:1.0890998840332031  val_accuracy:0.5625 ]
32it [01:44,  3.25s/it, epoch: 45/50 -  loss:0.7667314410209656  accuracy:0.7477102279663086  val_loss:1.1028246879577637  val_accuracy:0.5833333134651184 ]
32it [01:39,  3.11s/it, epoch: 46/50 -  loss:0.7435023188591003  accuracy:0.7468776106834412  val_loss:1.0851361751556396  val_accuracy:0.5791666507720947 ]
32it [01:38,  3.07s/it, epoch: 47/50 -  loss:0.7359048128128052  accuracy:0.7643630504608154  val_loss:1.0931051969528198  val_accuracy:0.5625 ]
32it [01:37,  3.05s/it, epoch: 48/50 -  loss:0.729657769203186  accuracy:0.7626977562904358  val_loss:1.083522915840149  val_accuracy:0.5708333253860474 ]
32it [01:44,  3.27s/it, epoch: 49/50 -  loss:0.7189591526985168  accuracy:0.7601998448371887  val_loss:1.0874546766281128  val_accuracy:0.574999988079071 ]
32it [01:39,  3.10s/it, epoch: 50/50 -  loss:0.7207762598991394  accuracy:0.7760199904441833  val_loss:1.0822055339813232  val_accuracy:0.5708333253860474 ]
[35]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()

# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
../_images/tutorial_tensorflow-finetune-inception-v3_54_0.png
../_images/tutorial_tensorflow-finetune-inception-v3_54_1.png

联邦学习小结#

可以看到,对照着 TensorFlow 的官方教程,隐语能够无缝地兼容所给出的微调方式;并且我们可以看到,通过对预训练模型的兼容,我们可以不需要自己再重新写出复杂网络的模型结构,InceptionV3 的网络结构源代码位于:source code of Inception V3,并且通过对比实验我们可以看出,加载预训练模型的权重,可以让我们的模型性能更优秀。

总结#

本篇教程,我们以Inception V3为例介绍了如何在隐语的联邦学习模式下基于直接加载 TensorFlow.Keras 的 预训练模型,通过直接加载预训练模型,我们能够获得: - 不需要再次编写复杂模型的结构代码 - 基于预训练模型进行微调和迁移学习 - 使用预训练权重模型能够使得联邦模型获得更好的性能