基于 Keras Applications 的预训练模型在隐语联邦学习环境下的微调#
引言#
预训练模型加载和精调在机器学习中非常重要。一般来说,从头训练一个非常大的模型,不仅需要大量的算力资源,同时也需要耗费大量的时间。所以在传统的机器学习中,使用预训练模型,然后针对具体的任务做微调和迁移学习非常普遍。同样的,对于联邦学习来说,如果能够加载预训练模型进行微调和迁移学习,不仅能够节省参与方的算力资源,降低参与方的准入门槛,同时也能够加快模型的学习速度。
得益于隐语联邦学习模块优异的兼容性,使得其可以直接加载TensorFlow.Keras的一系列预训练模型;本教程将基于TensorFlow.Keras的InceptionV3的微调教程展现如何基于TensorFlow.Keras的预训练模型在SecretFlow的框架下进行微调,充分展现SecretFlow的易用性。
加载数据集#
数据集介绍#
Flower 数据集介绍:flower 数据集是一个包含了 5 种花卉(雏菊、蒲公英、玫瑰、向日葵、郁金香)共计 4323 张彩色图片的数据集。每种花卉都有多个角度和不同光照下的图片,每张图片的分辨率为 320x240。这个数据集常用于图像分类和机器学习算法的训练与测试。数据集中每个类别的数量分别是:daisy(633),dandelion(898),rose(641),sunflower(699),tulip(852)
下载地址: http://download.tensorflow.org/example_images/flower_photos.tgz
下载数据集并解压#
[1]:
import tempfile
import tensorflow as tf
_temp_dir = tempfile.mkdtemp()
path_to_flower_dataset = tf.keras.utils.get_file(
"flower_photos",
"https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz",
untar=True,
cache_dir=_temp_dir,
)
2023-10-11 07:11:54.892985: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-11 07:11:55.019580: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-11 07:11:57.008960: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz
67588319/67588319 [==============================] - 2s 0us/step
加载数据集#
[2]:
import math
import tensorflow as tf
img_height = 180
img_width = 180
batch_size = 32
# In this example, we use the TensorFlow interface for development.
data_set = tf.keras.utils.image_dataset_from_directory(
path_to_flower_dataset,
validation_split=0.2,
subset="both",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size,
)
Found 1201 files belonging to 5 classes.
Using 961 files for training.
Using 240 files for validation.
2023-10-11 07:12:05.321890: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 12653 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:3b:00.0, compute capability: 7.5
2023-10-11 07:12:05.324020: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13775 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:3c:00.0, compute capability: 7.5
划分数据集#
[3]:
train_set = data_set[0]
test_set = data_set[1]
查看数据集#
[4]:
print(type(train_set), type(test_set))
<class 'tensorflow.python.data.ops.batch_op._BatchDataset'> <class 'tensorflow.python.data.ops.batch_op._BatchDataset'>
[5]:
x, y = next(iter(train_set))
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
2023-10-11 07:12:05.799561: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [961]
[[{{node Placeholder/_4}}]]
2023-10-11 07:12:05.800177: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [961]
[[{{node Placeholder/_0}}]]
x.shape = (32, 180, 180, 3)
y.shape = (32,)
单机模式进行微调#
单机模式下进行预训练模型的微调,基本上参考TensorFlow.Keras的官方教程,并根据数据集格式在编译模型的参数上作适当的修改,但影响不大;
微调顶部分类器#
[6]:
import matplotlib.pyplot as plt
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
# create the base pre-trained model
base_model = InceptionV3(weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 10 classes
predictions = Dense(10, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(
optimizer='rmsprop',
loss='sparse_categorical_crossentropy',
metrics=["accuracy"],
)
[7]:
# train the model on the new data for a few epochs
history = model.fit(train_set, validation_data=test_set, epochs=50)
# at this point, the top layers are well trained and we can start fine-tuning
# convolutional layers from inception V3. We will freeze the bottom N layers
# and train the remaining top layers.
Epoch 1/50
2023-10-11 07:12:13.444936: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8600
2023-10-11 07:12:14.278747: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-10-11 07:12:16.692342: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x561a8438d460 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-10-11 07:12:16.692403: I tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): Tesla T4, Compute Capability 7.5
2023-10-11 07:12:16.692418: I tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (1): Tesla T4, Compute Capability 7.5
2023-10-11 07:12:17.022370: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-10-11 07:12:18.304269: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-10-11 07:12:18.552866: I ./tensorflow/compiler/jit/device_compiler.h:180] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
31/31 [==============================] - ETA: 0s - loss: 112.3790 - accuracy: 0.2206
2023-10-11 07:12:21.562516: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [240]
[[{{node Placeholder/_4}}]]
2023-10-11 07:12:21.562776: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [240]
[[{{node Placeholder/_4}}]]
31/31 [==============================] - 15s 175ms/step - loss: 112.3790 - accuracy: 0.2206 - val_loss: 34.8170 - val_accuracy: 0.1500
Epoch 2/50
31/31 [==============================] - 2s 57ms/step - loss: 18.9387 - accuracy: 0.2737 - val_loss: 21.3622 - val_accuracy: 0.2625
Epoch 3/50
31/31 [==============================] - 2s 57ms/step - loss: 14.2041 - accuracy: 0.3091 - val_loss: 16.4007 - val_accuracy: 0.2625
Epoch 4/50
31/31 [==============================] - 2s 59ms/step - loss: 11.6635 - accuracy: 0.2716 - val_loss: 14.6091 - val_accuracy: 0.2167
Epoch 5/50
31/31 [==============================] - 2s 57ms/step - loss: 8.6897 - accuracy: 0.3215 - val_loss: 7.3476 - val_accuracy: 0.4042
Epoch 6/50
31/31 [==============================] - 2s 57ms/step - loss: 8.0223 - accuracy: 0.2914 - val_loss: 3.7781 - val_accuracy: 0.3417
Epoch 7/50
31/31 [==============================] - 2s 57ms/step - loss: 6.0596 - accuracy: 0.3413 - val_loss: 6.9468 - val_accuracy: 0.2417
Epoch 8/50
31/31 [==============================] - 2s 59ms/step - loss: 4.8735 - accuracy: 0.3486 - val_loss: 14.5767 - val_accuracy: 0.2583
Epoch 9/50
31/31 [==============================] - 2s 57ms/step - loss: 3.7563 - accuracy: 0.3954 - val_loss: 7.5974 - val_accuracy: 0.2625
Epoch 10/50
31/31 [==============================] - 2s 58ms/step - loss: 3.1183 - accuracy: 0.3871 - val_loss: 9.8358 - val_accuracy: 0.2583
Epoch 11/50
31/31 [==============================] - 2s 58ms/step - loss: 3.5958 - accuracy: 0.3725 - val_loss: 3.7865 - val_accuracy: 0.2792
Epoch 12/50
31/31 [==============================] - 2s 58ms/step - loss: 2.4401 - accuracy: 0.4287 - val_loss: 5.4358 - val_accuracy: 0.2833
Epoch 13/50
31/31 [==============================] - 2s 57ms/step - loss: 2.2799 - accuracy: 0.4194 - val_loss: 4.6291 - val_accuracy: 0.2208
Epoch 14/50
31/31 [==============================] - 2s 58ms/step - loss: 2.3558 - accuracy: 0.4204 - val_loss: 4.1438 - val_accuracy: 0.2125
Epoch 15/50
31/31 [==============================] - 2s 57ms/step - loss: 1.8603 - accuracy: 0.4828 - val_loss: 6.8592 - val_accuracy: 0.2917
Epoch 16/50
31/31 [==============================] - 2s 58ms/step - loss: 1.8960 - accuracy: 0.4880 - val_loss: 2.8224 - val_accuracy: 0.3042
Epoch 17/50
31/31 [==============================] - 2s 58ms/step - loss: 1.7029 - accuracy: 0.4984 - val_loss: 3.8373 - val_accuracy: 0.2250
Epoch 18/50
31/31 [==============================] - 2s 58ms/step - loss: 1.4418 - accuracy: 0.5120 - val_loss: 3.6055 - val_accuracy: 0.3292
Epoch 19/50
31/31 [==============================] - 2s 57ms/step - loss: 1.5190 - accuracy: 0.5245 - val_loss: 3.9276 - val_accuracy: 0.3458
Epoch 20/50
31/31 [==============================] - 2s 58ms/step - loss: 1.3073 - accuracy: 0.5619 - val_loss: 6.1296 - val_accuracy: 0.1875
Epoch 21/50
31/31 [==============================] - 2s 59ms/step - loss: 1.3950 - accuracy: 0.5390 - val_loss: 3.7171 - val_accuracy: 0.2375
Epoch 22/50
31/31 [==============================] - 2s 59ms/step - loss: 1.1851 - accuracy: 0.5640 - val_loss: 5.6681 - val_accuracy: 0.1708
Epoch 23/50
31/31 [==============================] - 2s 58ms/step - loss: 1.3379 - accuracy: 0.5838 - val_loss: 2.4407 - val_accuracy: 0.2625
Epoch 24/50
31/31 [==============================] - 2s 59ms/step - loss: 1.1016 - accuracy: 0.6098 - val_loss: 3.4062 - val_accuracy: 0.2917
Epoch 25/50
31/31 [==============================] - 2s 58ms/step - loss: 1.1200 - accuracy: 0.5931 - val_loss: 2.9323 - val_accuracy: 0.3042
Epoch 26/50
31/31 [==============================] - 2s 59ms/step - loss: 1.0031 - accuracy: 0.6389 - val_loss: 2.8517 - val_accuracy: 0.2708
Epoch 27/50
31/31 [==============================] - 2s 59ms/step - loss: 1.0253 - accuracy: 0.6306 - val_loss: 4.8238 - val_accuracy: 0.3000
Epoch 28/50
31/31 [==============================] - 2s 59ms/step - loss: 1.0335 - accuracy: 0.6576 - val_loss: 3.0405 - val_accuracy: 0.2958
Epoch 29/50
31/31 [==============================] - 2s 58ms/step - loss: 1.1657 - accuracy: 0.6181 - val_loss: 3.3026 - val_accuracy: 0.2375
Epoch 30/50
31/31 [==============================] - 2s 59ms/step - loss: 0.9623 - accuracy: 0.6629 - val_loss: 3.2407 - val_accuracy: 0.2875
Epoch 31/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8993 - accuracy: 0.6920 - val_loss: 2.2036 - val_accuracy: 0.3917
Epoch 32/50
31/31 [==============================] - 2s 58ms/step - loss: 0.9263 - accuracy: 0.6816 - val_loss: 3.2231 - val_accuracy: 0.2917
Epoch 33/50
31/31 [==============================] - 2s 60ms/step - loss: 0.8958 - accuracy: 0.6930 - val_loss: 3.6673 - val_accuracy: 0.2583
Epoch 34/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8155 - accuracy: 0.7045 - val_loss: 3.7752 - val_accuracy: 0.2667
Epoch 35/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8687 - accuracy: 0.6982 - val_loss: 3.5233 - val_accuracy: 0.3708
Epoch 36/50
31/31 [==============================] - 2s 59ms/step - loss: 0.9007 - accuracy: 0.7138 - val_loss: 2.5410 - val_accuracy: 0.3542
Epoch 37/50
31/31 [==============================] - 2s 60ms/step - loss: 0.7482 - accuracy: 0.7378 - val_loss: 3.7791 - val_accuracy: 0.3583
Epoch 38/50
31/31 [==============================] - 2s 59ms/step - loss: 0.8524 - accuracy: 0.7534 - val_loss: 2.9299 - val_accuracy: 0.3375
Epoch 39/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6633 - accuracy: 0.7607 - val_loss: 3.8067 - val_accuracy: 0.3125
Epoch 40/50
31/31 [==============================] - 2s 59ms/step - loss: 0.7970 - accuracy: 0.7336 - val_loss: 3.9137 - val_accuracy: 0.3333
Epoch 41/50
31/31 [==============================] - 2s 59ms/step - loss: 0.7073 - accuracy: 0.7648 - val_loss: 2.9778 - val_accuracy: 0.3542
Epoch 42/50
31/31 [==============================] - 2s 60ms/step - loss: 0.6207 - accuracy: 0.7815 - val_loss: 3.2358 - val_accuracy: 0.3458
Epoch 43/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6773 - accuracy: 0.7700 - val_loss: 3.8493 - val_accuracy: 0.3292
Epoch 44/50
31/31 [==============================] - 2s 60ms/step - loss: 0.5552 - accuracy: 0.7950 - val_loss: 4.1360 - val_accuracy: 0.3292
Epoch 45/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5993 - accuracy: 0.8044 - val_loss: 3.3667 - val_accuracy: 0.3625
Epoch 46/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6991 - accuracy: 0.7877 - val_loss: 2.8805 - val_accuracy: 0.3625
Epoch 47/50
31/31 [==============================] - 2s 59ms/step - loss: 0.6570 - accuracy: 0.7908 - val_loss: 3.2625 - val_accuracy: 0.3375
Epoch 48/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5095 - accuracy: 0.8293 - val_loss: 4.0133 - val_accuracy: 0.3750
Epoch 49/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5420 - accuracy: 0.8148 - val_loss: 3.1207 - val_accuracy: 0.3042
Epoch 50/50
31/31 [==============================] - 2s 59ms/step - loss: 0.5678 - accuracy: 0.8023 - val_loss: 3.8196 - val_accuracy: 0.3583
[8]:
history.history.keys()
[8]:
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
[9]:
# Draw accuracy values for training & validation
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
冻结底层网络层微调顶层网络层#
[10]:
# let's visualize layer names and layer indices to see how many layers
# we should freeze:
for i, layer in enumerate(base_model.layers):
print(i, layer.name)
0 input_1
1 conv2d
2 batch_normalization
3 activation
4 conv2d_1
5 batch_normalization_1
6 activation_1
7 conv2d_2
8 batch_normalization_2
9 activation_2
10 max_pooling2d
11 conv2d_3
12 batch_normalization_3
13 activation_3
14 conv2d_4
15 batch_normalization_4
16 activation_4
17 max_pooling2d_1
18 conv2d_8
19 batch_normalization_8
20 activation_8
21 conv2d_6
22 conv2d_9
23 batch_normalization_6
24 batch_normalization_9
25 activation_6
26 activation_9
27 average_pooling2d
28 conv2d_5
29 conv2d_7
30 conv2d_10
31 conv2d_11
32 batch_normalization_5
33 batch_normalization_7
34 batch_normalization_10
35 batch_normalization_11
36 activation_5
37 activation_7
38 activation_10
39 activation_11
40 mixed0
41 conv2d_15
42 batch_normalization_15
43 activation_15
44 conv2d_13
45 conv2d_16
46 batch_normalization_13
47 batch_normalization_16
48 activation_13
49 activation_16
50 average_pooling2d_1
51 conv2d_12
52 conv2d_14
53 conv2d_17
54 conv2d_18
55 batch_normalization_12
56 batch_normalization_14
57 batch_normalization_17
58 batch_normalization_18
59 activation_12
60 activation_14
61 activation_17
62 activation_18
63 mixed1
64 conv2d_22
65 batch_normalization_22
66 activation_22
67 conv2d_20
68 conv2d_23
69 batch_normalization_20
70 batch_normalization_23
71 activation_20
72 activation_23
73 average_pooling2d_2
74 conv2d_19
75 conv2d_21
76 conv2d_24
77 conv2d_25
78 batch_normalization_19
79 batch_normalization_21
80 batch_normalization_24
81 batch_normalization_25
82 activation_19
83 activation_21
84 activation_24
85 activation_25
86 mixed2
87 conv2d_27
88 batch_normalization_27
89 activation_27
90 conv2d_28
91 batch_normalization_28
92 activation_28
93 conv2d_26
94 conv2d_29
95 batch_normalization_26
96 batch_normalization_29
97 activation_26
98 activation_29
99 max_pooling2d_2
100 mixed3
101 conv2d_34
102 batch_normalization_34
103 activation_34
104 conv2d_35
105 batch_normalization_35
106 activation_35
107 conv2d_31
108 conv2d_36
109 batch_normalization_31
110 batch_normalization_36
111 activation_31
112 activation_36
113 conv2d_32
114 conv2d_37
115 batch_normalization_32
116 batch_normalization_37
117 activation_32
118 activation_37
119 average_pooling2d_3
120 conv2d_30
121 conv2d_33
122 conv2d_38
123 conv2d_39
124 batch_normalization_30
125 batch_normalization_33
126 batch_normalization_38
127 batch_normalization_39
128 activation_30
129 activation_33
130 activation_38
131 activation_39
132 mixed4
133 conv2d_44
134 batch_normalization_44
135 activation_44
136 conv2d_45
137 batch_normalization_45
138 activation_45
139 conv2d_41
140 conv2d_46
141 batch_normalization_41
142 batch_normalization_46
143 activation_41
144 activation_46
145 conv2d_42
146 conv2d_47
147 batch_normalization_42
148 batch_normalization_47
149 activation_42
150 activation_47
151 average_pooling2d_4
152 conv2d_40
153 conv2d_43
154 conv2d_48
155 conv2d_49
156 batch_normalization_40
157 batch_normalization_43
158 batch_normalization_48
159 batch_normalization_49
160 activation_40
161 activation_43
162 activation_48
163 activation_49
164 mixed5
165 conv2d_54
166 batch_normalization_54
167 activation_54
168 conv2d_55
169 batch_normalization_55
170 activation_55
171 conv2d_51
172 conv2d_56
173 batch_normalization_51
174 batch_normalization_56
175 activation_51
176 activation_56
177 conv2d_52
178 conv2d_57
179 batch_normalization_52
180 batch_normalization_57
181 activation_52
182 activation_57
183 average_pooling2d_5
184 conv2d_50
185 conv2d_53
186 conv2d_58
187 conv2d_59
188 batch_normalization_50
189 batch_normalization_53
190 batch_normalization_58
191 batch_normalization_59
192 activation_50
193 activation_53
194 activation_58
195 activation_59
196 mixed6
197 conv2d_64
198 batch_normalization_64
199 activation_64
200 conv2d_65
201 batch_normalization_65
202 activation_65
203 conv2d_61
204 conv2d_66
205 batch_normalization_61
206 batch_normalization_66
207 activation_61
208 activation_66
209 conv2d_62
210 conv2d_67
211 batch_normalization_62
212 batch_normalization_67
213 activation_62
214 activation_67
215 average_pooling2d_6
216 conv2d_60
217 conv2d_63
218 conv2d_68
219 conv2d_69
220 batch_normalization_60
221 batch_normalization_63
222 batch_normalization_68
223 batch_normalization_69
224 activation_60
225 activation_63
226 activation_68
227 activation_69
228 mixed7
229 conv2d_72
230 batch_normalization_72
231 activation_72
232 conv2d_73
233 batch_normalization_73
234 activation_73
235 conv2d_70
236 conv2d_74
237 batch_normalization_70
238 batch_normalization_74
239 activation_70
240 activation_74
241 conv2d_71
242 conv2d_75
243 batch_normalization_71
244 batch_normalization_75
245 activation_71
246 activation_75
247 max_pooling2d_3
248 mixed8
249 conv2d_80
250 batch_normalization_80
251 activation_80
252 conv2d_77
253 conv2d_81
254 batch_normalization_77
255 batch_normalization_81
256 activation_77
257 activation_81
258 conv2d_78
259 conv2d_79
260 conv2d_82
261 conv2d_83
262 average_pooling2d_7
263 conv2d_76
264 batch_normalization_78
265 batch_normalization_79
266 batch_normalization_82
267 batch_normalization_83
268 conv2d_84
269 batch_normalization_76
270 activation_78
271 activation_79
272 activation_82
273 activation_83
274 batch_normalization_84
275 activation_76
276 mixed9_0
277 concatenate
278 activation_84
279 mixed9
280 conv2d_89
281 batch_normalization_89
282 activation_89
283 conv2d_86
284 conv2d_90
285 batch_normalization_86
286 batch_normalization_90
287 activation_86
288 activation_90
289 conv2d_87
290 conv2d_88
291 conv2d_91
292 conv2d_92
293 average_pooling2d_8
294 conv2d_85
295 batch_normalization_87
296 batch_normalization_88
297 batch_normalization_91
298 batch_normalization_92
299 conv2d_93
300 batch_normalization_85
301 activation_87
302 activation_88
303 activation_91
304 activation_92
305 batch_normalization_93
306 activation_85
307 mixed9_1
308 concatenate_1
309 activation_93
310 mixed10
[11]:
# we chose to train the top 2 inception blocks, i.e. we will freeze
# the first 249 layers and unfreeze the rest:
for layer in model.layers[:249]:
layer.trainable = False
for layer in model.layers[249:]:
layer.trainable = True
[12]:
# we need to recompile the model for these modifications to take effect
# we use SGD with a low learning rate
from tensorflow.keras.optimizers import SGD
model.compile(
optimizer=SGD(learning_rate=0.0001, momentum=0.9),
loss='sparse_categorical_crossentropy',
metrics=["accuracy"],
)
[13]:
# we train our model again (this time fine-tuning the top 2 inception blocks
# alongside the top Dense layers
history = model.fit(train_set, validation_data=test_set, epochs=50)
Epoch 1/50
31/31 [==============================] - 13s 118ms/step - loss: 1.7043 - accuracy: 0.2882 - val_loss: 1.6911 - val_accuracy: 0.3458
Epoch 2/50
31/31 [==============================] - 2s 71ms/step - loss: 1.6528 - accuracy: 0.3330 - val_loss: 1.6635 - val_accuracy: 0.3375
Epoch 3/50
31/31 [==============================] - 2s 70ms/step - loss: 1.6001 - accuracy: 0.3673 - val_loss: 1.6358 - val_accuracy: 0.3250
Epoch 4/50
31/31 [==============================] - 2s 71ms/step - loss: 1.5633 - accuracy: 0.3975 - val_loss: 1.6324 - val_accuracy: 0.3208
Epoch 5/50
31/31 [==============================] - 2s 71ms/step - loss: 1.5255 - accuracy: 0.3975 - val_loss: 1.6274 - val_accuracy: 0.3250
Epoch 6/50
31/31 [==============================] - 2s 71ms/step - loss: 1.4992 - accuracy: 0.4048 - val_loss: 1.6040 - val_accuracy: 0.3167
Epoch 7/50
31/31 [==============================] - 2s 71ms/step - loss: 1.4658 - accuracy: 0.4287 - val_loss: 1.5823 - val_accuracy: 0.3292
Epoch 8/50
31/31 [==============================] - 2s 72ms/step - loss: 1.4350 - accuracy: 0.4454 - val_loss: 1.5678 - val_accuracy: 0.3208
Epoch 9/50
31/31 [==============================] - 2s 71ms/step - loss: 1.4175 - accuracy: 0.4443 - val_loss: 1.5482 - val_accuracy: 0.3250
Epoch 10/50
31/31 [==============================] - 2s 71ms/step - loss: 1.3867 - accuracy: 0.4589 - val_loss: 1.5394 - val_accuracy: 0.3292
Epoch 11/50
31/31 [==============================] - 2s 71ms/step - loss: 1.3550 - accuracy: 0.4880 - val_loss: 1.5272 - val_accuracy: 0.3167
Epoch 12/50
31/31 [==============================] - 2s 72ms/step - loss: 1.3254 - accuracy: 0.4932 - val_loss: 1.5242 - val_accuracy: 0.3250
Epoch 13/50
31/31 [==============================] - 2s 72ms/step - loss: 1.3148 - accuracy: 0.5120 - val_loss: 1.5063 - val_accuracy: 0.3458
Epoch 14/50
31/31 [==============================] - 2s 72ms/step - loss: 1.2890 - accuracy: 0.5297 - val_loss: 1.4933 - val_accuracy: 0.3500
Epoch 15/50
31/31 [==============================] - 2s 72ms/step - loss: 1.2470 - accuracy: 0.5609 - val_loss: 1.4905 - val_accuracy: 0.3333
Epoch 16/50
31/31 [==============================] - 2s 71ms/step - loss: 1.2174 - accuracy: 0.5525 - val_loss: 1.4791 - val_accuracy: 0.3458
Epoch 17/50
31/31 [==============================] - 2s 71ms/step - loss: 1.1906 - accuracy: 0.5723 - val_loss: 1.4681 - val_accuracy: 0.3500
Epoch 18/50
31/31 [==============================] - 2s 72ms/step - loss: 1.1686 - accuracy: 0.5817 - val_loss: 1.4669 - val_accuracy: 0.3500
Epoch 19/50
31/31 [==============================] - 2s 72ms/step - loss: 1.1449 - accuracy: 0.6046 - val_loss: 1.4621 - val_accuracy: 0.3625
Epoch 20/50
31/31 [==============================] - 2s 70ms/step - loss: 1.1219 - accuracy: 0.6181 - val_loss: 1.4552 - val_accuracy: 0.3667
Epoch 21/50
31/31 [==============================] - 2s 72ms/step - loss: 1.0932 - accuracy: 0.6327 - val_loss: 1.4428 - val_accuracy: 0.3708
Epoch 22/50
31/31 [==============================] - 2s 71ms/step - loss: 1.0565 - accuracy: 0.6462 - val_loss: 1.4397 - val_accuracy: 0.3625
Epoch 23/50
31/31 [==============================] - 2s 71ms/step - loss: 1.0192 - accuracy: 0.6722 - val_loss: 1.4278 - val_accuracy: 0.3792
Epoch 24/50
31/31 [==============================] - 2s 72ms/step - loss: 0.9906 - accuracy: 0.6785 - val_loss: 1.4345 - val_accuracy: 0.3833
Epoch 25/50
31/31 [==============================] - 2s 72ms/step - loss: 0.9632 - accuracy: 0.6930 - val_loss: 1.4148 - val_accuracy: 0.3875
Epoch 26/50
31/31 [==============================] - 2s 72ms/step - loss: 0.9382 - accuracy: 0.7242 - val_loss: 1.4085 - val_accuracy: 0.4000
Epoch 27/50
31/31 [==============================] - 2s 73ms/step - loss: 0.9137 - accuracy: 0.7471 - val_loss: 1.4146 - val_accuracy: 0.3958
Epoch 28/50
31/31 [==============================] - 2s 72ms/step - loss: 0.8796 - accuracy: 0.7409 - val_loss: 1.4115 - val_accuracy: 0.4042
Epoch 29/50
31/31 [==============================] - 2s 72ms/step - loss: 0.8471 - accuracy: 0.7419 - val_loss: 1.4027 - val_accuracy: 0.4167
Epoch 30/50
31/31 [==============================] - 2s 73ms/step - loss: 0.8274 - accuracy: 0.7607 - val_loss: 1.3981 - val_accuracy: 0.3958
Epoch 31/50
31/31 [==============================] - 2s 72ms/step - loss: 0.8057 - accuracy: 0.7638 - val_loss: 1.3978 - val_accuracy: 0.4125
Epoch 32/50
31/31 [==============================] - 2s 72ms/step - loss: 0.7804 - accuracy: 0.7950 - val_loss: 1.4013 - val_accuracy: 0.4083
Epoch 33/50
31/31 [==============================] - 2s 72ms/step - loss: 0.7360 - accuracy: 0.8158 - val_loss: 1.3971 - val_accuracy: 0.4125
Epoch 34/50
31/31 [==============================] - 2s 73ms/step - loss: 0.7070 - accuracy: 0.8158 - val_loss: 1.3992 - val_accuracy: 0.3875
Epoch 35/50
31/31 [==============================] - 2s 72ms/step - loss: 0.6836 - accuracy: 0.8252 - val_loss: 1.3861 - val_accuracy: 0.3958
Epoch 36/50
31/31 [==============================] - 2s 72ms/step - loss: 0.6650 - accuracy: 0.8408 - val_loss: 1.3846 - val_accuracy: 0.4042
Epoch 37/50
31/31 [==============================] - 2s 72ms/step - loss: 0.6598 - accuracy: 0.8252 - val_loss: 1.4019 - val_accuracy: 0.4250
Epoch 38/50
31/31 [==============================] - 2s 73ms/step - loss: 0.6042 - accuracy: 0.8522 - val_loss: 1.3979 - val_accuracy: 0.4208
Epoch 39/50
31/31 [==============================] - 2s 72ms/step - loss: 0.5765 - accuracy: 0.8793 - val_loss: 1.3919 - val_accuracy: 0.4042
Epoch 40/50
31/31 [==============================] - 2s 73ms/step - loss: 0.5533 - accuracy: 0.8907 - val_loss: 1.3877 - val_accuracy: 0.4042
Epoch 41/50
31/31 [==============================] - 2s 71ms/step - loss: 0.5341 - accuracy: 0.8772 - val_loss: 1.4072 - val_accuracy: 0.4083
Epoch 42/50
31/31 [==============================] - 2s 72ms/step - loss: 0.5130 - accuracy: 0.9043 - val_loss: 1.4061 - val_accuracy: 0.3833
Epoch 43/50
31/31 [==============================] - 2s 73ms/step - loss: 0.4917 - accuracy: 0.9084 - val_loss: 1.4121 - val_accuracy: 0.3750
Epoch 44/50
31/31 [==============================] - 2s 72ms/step - loss: 0.4768 - accuracy: 0.9136 - val_loss: 1.4326 - val_accuracy: 0.4042
Epoch 45/50
31/31 [==============================] - 2s 72ms/step - loss: 0.4641 - accuracy: 0.9084 - val_loss: 1.4192 - val_accuracy: 0.4083
Epoch 46/50
31/31 [==============================] - 2s 71ms/step - loss: 0.4416 - accuracy: 0.9282 - val_loss: 1.4249 - val_accuracy: 0.4042
Epoch 47/50
31/31 [==============================] - 2s 71ms/step - loss: 0.4154 - accuracy: 0.9251 - val_loss: 1.4281 - val_accuracy: 0.3958
Epoch 48/50
31/31 [==============================] - 2s 72ms/step - loss: 0.3910 - accuracy: 0.9396 - val_loss: 1.4334 - val_accuracy: 0.4000
Epoch 49/50
31/31 [==============================] - 2s 71ms/step - loss: 0.3940 - accuracy: 0.9396 - val_loss: 1.4281 - val_accuracy: 0.4000
Epoch 50/50
31/31 [==============================] - 2s 71ms/step - loss: 0.3821 - accuracy: 0.9490 - val_loss: 1.4365 - val_accuracy: 0.4250
[14]:
# Draw accuracy values for training & validation
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
单机模式小结#
以上我们按照官方教程在数据集Flower 上成功微调了 InceptionV3 模型,分别是微调顶部分类器和冻结底层网络层微调顶层网络层。接下来我们将展示如何将单机模式下的微调拓展到联邦学习模式下进行微调。
联邦学习模式进行微调#
环境设置#
首先我们初始化各个参与方。
[15]:
%load_ext autoreload
%autoreload 2
[16]:
import secretflow as sf
# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob', 'charlie'], address="local", log_to_driver=False)
alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')
The version of SecretFlow: 1.2.0.dev20230926
2023-10-11 07:16:09,152 INFO worker.py:1538 -- Started a local Ray instance.
定义Dataloader#
我们可以参考TensorFlow下的DataBuilder教程定义我们自己的DataBuilder。
[17]:
def create_dataset_builder(
batch_size=32,
):
def dataset_builder(folder_path, stage="train"):
import math
import tensorflow as tf
img_height = 180
img_width = 180
data_set = tf.keras.utils.image_dataset_from_directory(
folder_path,
validation_split=0.2,
subset="both",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size,
)
if stage == "train":
train_dataset = data_set[0]
train_step_per_epoch = math.ceil(len(data_set[0].file_paths) / batch_size)
return train_dataset, train_step_per_epoch
elif stage == "eval":
eval_dataset = data_set[1]
eval_step_per_epoch = math.ceil(len(data_set[1].file_paths) / batch_size)
return eval_dataset, eval_step_per_epoch
return dataset_builder
[18]:
data_builder_dict = {
alice: create_dataset_builder(
batch_size=32,
),
bob: create_dataset_builder(
batch_size=32,
),
}
定义 SecureAggregator#
[19]:
from secretflow.ml.nn import FLModel
from secretflow.security.aggregation import SecureAggregator
device_list = [alice, bob]
aggregator = SecureAggregator(charlie, [alice, bob])
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.
定义数据加载路径#
为了简便起见,我们在 单机模拟模式下直接加载同一处路径所对应的数据集
[20]:
data = {
alice: path_to_flower_dataset,
bob: path_to_flower_dataset,
}
定义联邦学习训练参数#
[21]:
epochs = 50
batch_size = 32
aggregate_freq = 2
sampler_method = "batch"
random_seed = 1234
dp_spent_step_freq = 1
微调顶部分类器#
我们只要参照教程里对模型的定义,在函数里完成我们对模型的定义即可;可以看到代码几乎不需要作任何修改,只需要进行适当的封装。 为了方便作对比实验,我们额外添加是否加载权重的选项。
[22]:
def create_inception_v3_model_classifier(num_classes, is_load_weight=True):
def create_model():
from tensorflow import keras
# Create model
# create the base pre-trained model
if is_load_weight:
base_model = InceptionV3(weights='imagenet', include_top=False)
else:
base_model = InceptionV3(weights=None, include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 10 classes
predictions = Dense(num_classes, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
layer.trainable = False
# Compile model
model.compile(
optimizer='rmsprop',
loss='sparse_categorical_crossentropy',
metrics=["accuracy"],
)
return model
return create_model
加载预训练模型权重并且微调#
[23]:
# prepare model
num_classes = 5
# keras model
weight_model = create_inception_v3_model_classifier(
num_classes=num_classes, is_load_weight=True
)
fed_model = FLModel(
device_list=device_list,
model=weight_model,
aggregator=aggregator,
backend="tensorflow",
strategy="fed_avg_w",
random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[24]:
history = fed_model.fit(
data,
None,
validation_data=data,
epochs=epochs,
batch_size=batch_size,
aggregate_freq=aggregate_freq,
sampler_method=sampler_method,
random_seed=random_seed,
dp_spent_step_freq=dp_spent_step_freq,
dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbd4bb1cb50>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:44, 3.28s/it, epoch: 1/50 - loss:102.29530334472656 accuracy:0.25494277477264404 val_loss:36.37749099731445 val_accuracy:0.25 ]
32it [01:39, 3.10s/it, epoch: 2/50 - loss:23.1259708404541 accuracy:0.27976685762405396 val_loss:10.871068954467773 val_accuracy:0.26249998807907104 ]
32it [01:38, 3.08s/it, epoch: 3/50 - loss:13.160760879516602 accuracy:0.2706078290939331 val_loss:15.88563346862793 val_accuracy:0.2708333432674408 ]
32it [01:38, 3.06s/it, epoch: 4/50 - loss:11.84687328338623 accuracy:0.2930890917778015 val_loss:27.221759796142578 val_accuracy:0.25 ]
32it [01:37, 3.06s/it, epoch: 5/50 - loss:12.541316032409668 accuracy:0.3080765902996063 val_loss:46.98291778564453 val_accuracy:0.1458333283662796 ]
32it [01:38, 3.07s/it, epoch: 6/50 - loss:15.415522575378418 accuracy:0.29142382740974426 val_loss:10.973555564880371 val_accuracy:0.17916665971279144 ]
32it [01:38, 3.07s/it, epoch: 7/50 - loss:6.303822994232178 accuracy:0.3255620300769806 val_loss:8.56163501739502 val_accuracy:0.20416666567325592 ]
32it [01:54, 3.58s/it, epoch: 8/50 - loss:5.007680892944336 accuracy:0.34637802839279175 val_loss:6.916962146759033 val_accuracy:0.3499999940395355 ]
32it [01:55, 3.61s/it, epoch: 9/50 - loss:3.740983247756958 accuracy:0.3855120837688446 val_loss:3.589289665222168 val_accuracy:0.2083333283662796 ]
32it [01:55, 3.62s/it, epoch: 10/50 - loss:2.5561084747314453 accuracy:0.35803496837615967 val_loss:5.7687087059021 val_accuracy:0.15833333134651184 ]
32it [01:40, 3.13s/it, epoch: 11/50 - loss:2.6316604614257812 accuracy:0.3921732008457184 val_loss:4.043315410614014 val_accuracy:0.25 ]
32it [01:38, 3.09s/it, epoch: 12/50 - loss:2.2317721843719482 accuracy:0.4063280522823334 val_loss:2.3408310413360596 val_accuracy:0.27916666865348816 ]
32it [01:39, 3.12s/it, epoch: 13/50 - loss:1.6880862712860107 accuracy:0.41631972789764404 val_loss:4.694206237792969 val_accuracy:0.17083333432674408 ]
32it [01:39, 3.11s/it, epoch: 14/50 - loss:2.0930988788604736 accuracy:0.4113239049911499 val_loss:3.81128191947937 val_accuracy:0.17916665971279144 ]
32it [01:50, 3.44s/it, epoch: 15/50 - loss:1.8601850271224976 accuracy:0.42381349205970764 val_loss:2.3201189041137695 val_accuracy:0.28333333134651184 ]
32it [01:47, 3.36s/it, epoch: 16/50 - loss:1.5195019245147705 accuracy:0.43880099058151245 val_loss:1.9348057508468628 val_accuracy:0.3499999940395355 ]
32it [02:12, 4.13s/it, epoch: 17/50 - loss:1.4174845218658447 accuracy:0.47960034012794495 val_loss:2.867670774459839 val_accuracy:0.27916666865348816 ]
32it [01:45, 3.29s/it, epoch: 18/50 - loss:1.5716767311096191 accuracy:0.4995836913585663 val_loss:2.5769591331481934 val_accuracy:0.2874999940395355 ]
32it [01:40, 3.14s/it, epoch: 19/50 - loss:1.5218170881271362 accuracy:0.4787676930427551 val_loss:2.084118604660034 val_accuracy:0.27916666865348816 ]
32it [01:38, 3.08s/it, epoch: 20/50 - loss:1.3385194540023804 accuracy:0.4970857501029968 val_loss:2.267355442047119 val_accuracy:0.34166666865348816 ]
32it [01:40, 3.13s/it, epoch: 21/50 - loss:1.4143364429473877 accuracy:0.5145711898803711 val_loss:1.8168503046035767 val_accuracy:0.2916666567325592 ]
32it [01:41, 3.17s/it, epoch: 22/50 - loss:1.223915696144104 accuracy:0.5253955125808716 val_loss:2.421297550201416 val_accuracy:0.34166666865348816 ]
32it [01:51, 3.48s/it, epoch: 23/50 - loss:1.3997722864151 accuracy:0.5278934240341187 val_loss:2.1914479732513428 val_accuracy:0.38333332538604736 ]
32it [01:39, 3.10s/it, epoch: 24/50 - loss:1.2561317682266235 accuracy:0.5770191550254822 val_loss:2.868597984313965 val_accuracy:0.3291666805744171 ]
32it [01:44, 3.28s/it, epoch: 25/50 - loss:1.4317070245742798 accuracy:0.5611990094184875 val_loss:2.570885181427002 val_accuracy:0.3166666626930237 ]
32it [01:45, 3.30s/it, epoch: 26/50 - loss:1.3605040311813354 accuracy:0.5578684210777283 val_loss:1.9396979808807373 val_accuracy:0.3083333373069763 ]
32it [01:39, 3.11s/it, epoch: 27/50 - loss:1.1538910865783691 accuracy:0.5711906552314758 val_loss:2.737755060195923 val_accuracy:0.3333333432674408 ]
32it [01:52, 3.53s/it, epoch: 28/50 - loss:1.295885682106018 accuracy:0.5961698293685913 val_loss:2.398267984390259 val_accuracy:0.36666667461395264 ]
32it [01:42, 3.20s/it, epoch: 29/50 - loss:1.2216304540634155 accuracy:0.5845128893852234 val_loss:5.456529140472412 val_accuracy:0.3166666626930237 ]
32it [01:42, 3.22s/it, epoch: 30/50 - loss:1.8455870151519775 accuracy:0.6069941520690918 val_loss:2.622253894805908 val_accuracy:0.3499999940395355 ]
32it [01:37, 3.06s/it, epoch: 31/50 - loss:1.2096481323242188 accuracy:0.607826828956604 val_loss:2.6999380588531494 val_accuracy:0.3583333194255829 ]
32it [01:52, 3.52s/it, epoch: 32/50 - loss:1.25503671169281 accuracy:0.6153205633163452 val_loss:3.1902101039886475 val_accuracy:0.3916666805744171 ]
32it [01:40, 3.13s/it, epoch: 33/50 - loss:1.321710467338562 accuracy:0.6319733262062073 val_loss:2.9719605445861816 val_accuracy:0.3125 ]
32it [02:03, 3.87s/it, epoch: 34/50 - loss:1.2177612781524658 accuracy:0.6394671201705933 val_loss:2.9436466693878174 val_accuracy:0.3499999940395355 ]
32it [01:38, 3.07s/it, epoch: 35/50 - loss:1.2386927604675293 accuracy:0.6561198830604553 val_loss:4.200243949890137 val_accuracy:0.2666666805744171 ]
32it [01:41, 3.19s/it, epoch: 36/50 - loss:1.5145163536071777 accuracy:0.6344712972640991 val_loss:2.730128765106201 val_accuracy:0.375 ]
32it [01:40, 3.16s/it, epoch: 37/50 - loss:1.150541067123413 accuracy:0.6644462943077087 val_loss:3.111203670501709 val_accuracy:0.34166666865348816 ]
32it [01:50, 3.45s/it, epoch: 38/50 - loss:1.268009066581726 accuracy:0.6477935314178467 val_loss:3.186652183532715 val_accuracy:0.2958333194255829 ]
32it [01:43, 3.23s/it, epoch: 39/50 - loss:1.1637378931045532 accuracy:0.6727727055549622 val_loss:2.7717678546905518 val_accuracy:0.32499998807907104 ]
32it [01:43, 3.25s/it, epoch: 40/50 - loss:1.1945478916168213 accuracy:0.6594504714012146 val_loss:4.414767265319824 val_accuracy:0.3375000059604645 ]
32it [01:41, 3.19s/it, epoch: 41/50 - loss:1.4453009366989136 accuracy:0.6860949397087097 val_loss:3.308169364929199 val_accuracy:0.3166666626930237 ]
32it [01:43, 3.22s/it, epoch: 42/50 - loss:1.1969460248947144 accuracy:0.6835970282554626 val_loss:3.0412395000457764 val_accuracy:0.40833333134651184 ]
32it [01:50, 3.44s/it, epoch: 43/50 - loss:1.0718883275985718 accuracy:0.7085762023925781 val_loss:4.295853137969971 val_accuracy:0.32499998807907104 ]
32it [01:48, 3.38s/it, epoch: 44/50 - loss:1.3032809495925903 accuracy:0.7027477025985718 val_loss:3.5949547290802 val_accuracy:0.375 ]
32it [01:47, 3.35s/it, epoch: 45/50 - loss:1.3326811790466309 accuracy:0.6827643513679504 val_loss:3.4334940910339355 val_accuracy:0.3791666626930237 ]
32it [01:43, 3.23s/it, epoch: 46/50 - loss:1.1082518100738525 accuracy:0.7252289652824402 val_loss:2.774402141571045 val_accuracy:0.3499999940395355 ]
32it [01:43, 3.22s/it, epoch: 47/50 - loss:1.052221655845642 accuracy:0.6994171738624573 val_loss:2.918473958969116 val_accuracy:0.36250001192092896 ]
32it [01:40, 3.14s/it, epoch: 48/50 - loss:1.0545026063919067 accuracy:0.7152373194694519 val_loss:3.524212121963501 val_accuracy:0.375 ]
32it [01:43, 3.22s/it, epoch: 49/50 - loss:1.1607937812805176 accuracy:0.7243963479995728 val_loss:3.213914632797241 val_accuracy:0.3499999940395355 ]
32it [01:38, 3.08s/it, epoch: 50/50 - loss:1.0676075220108032 accuracy:0.7277268767356873 val_loss:4.0257086753845215 val_accuracy:0.34583333134651184 ]
[25]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
只加载网络结构同时随机初始化#
[26]:
# keras model
no_weight_model = create_inception_v3_model_classifier(
num_classes=num_classes, is_load_weight=False
)
fed_model = FLModel(
device_list=device_list,
model=no_weight_model,
aggregator=aggregator,
backend="tensorflow",
strategy="fed_avg_w",
random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[27]:
history = fed_model.fit(
data,
None,
validation_data=data,
epochs=epochs,
batch_size=batch_size,
aggregate_freq=aggregate_freq,
sampler_method=sampler_method,
random_seed=random_seed,
dp_spent_step_freq=dp_spent_step_freq,
dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbd4baefbe0>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:49, 3.41s/it, epoch: 1/50 - loss:1.6169096231460571 accuracy:0.2788761854171753 val_loss:2.0081751346588135 val_accuracy:0.25 ]
32it [01:41, 3.18s/it, epoch: 2/50 - loss:1.6776905059814453 accuracy:0.2805995047092438 val_loss:1.5821692943572998 val_accuracy:0.2083333283662796 ]
32it [01:40, 3.15s/it, epoch: 3/50 - loss:1.5732810497283936 accuracy:0.27560365200042725 val_loss:1.6988409757614136 val_accuracy:0.25 ]
32it [01:43, 3.22s/it, epoch: 4/50 - loss:1.5928469896316528 accuracy:0.29059118032455444 val_loss:1.6616723537445068 val_accuracy:0.25 ]
32it [01:43, 3.25s/it, epoch: 5/50 - loss:1.5776023864746094 accuracy:0.2972522974014282 val_loss:1.5816072225570679 val_accuracy:0.25 ]
32it [01:39, 3.10s/it, epoch: 6/50 - loss:1.5489486455917358 accuracy:0.3064113259315491 val_loss:1.5895153284072876 val_accuracy:0.25 ]
32it [01:38, 3.09s/it, epoch: 7/50 - loss:1.5415023565292358 accuracy:0.33888426423072815 val_loss:1.6028796434402466 val_accuracy:0.2874999940395355 ]
32it [01:42, 3.21s/it, epoch: 8/50 - loss:1.5391955375671387 accuracy:0.3488759398460388 val_loss:1.6169476509094238 val_accuracy:0.2666666805744171 ]
32it [01:39, 3.11s/it, epoch: 9/50 - loss:1.5235228538513184 accuracy:0.37552040815353394 val_loss:1.6155515909194946 val_accuracy:0.25 ]
32it [01:42, 3.21s/it, epoch: 10/50 - loss:1.5191270112991333 accuracy:0.3522064983844757 val_loss:1.600549340248108 val_accuracy:0.32499998807907104 ]
32it [01:39, 3.09s/it, epoch: 11/50 - loss:1.5050208568572998 accuracy:0.3705245554447174 val_loss:1.6656410694122314 val_accuracy:0.25 ]
32it [01:45, 3.31s/it, epoch: 12/50 - loss:1.5016016960144043 accuracy:0.3563697040081024 val_loss:1.5471644401550293 val_accuracy:0.2541666626930237 ]
32it [01:36, 3.02s/it, epoch: 13/50 - loss:1.465255618095398 accuracy:0.37968358397483826 val_loss:1.54298996925354 val_accuracy:0.3791666626930237 ]
32it [01:42, 3.21s/it, epoch: 14/50 - loss:1.4461535215377808 accuracy:0.40549543499946594 val_loss:1.5411031246185303 val_accuracy:0.3791666626930237 ]
32it [01:38, 3.07s/it, epoch: 15/50 - loss:1.4370651245117188 accuracy:0.40882596373558044 val_loss:1.490233063697815 val_accuracy:0.3291666805744171 ]
32it [01:41, 3.17s/it, epoch: 16/50 - loss:1.413556456565857 accuracy:0.4046627879142761 val_loss:1.590057373046875 val_accuracy:0.3333333432674408 ]
32it [01:42, 3.19s/it, epoch: 17/50 - loss:1.425291895866394 accuracy:0.40549543499946594 val_loss:1.5726838111877441 val_accuracy:0.26249998807907104 ]
32it [01:46, 3.34s/it, epoch: 18/50 - loss:1.411468267440796 accuracy:0.3930058181285858 val_loss:1.6937826871871948 val_accuracy:0.25 ]
32it [01:39, 3.10s/it, epoch: 19/50 - loss:1.4267196655273438 accuracy:0.4004995822906494 val_loss:1.7930010557174683 val_accuracy:0.25 ]
32it [01:41, 3.16s/it, epoch: 20/50 - loss:1.444821834564209 accuracy:0.39134055376052856 val_loss:1.909441351890564 val_accuracy:0.2708333432674408 ]
32it [01:43, 3.22s/it, epoch: 21/50 - loss:1.4686118364334106 accuracy:0.40216487646102905 val_loss:1.622828722000122 val_accuracy:0.3083333373069763 ]
32it [01:39, 3.10s/it, epoch: 22/50 - loss:1.4035924673080444 accuracy:0.4104912579059601 val_loss:1.453572392463684 val_accuracy:0.3708333373069763 ]
32it [01:45, 3.31s/it, epoch: 23/50 - loss:1.359868049621582 accuracy:0.41631972789764404 val_loss:1.6003872156143188 val_accuracy:0.3583333194255829 ]
32it [01:39, 3.09s/it, epoch: 24/50 - loss:1.3811829090118408 accuracy:0.42381349205970764 val_loss:1.5266224145889282 val_accuracy:0.3166666626930237 ]
32it [01:42, 3.22s/it, epoch: 25/50 - loss:1.3589435815811157 accuracy:0.40965861082077026 val_loss:2.0248773097991943 val_accuracy:0.2708333432674408 ]
32it [01:39, 3.11s/it, epoch: 26/50 - loss:1.4731003046035767 accuracy:0.4154870808124542 val_loss:1.4912035465240479 val_accuracy:0.34583333134651184 ]
32it [01:42, 3.20s/it, epoch: 27/50 - loss:1.3412293195724487 accuracy:0.41631972789764404 val_loss:1.3930459022521973 val_accuracy:0.3791666626930237 ]
32it [01:42, 3.20s/it, epoch: 28/50 - loss:1.314136266708374 accuracy:0.43130725622177124 val_loss:1.555684208869934 val_accuracy:0.3166666626930237 ]
32it [01:52, 3.51s/it, epoch: 29/50 - loss:1.3471096754074097 accuracy:0.42797669768333435 val_loss:1.6817519664764404 val_accuracy:0.2666666805744171 ]
32it [01:38, 3.06s/it, epoch: 30/50 - loss:1.3709834814071655 accuracy:0.407993346452713 val_loss:1.4343681335449219 val_accuracy:0.36666667461395264 ]
32it [01:48, 3.41s/it, epoch: 31/50 - loss:1.3120677471160889 accuracy:0.43463781476020813 val_loss:1.4698201417922974 val_accuracy:0.3499999940395355 ]
32it [01:37, 3.05s/it, epoch: 32/50 - loss:1.3166753053665161 accuracy:0.43130725622177124 val_loss:1.4647256135940552 val_accuracy:0.34166666865348816 ]
32it [01:40, 3.15s/it, epoch: 33/50 - loss:1.316184401512146 accuracy:0.4263114035129547 val_loss:1.7285205125808716 val_accuracy:0.25833332538604736 ]
32it [01:41, 3.17s/it, epoch: 34/50 - loss:1.3727495670318604 accuracy:0.4154870808124542 val_loss:1.3565254211425781 val_accuracy:0.3916666805744171 ]
32it [01:44, 3.26s/it, epoch: 35/50 - loss:1.280211329460144 accuracy:0.45045796036720276 val_loss:1.469690203666687 val_accuracy:0.30000001192092896 ]
32it [01:44, 3.27s/it, epoch: 36/50 - loss:1.3061801195144653 accuracy:0.41715237498283386 val_loss:1.4475810527801514 val_accuracy:0.38749998807907104 ]
32it [01:39, 3.11s/it, epoch: 37/50 - loss:1.2898013591766357 accuracy:0.443796843290329 val_loss:1.6531250476837158 val_accuracy:0.28333333134651184 ]
32it [01:43, 3.22s/it, epoch: 38/50 - loss:1.3455705642700195 accuracy:0.42714405059814453 val_loss:1.5225203037261963 val_accuracy:0.38749998807907104 ]
32it [01:39, 3.12s/it, epoch: 39/50 - loss:1.30066978931427 accuracy:0.4371357262134552 val_loss:1.4351023435592651 val_accuracy:0.38749998807907104 ]
32it [01:42, 3.21s/it, epoch: 40/50 - loss:1.2890324592590332 accuracy:0.45295587182044983 val_loss:1.563452959060669 val_accuracy:0.3125 ]
32it [01:37, 3.03s/it, epoch: 41/50 - loss:1.3143178224563599 accuracy:0.43963363766670227 val_loss:1.3913853168487549 val_accuracy:0.4124999940395355 ]
32it [01:46, 3.34s/it, epoch: 42/50 - loss:1.2762908935546875 accuracy:0.45711907744407654 val_loss:1.3839548826217651 val_accuracy:0.3916666805744171 ]
32it [01:36, 3.00s/it, epoch: 43/50 - loss:1.2671079635620117 accuracy:0.4587843418121338 val_loss:1.4633164405822754 val_accuracy:0.38749998807907104 ]
32it [01:42, 3.20s/it, epoch: 44/50 - loss:1.2803606986999512 accuracy:0.45711907744407654 val_loss:1.5430619716644287 val_accuracy:0.30000001192092896 ]
32it [01:39, 3.11s/it, epoch: 45/50 - loss:1.3015990257263184 accuracy:0.437968373298645 val_loss:2.1071672439575195 val_accuracy:0.2958333194255829 ]
32it [01:43, 3.24s/it, epoch: 46/50 - loss:1.4335447549819946 accuracy:0.4371357262134552 val_loss:1.4610460996627808 val_accuracy:0.3583333194255829 ]
32it [01:42, 3.21s/it, epoch: 47/50 - loss:1.276992678642273 accuracy:0.46128225326538086 val_loss:1.4364838600158691 val_accuracy:0.375 ]
32it [01:44, 3.25s/it, epoch: 48/50 - loss:1.2780916690826416 accuracy:0.4662781059741974 val_loss:1.4993922710418701 val_accuracy:0.3916666805744171 ]
32it [01:42, 3.22s/it, epoch: 49/50 - loss:1.2837833166122437 accuracy:0.443796843290329 val_loss:1.5112351179122925 val_accuracy:0.3333333432674408 ]
32it [01:41, 3.17s/it, epoch: 50/50 - loss:1.29127836227417 accuracy:0.45045796036720276 val_loss:1.6569411754608154 val_accuracy:0.34583333134651184 ]
[28]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
冻结底层微调顶层网络#
我们只要参照教程里对模型的定义,在函数里完成我们对模型的定义即可;可以看到代码几乎不需要作任何修改,只需要进行适当的封装。 为了方便作对比实验,我们额外添加是否加载权重的选项。
[29]:
def create_inception_v3_model_fine_tune(num_classes, is_load_weight=True):
def create_model():
from tensorflow import keras
# Create model
# create the base pre-trained model
if is_load_weight:
base_model = InceptionV3(weights='imagenet', include_top=False)
else:
base_model = InceptionV3(weights=None, include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 10 classes
predictions = Dense(num_classes, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
for layer in model.layers[:249]:
layer.trainable = False
for layer in model.layers[249:]:
layer.trainable = True
# Compile model
model.compile(
optimizer=SGD(learning_rate=0.0001, momentum=0.9),
loss='sparse_categorical_crossentropy',
metrics=["accuracy"],
)
return model
return create_model
加载预训练模型权重并且微调#
[30]:
# keras model
weight_model = create_inception_v3_model_fine_tune(
num_classes=num_classes, is_load_weight=True
)
fed_model = FLModel(
device_list=device_list,
model=weight_model,
aggregator=aggregator,
backend="tensorflow",
strategy="fed_avg_w",
random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[31]:
history = fed_model.fit(
data,
None,
validation_data=data,
epochs=epochs,
batch_size=batch_size,
aggregate_freq=aggregate_freq,
sampler_method=sampler_method,
random_seed=random_seed,
dp_spent_step_freq=dp_spent_step_freq,
dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbc1c65c160>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:50, 3.46s/it, epoch: 1/50 - loss:1.650011658668518 accuracy:0.20915712416172028 val_loss:1.656758427619934 val_accuracy:0.24583333730697632 ]
32it [01:46, 3.32s/it, epoch: 2/50 - loss:1.5472569465637207 accuracy:0.32389676570892334 val_loss:1.58930242061615 val_accuracy:0.28333333134651184 ]
32it [01:45, 3.29s/it, epoch: 3/50 - loss:1.4915741682052612 accuracy:0.36552873253822327 val_loss:1.5358479022979736 val_accuracy:0.30000001192092896 ]
32it [01:43, 3.25s/it, epoch: 4/50 - loss:1.4296882152557373 accuracy:0.4004995822906494 val_loss:1.5131781101226807 val_accuracy:0.32499998807907104 ]
32it [01:45, 3.30s/it, epoch: 5/50 - loss:1.3764379024505615 accuracy:0.43963363766670227 val_loss:1.485734462738037 val_accuracy:0.36666667461395264 ]
32it [01:41, 3.16s/it, epoch: 6/50 - loss:1.334008812904358 accuracy:0.4845961630344391 val_loss:1.4631234407424927 val_accuracy:0.3541666567325592 ]
32it [01:53, 3.56s/it, epoch: 7/50 - loss:1.2897144556045532 accuracy:0.5104079842567444 val_loss:1.4368029832839966 val_accuracy:0.3499999940395355 ]
32it [01:38, 3.08s/it, epoch: 8/50 - loss:1.2628018856048584 accuracy:0.526228129863739 val_loss:1.4215635061264038 val_accuracy:0.36666667461395264 ]
32it [01:44, 3.27s/it, epoch: 9/50 - loss:1.2168503999710083 accuracy:0.5553705096244812 val_loss:1.4049427509307861 val_accuracy:0.3916666805744171 ]
32it [01:36, 3.03s/it, epoch: 10/50 - loss:1.187528371810913 accuracy:0.5611990094184875 val_loss:1.3869177103042603 val_accuracy:0.3916666805744171 ]
32it [01:37, 3.04s/it, epoch: 11/50 - loss:1.1617594957351685 accuracy:0.5903413891792297 val_loss:1.3873077630996704 val_accuracy:0.3958333432674408 ]
32it [01:42, 3.19s/it, epoch: 12/50 - loss:1.1248602867126465 accuracy:0.5936719179153442 val_loss:1.3818570375442505 val_accuracy:0.4000000059604645 ]
32it [01:39, 3.11s/it, epoch: 13/50 - loss:1.0991374254226685 accuracy:0.6369692087173462 val_loss:1.366493821144104 val_accuracy:0.42500001192092896 ]
32it [01:45, 3.31s/it, epoch: 14/50 - loss:1.0843088626861572 accuracy:0.6353039145469666 val_loss:1.3633487224578857 val_accuracy:0.42916667461395264 ]
32it [01:41, 3.17s/it, epoch: 15/50 - loss:1.059274673461914 accuracy:0.6636136770248413 val_loss:1.3560353517532349 val_accuracy:0.4375 ]
32it [01:44, 3.26s/it, epoch: 16/50 - loss:1.0273948907852173 accuracy:0.6644462943077087 val_loss:1.3572434186935425 val_accuracy:0.4333333373069763 ]
32it [01:39, 3.12s/it, epoch: 17/50 - loss:0.9911747574806213 accuracy:0.6994171738624573 val_loss:1.3419238328933716 val_accuracy:0.44999998807907104 ]
32it [01:50, 3.44s/it, epoch: 18/50 - loss:0.9911114573478699 accuracy:0.67194002866745 val_loss:1.3493646383285522 val_accuracy:0.42500001192092896 ]
32it [01:40, 3.16s/it, epoch: 19/50 - loss:0.9671102166175842 accuracy:0.6852622628211975 val_loss:1.3433830738067627 val_accuracy:0.42916667461395264 ]
32it [01:46, 3.32s/it, epoch: 20/50 - loss:0.937901496887207 accuracy:0.7052456140518188 val_loss:1.3399254083633423 val_accuracy:0.4416666626930237 ]
32it [01:42, 3.21s/it, epoch: 21/50 - loss:0.906328558921814 accuracy:0.7277268767356873 val_loss:1.328268051147461 val_accuracy:0.4541666805744171 ]
32it [01:52, 3.53s/it, epoch: 22/50 - loss:0.8943104147911072 accuracy:0.7402164936065674 val_loss:1.3352607488632202 val_accuracy:0.44583332538604736 ]
32it [01:44, 3.26s/it, epoch: 23/50 - loss:0.8718297481536865 accuracy:0.7452123165130615 val_loss:1.310567021369934 val_accuracy:0.4583333432674408 ]
32it [01:53, 3.53s/it, epoch: 24/50 - loss:0.8595983982086182 accuracy:0.746044933795929 val_loss:1.3058205842971802 val_accuracy:0.46666666865348816 ]
32it [01:40, 3.14s/it, epoch: 25/50 - loss:0.8207837343215942 accuracy:0.7618651390075684 val_loss:1.3195736408233643 val_accuracy:0.47083333134651184 ]
32it [01:41, 3.18s/it, epoch: 26/50 - loss:0.8151105642318726 accuracy:0.7718567848205566 val_loss:1.3052656650543213 val_accuracy:0.4625000059604645 ]
32it [01:39, 3.11s/it, epoch: 27/50 - loss:0.8001042008399963 accuracy:0.7718567848205566 val_loss:1.3049904108047485 val_accuracy:0.46666666865348816 ]
32it [01:42, 3.20s/it, epoch: 28/50 - loss:0.7668044567108154 accuracy:0.7901748418807983 val_loss:1.307594895362854 val_accuracy:0.4625000059604645 ]
32it [01:40, 3.13s/it, epoch: 29/50 - loss:0.7354434728622437 accuracy:0.8043297529220581 val_loss:1.306896686553955 val_accuracy:0.47083333134651184 ]
32it [01:42, 3.21s/it, epoch: 30/50 - loss:0.7426672577857971 accuracy:0.8018317818641663 val_loss:1.3118059635162354 val_accuracy:0.4833333194255829 ]
32it [01:38, 3.07s/it, epoch: 31/50 - loss:0.7213742733001709 accuracy:0.8126561045646667 val_loss:1.296970248222351 val_accuracy:0.4625000059604645 ]
32it [01:37, 3.04s/it, epoch: 32/50 - loss:0.700598955154419 accuracy:0.8184846043586731 val_loss:1.305118203163147 val_accuracy:0.4583333432674408 ]
32it [01:37, 3.05s/it, epoch: 33/50 - loss:0.6826915144920349 accuracy:0.8259783387184143 val_loss:1.3077011108398438 val_accuracy:0.4791666567325592 ]
32it [01:37, 3.05s/it, epoch: 34/50 - loss:0.6565826535224915 accuracy:0.8251457214355469 val_loss:1.301528811454773 val_accuracy:0.46666666865348816 ]
32it [01:40, 3.13s/it, epoch: 35/50 - loss:0.656404972076416 accuracy:0.8226478099822998 val_loss:1.306329369544983 val_accuracy:0.4749999940395355 ]
32it [01:40, 3.13s/it, epoch: 36/50 - loss:0.6250473856925964 accuracy:0.8384679555892944 val_loss:1.3014804124832153 val_accuracy:0.46666666865348816 ]
32it [01:45, 3.29s/it, epoch: 37/50 - loss:0.6090459227561951 accuracy:0.8426311612129211 val_loss:1.3036556243896484 val_accuracy:0.46666666865348816 ]
32it [01:41, 3.17s/it, epoch: 38/50 - loss:0.6007826924324036 accuracy:0.8409658670425415 val_loss:1.3221803903579712 val_accuracy:0.46666666865348816 ]
32it [01:45, 3.29s/it, epoch: 39/50 - loss:0.5874541997909546 accuracy:0.8467943668365479 val_loss:1.3001883029937744 val_accuracy:0.4583333432674408 ]
32it [01:40, 3.14s/it, epoch: 40/50 - loss:0.5784027576446533 accuracy:0.8459616899490356 val_loss:1.29843270778656 val_accuracy:0.4583333432674408 ]
32it [01:45, 3.31s/it, epoch: 41/50 - loss:0.5616000890731812 accuracy:0.8434637784957886 val_loss:1.3310949802398682 val_accuracy:0.4749999940395355 ]
32it [01:44, 3.25s/it, epoch: 42/50 - loss:0.564281165599823 accuracy:0.8559533953666687 val_loss:1.3145432472229004 val_accuracy:0.47083333134651184 ]
32it [01:43, 3.25s/it, epoch: 43/50 - loss:0.54606032371521 accuracy:0.8592839241027832 val_loss:1.311663269996643 val_accuracy:0.4625000059604645 ]
32it [01:44, 3.25s/it, epoch: 44/50 - loss:0.5349538922309875 accuracy:0.8601165413856506 val_loss:1.3167394399642944 val_accuracy:0.4625000059604645 ]
32it [01:46, 3.32s/it, epoch: 45/50 - loss:0.5256470441818237 accuracy:0.8667776584625244 val_loss:1.3188197612762451 val_accuracy:0.4791666567325592 ]
32it [01:41, 3.16s/it, epoch: 46/50 - loss:0.5097993016242981 accuracy:0.8676103353500366 val_loss:1.3174797296524048 val_accuracy:0.4791666567325592 ]
32it [01:40, 3.15s/it, epoch: 47/50 - loss:0.49822941422462463 accuracy:0.8742714524269104 val_loss:1.3001458644866943 val_accuracy:0.4583333432674408 ]
32it [01:47, 3.35s/it, epoch: 48/50 - loss:0.4908905625343323 accuracy:0.8667776584625244 val_loss:1.3209236860275269 val_accuracy:0.46666666865348816 ]
32it [01:39, 3.10s/it, epoch: 49/50 - loss:0.48140043020248413 accuracy:0.8726061582565308 val_loss:1.3325889110565186 val_accuracy:0.4749999940395355 ]
32it [01:43, 3.23s/it, epoch: 50/50 - loss:0.4809170365333557 accuracy:0.8792672753334045 val_loss:1.336763858795166 val_accuracy:0.4833333194255829 ]
[32]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
只加载网络结构同时随机初始化#
[33]:
# keras model
no_weight_model = create_inception_v3_model_fine_tune(
num_classes=num_classes, is_load_weight=False
)
fed_model = FLModel(
device_list=device_list,
model=no_weight_model,
aggregator=aggregator,
backend="tensorflow",
strategy="fed_avg_w",
random_seed=1234,
)
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.tensorflow.strategy.fed_avg_w.PYUFedAvgW'> with party bob.
[34]:
history = fed_model.fit(
data,
None,
validation_data=data,
epochs=epochs,
batch_size=batch_size,
aggregate_freq=aggregate_freq,
sampler_method=sampler_method,
random_seed=random_seed,
dp_spent_step_freq=dp_spent_step_freq,
dataset_builder=data_builder_dict,
)
INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7fbc24074700>, 'x': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 50, 'verbose': 1, 'callbacks': None, 'validation_data': {PYURuntime(alice): '/tmp/tmpc7wdf9us/datasets/flower_photos', PYURuntime(bob): '/tmp/tmpc7wdf9us/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbb907ee700>, PYURuntime(bob): <function create_dataset_builder.<locals>.dataset_builder at 0x7fbadc0d5040>}, 'wait_steps': 100}
32it [01:48, 3.38s/it, epoch: 1/50 - loss:1.5617533922195435 accuracy:0.27783557772636414 val_loss:1.6185896396636963 val_accuracy:0.13750000298023224 ]
32it [01:45, 3.30s/it, epoch: 2/50 - loss:1.4592828750610352 accuracy:0.34970858693122864 val_loss:1.6081194877624512 val_accuracy:0.15833333134651184 ]
32it [01:41, 3.16s/it, epoch: 3/50 - loss:1.388456106185913 accuracy:0.4254787564277649 val_loss:1.5979021787643433 val_accuracy:0.2750000059604645 ]
32it [01:45, 3.28s/it, epoch: 4/50 - loss:1.3454346656799316 accuracy:0.43547043204307556 val_loss:1.5885674953460693 val_accuracy:0.2916666567325592 ]
32it [01:43, 3.23s/it, epoch: 5/50 - loss:1.3147145509719849 accuracy:0.46461281180381775 val_loss:1.57625150680542 val_accuracy:0.28333333134651184 ]
32it [01:42, 3.21s/it, epoch: 6/50 - loss:1.2754597663879395 accuracy:0.4854288101196289 val_loss:1.563430905342102 val_accuracy:0.30000001192092896 ]
32it [01:38, 3.07s/it, epoch: 7/50 - loss:1.2425297498703003 accuracy:0.5045795440673828 val_loss:1.5455553531646729 val_accuracy:0.2750000059604645 ]
32it [01:45, 3.30s/it, epoch: 8/50 - loss:1.2141278982162476 accuracy:0.5129058957099915 val_loss:1.5239124298095703 val_accuracy:0.2874999940395355 ]
32it [01:51, 3.48s/it, epoch: 9/50 - loss:1.1969578266143799 accuracy:0.5237302184104919 val_loss:1.4922692775726318 val_accuracy:0.38333332538604736 ]
32it [01:37, 3.05s/it, epoch: 10/50 - loss:1.1684715747833252 accuracy:0.5653622150421143 val_loss:1.4653935432434082 val_accuracy:0.36666667461395264 ]
32it [01:38, 3.09s/it, epoch: 11/50 - loss:1.1470826864242554 accuracy:0.5678601264953613 val_loss:1.439568281173706 val_accuracy:0.40833333134651184 ]
32it [01:42, 3.19s/it, epoch: 12/50 - loss:1.1196565628051758 accuracy:0.572023332118988 val_loss:1.399143934249878 val_accuracy:0.44583332538604736 ]
32it [01:47, 3.36s/it, epoch: 13/50 - loss:1.1069118976593018 accuracy:0.6036636233329773 val_loss:1.3591387271881104 val_accuracy:0.47083333134651184 ]
32it [01:48, 3.40s/it, epoch: 14/50 - loss:1.0912777185440063 accuracy:0.6036636233329773 val_loss:1.327014684677124 val_accuracy:0.48750001192092896 ]
32it [01:50, 3.45s/it, epoch: 15/50 - loss:1.0678472518920898 accuracy:0.6094920635223389 val_loss:1.2993195056915283 val_accuracy:0.5083333253860474 ]
32it [01:41, 3.18s/it, epoch: 16/50 - loss:1.0440828800201416 accuracy:0.6111573576927185 val_loss:1.2653034925460815 val_accuracy:0.5166666507720947 ]
32it [01:48, 3.38s/it, epoch: 17/50 - loss:1.0225656032562256 accuracy:0.6294754147529602 val_loss:1.231634497642517 val_accuracy:0.5249999761581421 ]
32it [01:42, 3.20s/it, epoch: 18/50 - loss:1.020599365234375 accuracy:0.6203163862228394 val_loss:1.2129393815994263 val_accuracy:0.5333333611488342 ]
32it [01:45, 3.29s/it, epoch: 19/50 - loss:0.9966413974761963 accuracy:0.6286427974700928 val_loss:1.1990076303482056 val_accuracy:0.5333333611488342 ]
32it [01:42, 3.21s/it, epoch: 20/50 - loss:0.9871851801872253 accuracy:0.6402997374534607 val_loss:1.176818609237671 val_accuracy:0.550000011920929 ]
32it [01:51, 3.49s/it, epoch: 21/50 - loss:0.9735381603240967 accuracy:0.6511240601539612 val_loss:1.164151906967163 val_accuracy:0.512499988079071 ]
32it [01:44, 3.27s/it, epoch: 22/50 - loss:0.9548952579498291 accuracy:0.6486261487007141 val_loss:1.1686121225357056 val_accuracy:0.5333333611488342 ]
32it [01:43, 3.22s/it, epoch: 23/50 - loss:0.9554542899131775 accuracy:0.6444629430770874 val_loss:1.1495128870010376 val_accuracy:0.5333333611488342 ]
32it [01:55, 3.61s/it, epoch: 24/50 - loss:0.9482444524765015 accuracy:0.6494587659835815 val_loss:1.136460542678833 val_accuracy:0.5375000238418579 ]
32it [01:52, 3.53s/it, epoch: 25/50 - loss:0.9186312556266785 accuracy:0.6702747941017151 val_loss:1.1416422128677368 val_accuracy:0.5416666865348816 ]
32it [01:42, 3.21s/it, epoch: 26/50 - loss:0.9188517928123474 accuracy:0.6736053228378296 val_loss:1.1264033317565918 val_accuracy:0.550000011920929 ]
32it [01:41, 3.17s/it, epoch: 27/50 - loss:0.9028259515762329 accuracy:0.681931734085083 val_loss:1.1303426027297974 val_accuracy:0.550000011920929 ]
32it [01:45, 3.31s/it, epoch: 28/50 - loss:0.8891870975494385 accuracy:0.6844296455383301 val_loss:1.1231117248535156 val_accuracy:0.5458333492279053 ]
32it [01:43, 3.23s/it, epoch: 29/50 - loss:0.8797008991241455 accuracy:0.6977518796920776 val_loss:1.1278127431869507 val_accuracy:0.5416666865348816 ]
32it [01:47, 3.37s/it, epoch: 30/50 - loss:0.878724217414856 accuracy:0.6752706170082092 val_loss:1.1188760995864868 val_accuracy:0.574999988079071 ]
32it [01:40, 3.13s/it, epoch: 31/50 - loss:0.8734909296035767 accuracy:0.703580379486084 val_loss:1.121822476387024 val_accuracy:0.550000011920929 ]
32it [01:50, 3.44s/it, epoch: 32/50 - loss:0.8486524820327759 accuracy:0.6985844969749451 val_loss:1.127307415008545 val_accuracy:0.5416666865348816 ]
32it [01:45, 3.29s/it, epoch: 33/50 - loss:0.8515317440032959 accuracy:0.6944212913513184 val_loss:1.1144821643829346 val_accuracy:0.5625 ]
32it [01:44, 3.26s/it, epoch: 34/50 - loss:0.8291710615158081 accuracy:0.7135720252990723 val_loss:1.1182698011398315 val_accuracy:0.5458333492279053 ]
32it [01:49, 3.41s/it, epoch: 35/50 - loss:0.8350727558135986 accuracy:0.7019150853157043 val_loss:1.1012609004974365 val_accuracy:0.5833333134651184 ]
32it [01:48, 3.38s/it, epoch: 36/50 - loss:0.8297075033187866 accuracy:0.7227310538291931 val_loss:1.1068880558013916 val_accuracy:0.5625 ]
32it [01:49, 3.41s/it, epoch: 37/50 - loss:0.809542715549469 accuracy:0.7293921709060669 val_loss:1.1091495752334595 val_accuracy:0.5458333492279053 ]
32it [01:43, 3.24s/it, epoch: 38/50 - loss:0.8065611720085144 accuracy:0.7185678482055664 val_loss:1.09843909740448 val_accuracy:0.5708333253860474 ]
32it [01:46, 3.34s/it, epoch: 39/50 - loss:0.7942371964454651 accuracy:0.7310574650764465 val_loss:1.105940580368042 val_accuracy:0.5708333253860474 ]
32it [01:41, 3.16s/it, epoch: 40/50 - loss:0.7934896349906921 accuracy:0.7227310538291931 val_loss:1.1048657894134521 val_accuracy:0.5666666626930237 ]
32it [01:43, 3.23s/it, epoch: 41/50 - loss:0.7780565023422241 accuracy:0.7327227592468262 val_loss:1.0963842868804932 val_accuracy:0.574999988079071 ]
32it [01:43, 3.24s/it, epoch: 42/50 - loss:0.7623825669288635 accuracy:0.7452123165130615 val_loss:1.0992350578308105 val_accuracy:0.5583333373069763 ]
32it [01:44, 3.27s/it, epoch: 43/50 - loss:0.7639929056167603 accuracy:0.7385511994361877 val_loss:1.0936287641525269 val_accuracy:0.5708333253860474 ]
32it [01:40, 3.13s/it, epoch: 44/50 - loss:0.7649388909339905 accuracy:0.7485429048538208 val_loss:1.0890998840332031 val_accuracy:0.5625 ]
32it [01:44, 3.25s/it, epoch: 45/50 - loss:0.7667314410209656 accuracy:0.7477102279663086 val_loss:1.1028246879577637 val_accuracy:0.5833333134651184 ]
32it [01:39, 3.11s/it, epoch: 46/50 - loss:0.7435023188591003 accuracy:0.7468776106834412 val_loss:1.0851361751556396 val_accuracy:0.5791666507720947 ]
32it [01:38, 3.07s/it, epoch: 47/50 - loss:0.7359048128128052 accuracy:0.7643630504608154 val_loss:1.0931051969528198 val_accuracy:0.5625 ]
32it [01:37, 3.05s/it, epoch: 48/50 - loss:0.729657769203186 accuracy:0.7626977562904358 val_loss:1.083522915840149 val_accuracy:0.5708333253860474 ]
32it [01:44, 3.27s/it, epoch: 49/50 - loss:0.7189591526985168 accuracy:0.7601998448371887 val_loss:1.0874546766281128 val_accuracy:0.574999988079071 ]
32it [01:39, 3.10s/it, epoch: 50/50 - loss:0.7207762598991394 accuracy:0.7760199904441833 val_loss:1.0822055339813232 val_accuracy:0.5708333253860474 ]
[35]:
# Draw accuracy values for training & validation
plt.plot(history.global_history['accuracy'])
plt.plot(history.global_history['val_accuracy'])
plt.title('FLModel accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
# Draw loss for training & validation
plt.plot(history.global_history['loss'])
plt.plot(history.global_history['val_loss'])
plt.title('FLModel loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Valid'], loc='upper left')
plt.show()
联邦学习小结#
可以看到,对照着 TensorFlow 的官方教程,隐语能够无缝地兼容所给出的微调方式;并且我们可以看到,通过对预训练模型的兼容,我们可以不需要自己再重新写出复杂网络的模型结构,InceptionV3 的网络结构源代码位于:source code of Inception V3,并且通过对比实验我们可以看出,加载预训练模型的权重,可以让我们的模型性能更优秀。
总结#
本篇教程,我们以Inception V3为例介绍了如何在隐语的联邦学习模式下基于直接加载 TensorFlow.Keras 的 预训练模型,通过直接加载预训练模型,我们能够获得: - 不需要再次编写复杂模型的结构代码 - 基于预训练模型进行微调和迁移学习 - 使用预训练权重模型能够使得联邦模型获得更好的性能