SPU 训练神经网络#
请先阅读 Logistic Regression On SPU 。
在实验 Logistic Regression On SPU 中,我们展示了如何使用 SecretFlow/SPU 将明文 JAX 训练程序转换为安全 MPC 训练程序。
在这个实验室中,这个想法非常相似,但这次我们将使用神经网络模型。
我们将使用相同的数据集和所有设置。
首先,让我们制定明文模型。
以下代码仅作为示例,请勿在生产环境直接使用。
本教程需要比 8c16g 更多的资源,这是 SecretFlow 的最低要求。
使用 JAX/FLAX 训练模型#
加载数据集#
以下内容是复制于实验 Logistic Regression On SPU 。 我不会再次在这里解释。
[ ]:
import sys
!{sys.executable} -m pip install flax==0.6.0
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
x, y = load_breast_cancer(return_X_y=True)
x = (x - np.min(x)) / (np.max(x) - np.min(x))
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=42
)
if train:
if party_id:
if party_id == 1:
return x_train[:, :15], _
else:
return x_train[:, 15:], y_train
else:
return x_train, y_train
else:
return x_test, y_test
定义模型#
[2]:
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
然后我们在这里定义训练方法。
[3]:
import jax.numpy as jnp
def predict(params, x):
# TODO(junfeng): investigate why need to have a duplicated definition in notebook,
# which is not the case in a normal python program.
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
return MLP(FEATURES).apply(params, x)
def loss_func(params, x, y):
pred = predict(params, x)
def mse(y, pred):
def squared_error(y, y_pred):
return jnp.multiply(y - y_pred, y - y_pred) / 2.0
return jnp.mean(squared_error(y, pred))
return mse(y, pred)
def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
x = jnp.concatenate((x1, x2), axis=1)
xs = jnp.array_split(x, len(x) / n_batch, axis=0)
ys = jnp.array_split(y, len(y) / n_batch, axis=0)
def body_fun(_, loop_carry):
params = loop_carry
for x, y in zip(xs, ys):
_, grads = jax.value_and_grad(loss_func)(params, x, y)
params = jax.tree_util.tree_map(
lambda p, g: p - step_size * g, params, grads
)
return params
params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
return params
def model_init(n_batch=10):
model = MLP(FEATURES)
return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))
验证模型#
我们使用 AUC 作为验证指标。
[4]:
from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):
y_pred = predict(params, X_test)
return roc_auc_score(y_test, y_pred)
放在一起#
让我们把所有不住放在一起,训练一个明文 NN 模型!
[5]:
import jax
# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)
# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01
# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)
# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9927939731411726
让我们把所有东西放在一起,训练一个明文神经网络模型!
使用 SPU 训练模型#
[6]:
import secretflow as sf
# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)
device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)
params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
2023-04-28 14:35:12,293 INFO worker.py:1538 -- Started a local Ray instance.
(_run pid=175587) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=175587) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=175587) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=175587) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=176242) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=176242) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=176242) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=176242) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=180401) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=180401) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=180401) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=180401) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=177052) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=177052) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=177052) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=177052) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=187368) 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime pid=187367) 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
让我们检查 SPU 程序的参数。
[7]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)
FrozenDict({
params: {
Dense_0: {
bias: array([ 6.7055225e-06, 6.7055225e-06, 6.7055225e-06, -8.4322095e-03,
4.7300994e-02, 4.5412779e-04, 6.7055225e-06, 4.5442879e-03,
6.7055225e-06, -3.4062415e-02, -8.3989948e-03, 6.7055225e-06,
6.7055225e-06, 5.6699291e-02, -4.8456341e-03, 6.7055225e-06,
3.5731569e-02, 6.3510090e-03, 3.0306578e-03, 3.2686546e-02,
6.7055225e-06, -2.1292433e-02, -7.7798963e-03, 6.7055225e-06,
2.8470993e-02, 6.7055225e-06, -3.0836165e-03, 4.5374036e-05,
1.4400020e-02, 2.0861626e-02], dtype=float32),
kernel: array([[-0.14870723, -0.23531294, -0.1493704 , -0.01558255, -0.13322462,
0.1917662 , -0.03679654, -0.03744406, -0.1417609 , 0.03231682,
0.12653404, -0.4025072 , -0.1689485 , 0.21399944, -0.13844648,
0.10585822, -0.11602122, 0.38625073, 0.05966607, 0.06318197,
0.0779368 , -0.01318966, -0.28804308, -0.09602153, 0.11111972,
-0.08543564, 0.07547122, -0.04118884, -0.38266844, 0.23767346],
[ 0.17795108, 0.2294012 , -0.24440196, -0.14849591, 0.33701816,
0.0258413 , -0.04214501, 0.41052908, 0.32439357, -0.16435765,
0.08169709, 0.05259326, 0.3113483 , 0.2931838 , 0.12270276,
-0.38752455, -0.38534215, -0.06536001, -0.25914845, -0.3322725 ,
-0.31587672, -0.29117638, -0.06018265, 0.2297913 , 0.10114388,
-0.01309358, 0.17881514, -0.23215818, 0.3828069 , 0.03806593],
[ 0.1428942 , -0.02135262, 0.16819091, 0.08982845, -0.38852412,
-0.04850367, 0.13870972, -0.05800854, 0.28472927, -0.12711032,
0.25702882, 0.09648418, 0.11670431, -0.1896179 , -0.03994708,
-0.09573121, 0.07308613, 0.14650668, 0.09226419, 0.03892082,
-0.24624617, 0.03725916, -0.01914255, -0.25209764, 0.17078896,
0.24982187, -0.0028675 , -0.09844984, 0.20797251, 0.08843645],
[-0.22511783, -0.0044653 , -0.04557581, -0.04286373, -0.13053825,
-0.3426896 , -0.00925103, 0.09015252, -0.2824888 , 0.22022144,
0.11647445, 0.04475737, -0.05021369, 0.29519165, -0.23622867,
0.05994891, 0.2596493 , 0.18784739, 0.14603132, 0.2965685 ,
0.03959064, 0.16071922, -0.11333965, -0.06968905, 0.26477575,
-0.317869 , 0.08121799, 0.25563055, -0.05901612, 0.19531868],
[-0.24254663, 0.07968816, -0.06768736, -0.11746876, 0.1875621 ,
-0.06137984, -0.05366111, -0.06934479, 0.07924516, 0.02541035,
-0.31857365, 0.28704768, -0.06027508, -0.30148876, -0.17660952,
0.07973847, 0.1614199 , 0.3279493 , 0.20515053, 0.30348837,
0.2711059 , 0.276556 , 0.07071564, 0.20800509, -0.07333609,
-0.10324922, 0.01553461, 0.31758228, 0.31677115, -0.06809102],
[ 0.07325926, 0.06064408, 0.0530773 , 0.17844556, 0.18787359,
0.17704393, 0.08110972, 0.01482402, -0.04424939, 0.06166127,
0.28827167, 0.05878101, 0.26427427, 0.12087436, -0.02181949,
-0.15166327, -0.04630022, 0.00738053, 0.2839891 , 0.10080083,
-0.3035335 , -0.31350654, -0.17609106, 0.11223568, 0.1156193 ,
-0.27605468, -0.06867941, 0.06136122, 0.3082044 , -0.28000844],
[-0.25858068, -0.01556443, 0.27713627, -0.38400537, 0.39872903,
-0.12919384, -0.02736983, 0.17572944, 0.13031955, 0.15870668,
-0.02625516, 0.29411823, 0.03559025, 0.03587727, -0.2966054 ,
-0.16969463, 0.0300006 , 0.16187829, -0.17532285, -0.08767432,
-0.04854703, -0.10537073, 0.08301418, -0.04356302, -0.25446534,
-0.09856299, -0.04166624, -0.04677388, -0.3353408 , -0.11825959],
[ 0.27912897, -0.07000226, -0.02481516, 0.04389155, -0.08830354,
-0.00139034, 0.08731189, -0.24834795, 0.15356407, -0.12887374,
-0.00434314, -0.00279981, -0.07792975, -0.1029453 , 0.2409295 ,
-0.25699303, 0.2918012 , -0.19479287, -0.27555436, 0.01553042,
0.12703311, 0.1288091 , 0.15366644, 0.1431344 , 0.06207459,
-0.11137639, 0.05906925, -0.11649235, 0.01587239, -0.20323639],
[ 0.06792891, 0.08563136, -0.09104523, 0.17886826, 0.07520616,
-0.13827898, 0.33567435, -0.14805417, -0.03184932, 0.39237124,
-0.1335338 , 0.19828805, 0.05121414, -0.04607381, -0.12948062,
-0.22250798, 0.12677568, 0.39128548, -0.11602047, 0.00093162,
-0.07845107, 0.17064299, 0.2707931 , 0.06743585, 0.07426128,
-0.00924093, -0.0035352 , -0.3685534 , -0.12302665, 0.22056273],
[ 0.02833928, -0.12450014, 0.17981096, 0.15364204, 0.05483492,
0.19171704, -0.0949284 , 0.06867886, -0.07678194, -0.01938733,
0.05701402, -0.39338416, 0.05287948, 0.3794972 , 0.24641661,
-0.1212198 , 0.04000506, -0.38034967, 0.19541413, -0.0905077 ,
0.3206088 , 0.01485404, -0.03493308, 0.11109039, -0.33723742,
-0.30601716, -0.11324729, 0.1596858 , 0.06751473, 0.1008921 ],
[ 0.16805576, 0.19498089, -0.09763785, -0.14558062, 0.10152206,
-0.31742054, -0.11583678, -0.2865575 , -0.10120936, -0.13012367,
0.19799586, -0.06929106, 0.00183079, -0.06139433, -0.23812771,
0.14183812, 0.41206583, -0.11150262, -0.07695962, -0.03937718,
-0.05823223, -0.25616592, 0.17551638, -0.05776715, 0.04627597,
0.12046237, 0.31444448, -0.1823728 , -0.16253875, -0.09766676],
[-0.06190741, -0.11557767, 0.07265058, 0.12529932, 0.20684099,
0.15767016, -0.08056761, -0.19449666, 0.02133167, 0.23543602,
-0.17700855, -0.35116544, -0.22017023, 0.03137846, 0.10100484,
-0.40086156, -0.13380852, -0.06593318, -0.14122422, -0.17200904,
-0.0666105 , 0.09940979, -0.03091712, 0.25939053, 0.06447808,
-0.2506336 , -0.0349206 , 0.08023839, 0.25556827, -0.2408923 ],
[ 0.00898188, -0.40073588, -0.06301974, 0.06183384, 0.3735768 ,
0.03177406, 0.27502847, -0.28810993, -0.2024756 , 0.16113877,
-0.21794656, 0.10632099, 0.00266866, -0.27301037, 0.07529524,
0.07778189, 0.02633543, 0.09457737, -0.28337651, 0.0255892 ,
0.17133063, 0.04773571, -0.01299471, -0.0919252 , 0.22021984,
-0.1989678 , 0.34153467, 0.08680797, -0.08852738, -0.0090448 ],
[ 0.12035008, 0.12541962, -0.36259866, 0.22371957, -0.07335131,
0.10498597, 0.00436583, -0.08324738, 0.22863485, -0.14954014,
0.08159503, -0.3141421 , 0.08762485, -0.05525228, 0.08568875,
-0.02316961, -0.2230854 , 0.02858485, 0.10418503, 0.09759469,
0.08704272, 0.01555008, 0.17367665, 0.08375961, -0.01750728,
-0.06537268, 0.05048656, -0.22944517, 0.05722432, 0.25090805],
[-0.39710748, 0.10012694, -0.07080103, -0.16264898, -0.13910918,
0.16161054, 0.16022125, -0.00788775, 0.05428429, 0.16593601,
0.22370476, 0.36696965, 0.06149913, -0.04857542, 0.3345247 ,
-0.07260554, 0.1938989 , -0.06002848, -0.30302036, 0.17182748,
0.29064724, 0.21397091, 0.04791559, 0.09810503, 0.1033058 ,
0.12732479, -0.0579783 , 0.15246823, -0.3666319 , 0.1779919 ],
[ 0.01064624, 0.0888928 , 0.26858085, 0.34396815, 0.06943932,
0.30761874, -0.15886313, 0.00265385, 0.04297891, -0.06383656,
-0.01197957, -0.10140778, 0.03901416, -0.02126652, 0.13493209,
-0.16070978, -0.27638012, -0.11028586, 0.12214845, -0.2560637 ,
-0.08863154, 0.03597671, -0.1732396 , 0.12559041, 0.14788477,
0.09702435, 0.17843248, 0.08070756, 0.0718791 , 0.08296195],
[ 0.14691886, 0.13540354, -0.05013047, -0.2566406 , -0.2376638 ,
0.21672072, 0.1372795 , -0.03882806, 0.39052176, 0.0047731 ,
0.14544334, -0.0696618 , -0.15187763, 0.06678917, -0.24012098,
0.31160212, 0.06627946, -0.2530402 , -0.20175886, -0.22604358,
0.1381416 , -0.14101216, 0.3429103 , 0.12955913, 0.2845845 ,
0.06188303, -0.22960348, 0.2912202 , -0.08082792, -0.3445377 ],
[-0.01824994, 0.12698065, 0.11829151, -0.08935194, -0.04362963,
-0.06175369, -0.1114524 , -0.06696388, -0.34100425, -0.25512362,
-0.1483988 , -0.20127416, -0.00367533, 0.05239835, 0.06488706,
0.08272076, 0.05891787, 0.2134408 , -0.13793291, 0.30933803,
-0.09876332, -0.15072244, -0.10377637, 0.03409749, 0.0937078 ,
-0.22452421, 0.3597254 , 0.24009626, -0.03083205, -0.10381168],
[-0.14538439, 0.17941016, 0.01639399, -0.2706253 , -0.02600642,
-0.03973 , -0.0325162 , 0.03153259, -0.15472709, -0.09655666,
0.04076509, 0.1300038 , -0.19558378, -0.17638195, 0.12240331,
-0.26903665, 0.2714493 , -0.07004572, -0.07335924, 0.03825237,
0.22632292, 0.3012138 , 0.02217355, -0.30002278, -0.06066401,
-0.07689169, -0.37136257, 0.19665234, -0.10525645, -0.27408272],
[ 0.05384398, 0.03158583, -0.00409974, -0.04451011, -0.10076478,
-0.06426084, 0.3136195 , -0.13606365, 0.1243284 , -0.10924114,
-0.03940558, 0.22020963, -0.07174113, 0.08709462, 0.04955287,
0.36317343, 0.00659794, -0.15838777, 0.09210019, -0.17414865,
-0.14202411, 0.3834263 , 0.02247368, 0.00736032, -0.02805607,
-0.15887989, 0.03910746, -0.0943727 , 0.21787158, 0.01440434],
[-0.09300622, -0.19802521, -0.31412005, 0.17171307, 0.1331803 ,
-0.14113024, -0.21318011, -0.16237472, 0.09434846, 0.14660788,
0.01858762, -0.02211154, -0.14670722, 0.39278403, -0.20136856,
0.10904545, -0.02885009, -0.15209475, 0.1743193 , 0.0778787 ,
0.09585676, 0.10286772, 0.1895318 , -0.15744607, 0.0972386 ,
-0.26544875, -0.05130047, 0.08041063, 0.05855417, 0.24786705],
[ 0.05508722, 0.23071642, 0.00278442, -0.05163229, -0.13318591,
0.17231207, -0.0383717 , 0.17234325, -0.12098849, -0.12200612,
-0.165717 , -0.08695543, -0.01522441, -0.31668693, 0.196136 ,
-0.20849878, 0.34565175, 0.252592 , 0.03059202, -0.23635055,
-0.02455017, -0.07401715, 0.18046305, 0.08005303, 0.02341022,
0.05160871, 0.0830403 , -0.10961437, 0.2051303 , 0.05485763],
[-0.29294217, 0.01583408, -0.00052598, 0.07539546, 0.17627907,
0.16075702, 0.00591798, -0.02526975, -0.2719347 , -0.2642147 ,
0.17578189, 0.26844388, -0.16066906, 0.00551553, -0.41348425,
0.1321568 , 0.2071938 , -0.09202607, -0.32119918, 0.03001858,
-0.03515013, -0.11420041, 0.00692059, 0.06027223, 0.31073922,
0.31373912, 0.15468763, 0.23844069, 0.20547047, 0.165754 ],
[-0.1317544 , -0.09716719, -0.2110814 , 0.30688593, 0.13689038,
0.25466746, -0.23185365, 0.265381 , -0.20205005, 0.26761973,
-0.01471928, -0.17001429, -0.00165382, 0.10118251, 0.28316593,
-0.10187137, 0.02500786, 0.09213623, -0.06184761, 0.051311 ,
-0.13956325, 0.29834348, 0.16425882, -0.20013842, 0.10159607,
-0.09226643, -0.09284794, -0.24736227, 0.28198415, 0.18465933],
[-0.19596493, -0.26223665, -0.02396852, 0.1405711 , -0.05117449,
0.09832071, 0.10009323, 0.08764507, -0.20915532, -0.04817107,
0.11512975, -0.0107393 , 0.06286559, -0.14394692, 0.1831078 ,
0.18601051, -0.01792853, -0.010507 , 0.2988264 , 0.02924132,
0.1502285 , -0.02573505, 0.10515428, 0.32683268, -0.06475027,
-0.07946308, -0.33095527, -0.33394814, 0.14654751, -0.18609025],
[-0.04332547, 0.18820217, 0.03160366, 0.11940409, -0.22678787,
0.09432799, -0.08720809, 0.25600654, -0.14890012, 0.09946848,
0.18772584, -0.19526623, 0.0827599 , -0.14669879, -0.12541471,
-0.13776924, 0.09574251, -0.2980466 , 0.10541511, -0.11811657,
-0.23554784, -0.01769215, -0.29761636, 0.04322377, -0.04169539,
0.04331157, -0.10865457, 0.3526432 , 0.27452517, 0.01664442],
[ 0.17763771, -0.07080895, -0.12558904, -0.13398908, -0.22847766,
-0.20403627, 0.07889682, 0.13384837, -0.31691694, -0.13476555,
-0.08197045, 0.02778772, 0.02476428, 0.10588782, -0.25830707,
-0.24311969, 0.03762388, 0.05451898, -0.13534577, -0.10997833,
-0.3139264 , 0.05126831, -0.00060226, -0.15891929, -0.17077953,
0.2362888 , 0.08467598, 0.01052356, 0.08872832, 0.16418251],
[ 0.37544525, 0.0681546 , 0.07722013, -0.40396348, -0.05511683,
0.00878677, 0.33257678, 0.18474084, -0.0799066 , 0.20011736,
0.14146338, -0.15846273, -0.15961201, -0.18772689, -0.17597765,
-0.13404477, 0.21314138, -0.13090074, 0.10695033, 0.28710032,
0.13358802, 0.3303852 , -0.2687422 , -0.22376198, 0.29356587,
-0.03488064, 0.14832059, 0.12624982, -0.20833445, 0.05823356],
[ 0.17862728, -0.12085862, -0.07798004, 0.16461669, 0.13114056,
0.1119384 , -0.02916402, -0.01834482, 0.03708343, -0.39161655,
0.04380961, 0.12685701, -0.20311095, 0.14991562, 0.08968998,
-0.1430527 , 0.3768945 , 0.2545389 , 0.09408659, 0.30030465,
0.00201878, -0.03300162, -0.31967437, 0.08429171, -0.10358454,
0.15462488, 0.15204427, -0.00353977, -0.15648344, 0.03190795],
[-0.3349845 , -0.18704857, 0.12660322, 0.27142197, -0.04179126,
-0.01659705, -0.15886122, 0.14643206, 0.10317151, 0.139131 ,
0.26203057, 0.03828669, 0.17041986, 0.28139216, 0.03020249,
-0.21715921, 0.05988631, 0.20941454, 0.27820507, -0.30283943,
0.21741417, 0.06876856, -0.0162366 , -0.09319973, 0.16716208,
-0.05672812, -0.01678701, -0.33967227, 0.04148872, 0.24174951]],
dtype=float32),
},
Dense_1: {
bias: array([-9.28075612e-03, -4.38565016e-03, 6.73728883e-02, 6.70552254e-06,
-1.11967325e-02, 1.06356591e-02, -2.26502120e-02, -3.45642865e-03,
3.70997190e-02, 9.00812894e-02, 6.70552254e-06, 6.70552254e-06,
6.70552254e-06, -1.53630227e-02, 3.88986021e-02], dtype=float32),
kernel: array([[-0.21578309, -0.08008368, -0.34167936, -0.03616343, -0.04043388,
-0.19278756, 0.07816273, 0.3847432 , -0.27097666, 0.03089739,
-0.11206758, 0.12151396, 0.38484663, 0.12947203, 0.03026646],
[ 0.3035894 , 0.14900179, 0.02244793, 0.17264117, 0.0011169 ,
-0.1606707 , 0.17210394, -0.19850568, -0.00882789, 0.06376703,
-0.09706031, -0.27143008, 0.32902688, 0.01248117, -0.20562333],
[ 0.01422736, 0.25237322, 0.26592904, -0.07876748, 0.02570754,
0.13746765, -0.3037846 , -0.30066282, 0.228537 , 0.07397157,
-0.05444951, 0.06826244, -0.11475235, -0.04363853, -0.00258049],
[-0.03807974, 0.36382473, -0.05991563, 0.1660564 , -0.18014075,
0.17624326, -0.24441232, -0.31741685, -0.06890935, -0.04919542,
0.13665393, -0.05236177, 0.12887959, 0.2582429 , -0.06479871],
[ 0.01582411, -0.00546367, 0.06451672, 0.00377437, 0.05299711,
0.09622552, -0.33355796, 0.15772232, -0.00315991, 0.3426076 ,
-0.01920256, -0.0157837 , 0.10247016, -0.02410382, 0.14005862],
[-0.1542452 , -0.18916714, -0.12516193, -0.15350978, -0.20895821,
-0.03576617, 0.0180776 , 0.16850933, 0.05937128, 0.03776275,
0.07396616, 0.03354299, 0.06906249, 0.15164083, -0.2608541 ],
[ 0.0905385 , 0.3133642 , -0.17575558, 0.05339232, -0.19663817,
0.22920834, 0.21465397, 0.14934285, 0.30395645, -0.2403111 ,
0.11673826, -0.0449398 , -0.0359446 , 0.3089288 , -0.01469092],
[-0.09350568, -0.09241582, 0.29311585, 0.13808654, 0.14410885,
0.11155026, 0.19201808, -0.22068372, 0.0091573 , -0.00837527,
-0.10839485, -0.0492375 , 0.15357326, 0.3894407 , -0.15209287],
[ 0.0029071 , 0.18366341, 0.03765289, -0.01738603, 0.18317957,
0.00410259, 0.09655431, 0.07968767, 0.21980065, 0.22737293,
-0.15136166, 0.20435053, 0.11874333, -0.3370184 , 0.11251831],
[-0.03699435, 0.05359124, -0.00424996, -0.00427449, -0.20195475,
-0.12829332, 0.06293778, 0.13848272, -0.17896764, -0.38953093,
-0.07185236, 0.22985502, -0.11224222, 0.04145651, -0.3817618 ],
[ 0.23528674, 0.1663438 , -0.08346738, 0.20346904, -0.20409097,
-0.07192652, 0.11208971, 0.24518102, 0.23959732, -0.1391288 ,
-0.02638906, -0.11256091, -0.27086872, -0.00492385, 0.13006589],
[-0.05570208, -0.34653068, 0.298495 , -0.16680127, 0.06143057,
0.09288131, 0.1472318 , -0.12598082, -0.01329006, -0.26823848,
0.08741044, 0.10009366, 0.1264808 , 0.13802043, 0.2563799 ],
[ 0.01380032, -0.19647142, 0.14879738, 0.0388497 , -0.14403345,
0.3500362 , -0.03261025, -0.11959814, -0.35041225, -0.09013529,
0.16815332, -0.17363463, -0.26452613, 0.18936844, -0.30342007],
[ 0.15264955, -0.16593191, 0.2803555 , -0.02613318, 0.09317887,
-0.1145407 , -0.02915843, 0.09115867, 0.16309327, 0.16567504,
-0.16353543, -0.02392778, 0.21730614, -0.37557966, 0.36441088],
[ 0.27545726, -0.0511765 , 0.03052394, 0.38374472, 0.18914919,
-0.30549794, -0.1365143 , -0.09850363, -0.08355592, -0.17305706,
0.00163533, 0.27035654, -0.01430997, 0.01418965, -0.23040168],
[-0.11281115, -0.08904882, 0.05267188, -0.03345008, 0.17955093,
0.15272899, -0.05194864, 0.10906464, 0.21673168, -0.05776855,
0.29315004, -0.272271 , 0.22718571, -0.04166271, 0.08242701],
[ 0.11221845, 0.15372409, -0.13822478, -0.1822467 , -0.26139548,
0.22891735, -0.12165104, -0.20519899, 0.39132354, 0.19772069,
0.00470251, -0.04089966, -0.17769708, 0.22472003, 0.24131916],
[-0.08916578, -0.13332213, 0.11583962, -0.3159293 , -0.05461061,
-0.03293931, 0.17573284, -0.03388457, -0.04562169, -0.00728241,
-0.20086795, -0.04282369, -0.06481767, 0.00174528, 0.08415282],
[ 0.10727057, 0.15353328, 0.09634779, 0.01951087, -0.00729644,
-0.25289363, -0.23461105, 0.35619986, 0.1761693 , -0.18046483,
-0.25238073, -0.05560882, -0.20357345, -0.13479468, 0.14422338],
[-0.32128555, 0.01506339, 0.3208 , 0.3084674 , 0.06561027,
-0.20671532, 0.07110539, 0.09107907, -0.05795462, 0.06884666,
-0.24340847, 0.09923317, -0.39770418, -0.1435657 , 0.18189654],
[ 0.17829058, -0.3734911 , -0.344891 , -0.18513158, -0.1252987 ,
-0.359349 , -0.21523046, 0.4066509 , -0.06088345, -0.12821774,
0.30891037, -0.05408247, 0.13263397, 0.01792602, 0.22215459],
[-0.15978388, -0.19274645, -0.39842924, -0.13794418, 0.24811234,
-0.30259767, 0.25340182, 0.36628515, -0.04467097, 0.2068153 ,
0.10091661, -0.17184901, -0.01158652, 0.28761518, 0.07140049],
[-0.38753265, -0.2148714 , -0.34941888, -0.37459916, 0.00249913,
-0.38012785, -0.26021895, 0.06027205, -0.05131304, 0.24082436,
0.20541278, -0.09037189, -0.1668249 , 0.24143052, -0.26692837],
[-0.20972599, -0.01015192, 0.16557814, 0.20875406, -0.19013 ,
-0.31780058, -0.0311262 , -0.06458683, 0.39772552, -0.26640862,
0.31138209, -0.06382139, -0.39696902, 0.10767588, 0.01154487],
[ 0.18972392, 0.0260205 , 0.10645114, -0.21743847, -0.26412916,
0.15006566, 0.13827245, -0.21839432, -0.0661045 , 0.27946475,
-0.10810606, -0.32184672, -0.03605315, 0.04213592, -0.01746434],
[ 0.04842271, -0.17424968, 0.1226363 , 0.3272752 , -0.08305582,
-0.31486374, 0.10645151, 0.09955801, 0.07176954, -0.20580491,
0.04142599, -0.00282551, 0.15971068, 0.19535801, -0.218681 ],
[ 0.18084396, 0.0928266 , -0.27906552, -0.3218416 , 0.08461788,
-0.13167469, -0.22216263, 0.06937028, 0.10845792, -0.15438123,
-0.02529359, 0.03964274, -0.01773006, 0.04081337, 0.15702999],
[ 0.17229266, -0.27421105, 0.03916015, -0.10643585, 0.15348236,
-0.40775347, -0.14518811, -0.19719355, 0.15164377, 0.08712097,
-0.01809341, 0.03163807, -0.31661576, -0.08889712, -0.3158114 ],
[-0.09769286, -0.0287679 , 0.35823774, -0.2710501 , 0.32775387,
0.08072445, 0.30246186, -0.19245733, -0.17830896, 0.2923805 ,
0.09355405, -0.2524669 , 0.12927876, 0.38659224, -0.394949 ],
[ 0.06278192, 0.08469887, -0.00950807, 0.10956495, 0.09936713,
-0.19083263, 0.21161489, 0.3930601 , 0.00441524, 0.20089766,
-0.13769451, 0.27256852, -0.0958501 , -0.05921303, 0.33085537]],
dtype=float32),
},
Dense_2: {
bias: array([-1.85250044e-02, 1.07492432e-01, 6.70552254e-06, -2.40181834e-02,
1.02186531e-01, 1.17786124e-01, -7.42332637e-03, 6.70552254e-06],
dtype=float32),
kernel: array([[ 1.70731947e-01, -9.90723372e-02, -2.29330808e-02,
1.22201458e-01, -1.06555134e-01, -3.66996974e-02,
1.41981944e-01, 8.84072036e-02],
[ 1.80482566e-01, 1.18618101e-01, 5.27178943e-01,
-1.53569669e-01, -3.86155099e-02, -1.22957200e-01,
6.27236068e-03, 1.60065144e-02],
[ 1.45004794e-01, 5.34007728e-01, -3.44348907e-01,
-9.43183005e-02, 1.35729849e-01, -8.20837915e-03,
1.19165242e-01, -4.11043108e-01],
[ 2.49569073e-01, 2.14901850e-01, 3.24754179e-01,
-4.91983056e-01, -1.14351839e-01, -2.11404264e-02,
7.69451857e-02, 3.31384748e-01],
[ 4.12468165e-02, -1.21441126e-01, -3.10934186e-01,
3.62142444e-01, -2.74272680e-01, -5.16952693e-01,
-5.41899055e-02, 5.59365571e-01],
[-4.59907204e-02, -6.31701499e-02, 1.12813368e-01,
3.72401834e-01, 1.19809762e-01, 2.30254814e-01,
-1.38893276e-01, 4.47092503e-02],
[-6.63732141e-02, -3.46694767e-01, -4.84580398e-01,
1.17096156e-01, -2.02452645e-01, -3.72330904e-01,
5.67477584e-01, -2.42807895e-01],
[ 3.36939037e-01, -1.21135429e-01, 3.77379209e-01,
4.15782034e-01, 4.28560078e-02, -3.28819275e-01,
4.96320784e-01, -1.31850213e-01],
[ 3.29444051e-01, 7.28795379e-02, 2.03807220e-01,
1.12708807e-01, 3.64434421e-01, 1.61256164e-01,
-2.09103152e-01, 5.57109714e-04],
[ 1.63463175e-01, 4.33325768e-04, -4.44926351e-01,
-9.06594098e-02, 3.35962057e-01, 4.89929318e-01,
-1.45875633e-01, -2.39341617e-01],
[-2.70898372e-01, 1.98495671e-01, 1.63841411e-01,
-3.97940278e-01, 8.93494636e-02, 3.55310917e-01,
-4.32752073e-03, -2.55927563e-01],
[ 2.71851391e-01, -6.50364608e-02, 5.75378686e-02,
8.61035287e-02, 4.62560952e-02, 6.84097558e-02,
-3.49434435e-01, 3.20657253e-01],
[ 3.87204051e-01, 1.02552503e-01, -3.67724121e-01,
-1.37631506e-01, -2.76333094e-03, -9.74716246e-03,
-2.03522891e-02, -3.78593743e-01],
[ 1.50336102e-01, -1.60084844e-01, -4.73217756e-01,
2.41285011e-01, 4.34440225e-02, -3.39211166e-01,
1.99615315e-01, -1.64225787e-01],
[-9.00761038e-02, 1.48902372e-01, -3.05288136e-02,
4.30284500e-01, -3.87649029e-01, 5.50720513e-01,
-9.88744646e-02, -3.92888784e-01]], dtype=float32),
},
Dense_3: {
bias: array([0.23489459], dtype=float32),
kernel: array([[-0.07592051],
[ 0.6634396 ],
[ 0.193967 ],
[-0.21856818],
[ 0.44744045],
[ 0.7258886 ],
[-0.07619607],
[-0.20487107]], dtype=float32),
},
},
})
最后,让我们验证模型。
[8]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
auc=0.9927939731411726
实验到此结束。