Neural Network with SPU#
Please read lab Logistic Regression On SPU first if you have not。
In lab Logistic Regression On SPU, we have showed how to use SecretFlow/SPU to convert a plaintext JAX training program to a secure MPC training program.
In this lab, the idea is quite similar but this time we will work with a Neural Network model.
We are going to use the same dataset and all the settings as lab Logistic Regression On SPU.
And first, let’s work out the plaintext model.
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
This tutorial needs more resources than 8c16g, which is the minimum requirement of SecretFlow.
Train a model with JAX/FLAX#
Load the Dataset#
The below is just copied from lab Logistic Regression On SPU. I’m not going to explain again.
[ ]:
import sys
!{sys.executable} -m pip install flax==0.6.0
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
x, y = load_breast_cancer(return_X_y=True)
x = (x - np.min(x)) / (np.max(x) - np.min(x))
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=42
)
if train:
if party_id:
if party_id == 1:
return x_train[:, :15], _
else:
return x_train[:, 15:], y_train
else:
return x_train, y_train
else:
return x_test, y_test
Define the Model#
We are going to use a 4-layer MLP model with a ReLU activation function here.
[2]:
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
Then we define the training method here.
[3]:
import jax.numpy as jnp
def predict(params, x):
# TODO(junfeng): investigate why need to have a duplicated definition in notebook,
# which is not the case in a normal python program.
from typing import Sequence
import flax.linen as nn
FEATURES = [30, 15, 8, 1]
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
return MLP(FEATURES).apply(params, x)
def loss_func(params, x, y):
pred = predict(params, x)
def mse(y, pred):
def squared_error(y, y_pred):
return jnp.multiply(y - y_pred, y - y_pred) / 2.0
return jnp.mean(squared_error(y, pred))
return mse(y, pred)
def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
x = jnp.concatenate((x1, x2), axis=1)
xs = jnp.array_split(x, len(x) / n_batch, axis=0)
ys = jnp.array_split(y, len(y) / n_batch, axis=0)
def body_fun(_, loop_carry):
params = loop_carry
for x, y in zip(xs, ys):
_, grads = jax.value_and_grad(loss_func)(params, x, y)
params = jax.tree_util.tree_map(
lambda p, g: p - step_size * g, params, grads
)
return params
params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
return params
def model_init(n_batch=10):
model = MLP(FEATURES)
return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))
Validate the Model#
We use AUC as the validation metric.
[4]:
from sklearn.metrics import roc_auc_score
def validate_model(params, X_test, y_test):
y_pred = predict(params, X_test)
return roc_auc_score(y_test, y_pred)
BUILD Together#
Let’s put everything together and train a plaintext NN model!
[5]:
import jax
# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)
# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01
# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)
# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9927939731411726
Must keep the number of AUC in mind, we are going to repeat the training with SPU. Let’s do that magic!
Train a Model with SPU#
[6]:
import secretflow as sf
# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))
# In case you have a running secretflow runtime already.
sf.shutdown()
sf.init(['alice', 'bob'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))
x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)
device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)
params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
2023-04-28 14:35:12,293 INFO worker.py:1538 -- Started a local Ray instance.
(_run pid=175587) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=175587) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=175587) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=175587) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=176242) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=176242) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=176242) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=176242) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=180401) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=180401) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=180401) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=180401) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=177052) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=177052) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=177052) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=177052) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=187368) 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime pid=187367) 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
Let’s check params from SPU program.
[7]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)
FrozenDict({
params: {
Dense_0: {
bias: array([ 6.7055225e-06, 6.7055225e-06, 6.7055225e-06, -8.4322095e-03,
4.7300994e-02, 4.5412779e-04, 6.7055225e-06, 4.5442879e-03,
6.7055225e-06, -3.4062415e-02, -8.3989948e-03, 6.7055225e-06,
6.7055225e-06, 5.6699291e-02, -4.8456341e-03, 6.7055225e-06,
3.5731569e-02, 6.3510090e-03, 3.0306578e-03, 3.2686546e-02,
6.7055225e-06, -2.1292433e-02, -7.7798963e-03, 6.7055225e-06,
2.8470993e-02, 6.7055225e-06, -3.0836165e-03, 4.5374036e-05,
1.4400020e-02, 2.0861626e-02], dtype=float32),
kernel: array([[-0.14870723, -0.23531294, -0.1493704 , -0.01558255, -0.13322462,
0.1917662 , -0.03679654, -0.03744406, -0.1417609 , 0.03231682,
0.12653404, -0.4025072 , -0.1689485 , 0.21399944, -0.13844648,
0.10585822, -0.11602122, 0.38625073, 0.05966607, 0.06318197,
0.0779368 , -0.01318966, -0.28804308, -0.09602153, 0.11111972,
-0.08543564, 0.07547122, -0.04118884, -0.38266844, 0.23767346],
[ 0.17795108, 0.2294012 , -0.24440196, -0.14849591, 0.33701816,
0.0258413 , -0.04214501, 0.41052908, 0.32439357, -0.16435765,
0.08169709, 0.05259326, 0.3113483 , 0.2931838 , 0.12270276,
-0.38752455, -0.38534215, -0.06536001, -0.25914845, -0.3322725 ,
-0.31587672, -0.29117638, -0.06018265, 0.2297913 , 0.10114388,
-0.01309358, 0.17881514, -0.23215818, 0.3828069 , 0.03806593],
[ 0.1428942 , -0.02135262, 0.16819091, 0.08982845, -0.38852412,
-0.04850367, 0.13870972, -0.05800854, 0.28472927, -0.12711032,
0.25702882, 0.09648418, 0.11670431, -0.1896179 , -0.03994708,
-0.09573121, 0.07308613, 0.14650668, 0.09226419, 0.03892082,
-0.24624617, 0.03725916, -0.01914255, -0.25209764, 0.17078896,
0.24982187, -0.0028675 , -0.09844984, 0.20797251, 0.08843645],
[-0.22511783, -0.0044653 , -0.04557581, -0.04286373, -0.13053825,
-0.3426896 , -0.00925103, 0.09015252, -0.2824888 , 0.22022144,
0.11647445, 0.04475737, -0.05021369, 0.29519165, -0.23622867,
0.05994891, 0.2596493 , 0.18784739, 0.14603132, 0.2965685 ,
0.03959064, 0.16071922, -0.11333965, -0.06968905, 0.26477575,
-0.317869 , 0.08121799, 0.25563055, -0.05901612, 0.19531868],
[-0.24254663, 0.07968816, -0.06768736, -0.11746876, 0.1875621 ,
-0.06137984, -0.05366111, -0.06934479, 0.07924516, 0.02541035,
-0.31857365, 0.28704768, -0.06027508, -0.30148876, -0.17660952,
0.07973847, 0.1614199 , 0.3279493 , 0.20515053, 0.30348837,
0.2711059 , 0.276556 , 0.07071564, 0.20800509, -0.07333609,
-0.10324922, 0.01553461, 0.31758228, 0.31677115, -0.06809102],
[ 0.07325926, 0.06064408, 0.0530773 , 0.17844556, 0.18787359,
0.17704393, 0.08110972, 0.01482402, -0.04424939, 0.06166127,
0.28827167, 0.05878101, 0.26427427, 0.12087436, -0.02181949,
-0.15166327, -0.04630022, 0.00738053, 0.2839891 , 0.10080083,
-0.3035335 , -0.31350654, -0.17609106, 0.11223568, 0.1156193 ,
-0.27605468, -0.06867941, 0.06136122, 0.3082044 , -0.28000844],
[-0.25858068, -0.01556443, 0.27713627, -0.38400537, 0.39872903,
-0.12919384, -0.02736983, 0.17572944, 0.13031955, 0.15870668,
-0.02625516, 0.29411823, 0.03559025, 0.03587727, -0.2966054 ,
-0.16969463, 0.0300006 , 0.16187829, -0.17532285, -0.08767432,
-0.04854703, -0.10537073, 0.08301418, -0.04356302, -0.25446534,
-0.09856299, -0.04166624, -0.04677388, -0.3353408 , -0.11825959],
[ 0.27912897, -0.07000226, -0.02481516, 0.04389155, -0.08830354,
-0.00139034, 0.08731189, -0.24834795, 0.15356407, -0.12887374,
-0.00434314, -0.00279981, -0.07792975, -0.1029453 , 0.2409295 ,
-0.25699303, 0.2918012 , -0.19479287, -0.27555436, 0.01553042,
0.12703311, 0.1288091 , 0.15366644, 0.1431344 , 0.06207459,
-0.11137639, 0.05906925, -0.11649235, 0.01587239, -0.20323639],
[ 0.06792891, 0.08563136, -0.09104523, 0.17886826, 0.07520616,
-0.13827898, 0.33567435, -0.14805417, -0.03184932, 0.39237124,
-0.1335338 , 0.19828805, 0.05121414, -0.04607381, -0.12948062,
-0.22250798, 0.12677568, 0.39128548, -0.11602047, 0.00093162,
-0.07845107, 0.17064299, 0.2707931 , 0.06743585, 0.07426128,
-0.00924093, -0.0035352 , -0.3685534 , -0.12302665, 0.22056273],
[ 0.02833928, -0.12450014, 0.17981096, 0.15364204, 0.05483492,
0.19171704, -0.0949284 , 0.06867886, -0.07678194, -0.01938733,
0.05701402, -0.39338416, 0.05287948, 0.3794972 , 0.24641661,
-0.1212198 , 0.04000506, -0.38034967, 0.19541413, -0.0905077 ,
0.3206088 , 0.01485404, -0.03493308, 0.11109039, -0.33723742,
-0.30601716, -0.11324729, 0.1596858 , 0.06751473, 0.1008921 ],
[ 0.16805576, 0.19498089, -0.09763785, -0.14558062, 0.10152206,
-0.31742054, -0.11583678, -0.2865575 , -0.10120936, -0.13012367,
0.19799586, -0.06929106, 0.00183079, -0.06139433, -0.23812771,
0.14183812, 0.41206583, -0.11150262, -0.07695962, -0.03937718,
-0.05823223, -0.25616592, 0.17551638, -0.05776715, 0.04627597,
0.12046237, 0.31444448, -0.1823728 , -0.16253875, -0.09766676],
[-0.06190741, -0.11557767, 0.07265058, 0.12529932, 0.20684099,
0.15767016, -0.08056761, -0.19449666, 0.02133167, 0.23543602,
-0.17700855, -0.35116544, -0.22017023, 0.03137846, 0.10100484,
-0.40086156, -0.13380852, -0.06593318, -0.14122422, -0.17200904,
-0.0666105 , 0.09940979, -0.03091712, 0.25939053, 0.06447808,
-0.2506336 , -0.0349206 , 0.08023839, 0.25556827, -0.2408923 ],
[ 0.00898188, -0.40073588, -0.06301974, 0.06183384, 0.3735768 ,
0.03177406, 0.27502847, -0.28810993, -0.2024756 , 0.16113877,
-0.21794656, 0.10632099, 0.00266866, -0.27301037, 0.07529524,
0.07778189, 0.02633543, 0.09457737, -0.28337651, 0.0255892 ,
0.17133063, 0.04773571, -0.01299471, -0.0919252 , 0.22021984,
-0.1989678 , 0.34153467, 0.08680797, -0.08852738, -0.0090448 ],
[ 0.12035008, 0.12541962, -0.36259866, 0.22371957, -0.07335131,
0.10498597, 0.00436583, -0.08324738, 0.22863485, -0.14954014,
0.08159503, -0.3141421 , 0.08762485, -0.05525228, 0.08568875,
-0.02316961, -0.2230854 , 0.02858485, 0.10418503, 0.09759469,
0.08704272, 0.01555008, 0.17367665, 0.08375961, -0.01750728,
-0.06537268, 0.05048656, -0.22944517, 0.05722432, 0.25090805],
[-0.39710748, 0.10012694, -0.07080103, -0.16264898, -0.13910918,
0.16161054, 0.16022125, -0.00788775, 0.05428429, 0.16593601,
0.22370476, 0.36696965, 0.06149913, -0.04857542, 0.3345247 ,
-0.07260554, 0.1938989 , -0.06002848, -0.30302036, 0.17182748,
0.29064724, 0.21397091, 0.04791559, 0.09810503, 0.1033058 ,
0.12732479, -0.0579783 , 0.15246823, -0.3666319 , 0.1779919 ],
[ 0.01064624, 0.0888928 , 0.26858085, 0.34396815, 0.06943932,
0.30761874, -0.15886313, 0.00265385, 0.04297891, -0.06383656,
-0.01197957, -0.10140778, 0.03901416, -0.02126652, 0.13493209,
-0.16070978, -0.27638012, -0.11028586, 0.12214845, -0.2560637 ,
-0.08863154, 0.03597671, -0.1732396 , 0.12559041, 0.14788477,
0.09702435, 0.17843248, 0.08070756, 0.0718791 , 0.08296195],
[ 0.14691886, 0.13540354, -0.05013047, -0.2566406 , -0.2376638 ,
0.21672072, 0.1372795 , -0.03882806, 0.39052176, 0.0047731 ,
0.14544334, -0.0696618 , -0.15187763, 0.06678917, -0.24012098,
0.31160212, 0.06627946, -0.2530402 , -0.20175886, -0.22604358,
0.1381416 , -0.14101216, 0.3429103 , 0.12955913, 0.2845845 ,
0.06188303, -0.22960348, 0.2912202 , -0.08082792, -0.3445377 ],
[-0.01824994, 0.12698065, 0.11829151, -0.08935194, -0.04362963,
-0.06175369, -0.1114524 , -0.06696388, -0.34100425, -0.25512362,
-0.1483988 , -0.20127416, -0.00367533, 0.05239835, 0.06488706,
0.08272076, 0.05891787, 0.2134408 , -0.13793291, 0.30933803,
-0.09876332, -0.15072244, -0.10377637, 0.03409749, 0.0937078 ,
-0.22452421, 0.3597254 , 0.24009626, -0.03083205, -0.10381168],
[-0.14538439, 0.17941016, 0.01639399, -0.2706253 , -0.02600642,
-0.03973 , -0.0325162 , 0.03153259, -0.15472709, -0.09655666,
0.04076509, 0.1300038 , -0.19558378, -0.17638195, 0.12240331,
-0.26903665, 0.2714493 , -0.07004572, -0.07335924, 0.03825237,
0.22632292, 0.3012138 , 0.02217355, -0.30002278, -0.06066401,
-0.07689169, -0.37136257, 0.19665234, -0.10525645, -0.27408272],
[ 0.05384398, 0.03158583, -0.00409974, -0.04451011, -0.10076478,
-0.06426084, 0.3136195 , -0.13606365, 0.1243284 , -0.10924114,
-0.03940558, 0.22020963, -0.07174113, 0.08709462, 0.04955287,
0.36317343, 0.00659794, -0.15838777, 0.09210019, -0.17414865,
-0.14202411, 0.3834263 , 0.02247368, 0.00736032, -0.02805607,
-0.15887989, 0.03910746, -0.0943727 , 0.21787158, 0.01440434],
[-0.09300622, -0.19802521, -0.31412005, 0.17171307, 0.1331803 ,
-0.14113024, -0.21318011, -0.16237472, 0.09434846, 0.14660788,
0.01858762, -0.02211154, -0.14670722, 0.39278403, -0.20136856,
0.10904545, -0.02885009, -0.15209475, 0.1743193 , 0.0778787 ,
0.09585676, 0.10286772, 0.1895318 , -0.15744607, 0.0972386 ,
-0.26544875, -0.05130047, 0.08041063, 0.05855417, 0.24786705],
[ 0.05508722, 0.23071642, 0.00278442, -0.05163229, -0.13318591,
0.17231207, -0.0383717 , 0.17234325, -0.12098849, -0.12200612,
-0.165717 , -0.08695543, -0.01522441, -0.31668693, 0.196136 ,
-0.20849878, 0.34565175, 0.252592 , 0.03059202, -0.23635055,
-0.02455017, -0.07401715, 0.18046305, 0.08005303, 0.02341022,
0.05160871, 0.0830403 , -0.10961437, 0.2051303 , 0.05485763],
[-0.29294217, 0.01583408, -0.00052598, 0.07539546, 0.17627907,
0.16075702, 0.00591798, -0.02526975, -0.2719347 , -0.2642147 ,
0.17578189, 0.26844388, -0.16066906, 0.00551553, -0.41348425,
0.1321568 , 0.2071938 , -0.09202607, -0.32119918, 0.03001858,
-0.03515013, -0.11420041, 0.00692059, 0.06027223, 0.31073922,
0.31373912, 0.15468763, 0.23844069, 0.20547047, 0.165754 ],
[-0.1317544 , -0.09716719, -0.2110814 , 0.30688593, 0.13689038,
0.25466746, -0.23185365, 0.265381 , -0.20205005, 0.26761973,
-0.01471928, -0.17001429, -0.00165382, 0.10118251, 0.28316593,
-0.10187137, 0.02500786, 0.09213623, -0.06184761, 0.051311 ,
-0.13956325, 0.29834348, 0.16425882, -0.20013842, 0.10159607,
-0.09226643, -0.09284794, -0.24736227, 0.28198415, 0.18465933],
[-0.19596493, -0.26223665, -0.02396852, 0.1405711 , -0.05117449,
0.09832071, 0.10009323, 0.08764507, -0.20915532, -0.04817107,
0.11512975, -0.0107393 , 0.06286559, -0.14394692, 0.1831078 ,
0.18601051, -0.01792853, -0.010507 , 0.2988264 , 0.02924132,
0.1502285 , -0.02573505, 0.10515428, 0.32683268, -0.06475027,
-0.07946308, -0.33095527, -0.33394814, 0.14654751, -0.18609025],
[-0.04332547, 0.18820217, 0.03160366, 0.11940409, -0.22678787,
0.09432799, -0.08720809, 0.25600654, -0.14890012, 0.09946848,
0.18772584, -0.19526623, 0.0827599 , -0.14669879, -0.12541471,
-0.13776924, 0.09574251, -0.2980466 , 0.10541511, -0.11811657,
-0.23554784, -0.01769215, -0.29761636, 0.04322377, -0.04169539,
0.04331157, -0.10865457, 0.3526432 , 0.27452517, 0.01664442],
[ 0.17763771, -0.07080895, -0.12558904, -0.13398908, -0.22847766,
-0.20403627, 0.07889682, 0.13384837, -0.31691694, -0.13476555,
-0.08197045, 0.02778772, 0.02476428, 0.10588782, -0.25830707,
-0.24311969, 0.03762388, 0.05451898, -0.13534577, -0.10997833,
-0.3139264 , 0.05126831, -0.00060226, -0.15891929, -0.17077953,
0.2362888 , 0.08467598, 0.01052356, 0.08872832, 0.16418251],
[ 0.37544525, 0.0681546 , 0.07722013, -0.40396348, -0.05511683,
0.00878677, 0.33257678, 0.18474084, -0.0799066 , 0.20011736,
0.14146338, -0.15846273, -0.15961201, -0.18772689, -0.17597765,
-0.13404477, 0.21314138, -0.13090074, 0.10695033, 0.28710032,
0.13358802, 0.3303852 , -0.2687422 , -0.22376198, 0.29356587,
-0.03488064, 0.14832059, 0.12624982, -0.20833445, 0.05823356],
[ 0.17862728, -0.12085862, -0.07798004, 0.16461669, 0.13114056,
0.1119384 , -0.02916402, -0.01834482, 0.03708343, -0.39161655,
0.04380961, 0.12685701, -0.20311095, 0.14991562, 0.08968998,
-0.1430527 , 0.3768945 , 0.2545389 , 0.09408659, 0.30030465,
0.00201878, -0.03300162, -0.31967437, 0.08429171, -0.10358454,
0.15462488, 0.15204427, -0.00353977, -0.15648344, 0.03190795],
[-0.3349845 , -0.18704857, 0.12660322, 0.27142197, -0.04179126,
-0.01659705, -0.15886122, 0.14643206, 0.10317151, 0.139131 ,
0.26203057, 0.03828669, 0.17041986, 0.28139216, 0.03020249,
-0.21715921, 0.05988631, 0.20941454, 0.27820507, -0.30283943,
0.21741417, 0.06876856, -0.0162366 , -0.09319973, 0.16716208,
-0.05672812, -0.01678701, -0.33967227, 0.04148872, 0.24174951]],
dtype=float32),
},
Dense_1: {
bias: array([-9.28075612e-03, -4.38565016e-03, 6.73728883e-02, 6.70552254e-06,
-1.11967325e-02, 1.06356591e-02, -2.26502120e-02, -3.45642865e-03,
3.70997190e-02, 9.00812894e-02, 6.70552254e-06, 6.70552254e-06,
6.70552254e-06, -1.53630227e-02, 3.88986021e-02], dtype=float32),
kernel: array([[-0.21578309, -0.08008368, -0.34167936, -0.03616343, -0.04043388,
-0.19278756, 0.07816273, 0.3847432 , -0.27097666, 0.03089739,
-0.11206758, 0.12151396, 0.38484663, 0.12947203, 0.03026646],
[ 0.3035894 , 0.14900179, 0.02244793, 0.17264117, 0.0011169 ,
-0.1606707 , 0.17210394, -0.19850568, -0.00882789, 0.06376703,
-0.09706031, -0.27143008, 0.32902688, 0.01248117, -0.20562333],
[ 0.01422736, 0.25237322, 0.26592904, -0.07876748, 0.02570754,
0.13746765, -0.3037846 , -0.30066282, 0.228537 , 0.07397157,
-0.05444951, 0.06826244, -0.11475235, -0.04363853, -0.00258049],
[-0.03807974, 0.36382473, -0.05991563, 0.1660564 , -0.18014075,
0.17624326, -0.24441232, -0.31741685, -0.06890935, -0.04919542,
0.13665393, -0.05236177, 0.12887959, 0.2582429 , -0.06479871],
[ 0.01582411, -0.00546367, 0.06451672, 0.00377437, 0.05299711,
0.09622552, -0.33355796, 0.15772232, -0.00315991, 0.3426076 ,
-0.01920256, -0.0157837 , 0.10247016, -0.02410382, 0.14005862],
[-0.1542452 , -0.18916714, -0.12516193, -0.15350978, -0.20895821,
-0.03576617, 0.0180776 , 0.16850933, 0.05937128, 0.03776275,
0.07396616, 0.03354299, 0.06906249, 0.15164083, -0.2608541 ],
[ 0.0905385 , 0.3133642 , -0.17575558, 0.05339232, -0.19663817,
0.22920834, 0.21465397, 0.14934285, 0.30395645, -0.2403111 ,
0.11673826, -0.0449398 , -0.0359446 , 0.3089288 , -0.01469092],
[-0.09350568, -0.09241582, 0.29311585, 0.13808654, 0.14410885,
0.11155026, 0.19201808, -0.22068372, 0.0091573 , -0.00837527,
-0.10839485, -0.0492375 , 0.15357326, 0.3894407 , -0.15209287],
[ 0.0029071 , 0.18366341, 0.03765289, -0.01738603, 0.18317957,
0.00410259, 0.09655431, 0.07968767, 0.21980065, 0.22737293,
-0.15136166, 0.20435053, 0.11874333, -0.3370184 , 0.11251831],
[-0.03699435, 0.05359124, -0.00424996, -0.00427449, -0.20195475,
-0.12829332, 0.06293778, 0.13848272, -0.17896764, -0.38953093,
-0.07185236, 0.22985502, -0.11224222, 0.04145651, -0.3817618 ],
[ 0.23528674, 0.1663438 , -0.08346738, 0.20346904, -0.20409097,
-0.07192652, 0.11208971, 0.24518102, 0.23959732, -0.1391288 ,
-0.02638906, -0.11256091, -0.27086872, -0.00492385, 0.13006589],
[-0.05570208, -0.34653068, 0.298495 , -0.16680127, 0.06143057,
0.09288131, 0.1472318 , -0.12598082, -0.01329006, -0.26823848,
0.08741044, 0.10009366, 0.1264808 , 0.13802043, 0.2563799 ],
[ 0.01380032, -0.19647142, 0.14879738, 0.0388497 , -0.14403345,
0.3500362 , -0.03261025, -0.11959814, -0.35041225, -0.09013529,
0.16815332, -0.17363463, -0.26452613, 0.18936844, -0.30342007],
[ 0.15264955, -0.16593191, 0.2803555 , -0.02613318, 0.09317887,
-0.1145407 , -0.02915843, 0.09115867, 0.16309327, 0.16567504,
-0.16353543, -0.02392778, 0.21730614, -0.37557966, 0.36441088],
[ 0.27545726, -0.0511765 , 0.03052394, 0.38374472, 0.18914919,
-0.30549794, -0.1365143 , -0.09850363, -0.08355592, -0.17305706,
0.00163533, 0.27035654, -0.01430997, 0.01418965, -0.23040168],
[-0.11281115, -0.08904882, 0.05267188, -0.03345008, 0.17955093,
0.15272899, -0.05194864, 0.10906464, 0.21673168, -0.05776855,
0.29315004, -0.272271 , 0.22718571, -0.04166271, 0.08242701],
[ 0.11221845, 0.15372409, -0.13822478, -0.1822467 , -0.26139548,
0.22891735, -0.12165104, -0.20519899, 0.39132354, 0.19772069,
0.00470251, -0.04089966, -0.17769708, 0.22472003, 0.24131916],
[-0.08916578, -0.13332213, 0.11583962, -0.3159293 , -0.05461061,
-0.03293931, 0.17573284, -0.03388457, -0.04562169, -0.00728241,
-0.20086795, -0.04282369, -0.06481767, 0.00174528, 0.08415282],
[ 0.10727057, 0.15353328, 0.09634779, 0.01951087, -0.00729644,
-0.25289363, -0.23461105, 0.35619986, 0.1761693 , -0.18046483,
-0.25238073, -0.05560882, -0.20357345, -0.13479468, 0.14422338],
[-0.32128555, 0.01506339, 0.3208 , 0.3084674 , 0.06561027,
-0.20671532, 0.07110539, 0.09107907, -0.05795462, 0.06884666,
-0.24340847, 0.09923317, -0.39770418, -0.1435657 , 0.18189654],
[ 0.17829058, -0.3734911 , -0.344891 , -0.18513158, -0.1252987 ,
-0.359349 , -0.21523046, 0.4066509 , -0.06088345, -0.12821774,
0.30891037, -0.05408247, 0.13263397, 0.01792602, 0.22215459],
[-0.15978388, -0.19274645, -0.39842924, -0.13794418, 0.24811234,
-0.30259767, 0.25340182, 0.36628515, -0.04467097, 0.2068153 ,
0.10091661, -0.17184901, -0.01158652, 0.28761518, 0.07140049],
[-0.38753265, -0.2148714 , -0.34941888, -0.37459916, 0.00249913,
-0.38012785, -0.26021895, 0.06027205, -0.05131304, 0.24082436,
0.20541278, -0.09037189, -0.1668249 , 0.24143052, -0.26692837],
[-0.20972599, -0.01015192, 0.16557814, 0.20875406, -0.19013 ,
-0.31780058, -0.0311262 , -0.06458683, 0.39772552, -0.26640862,
0.31138209, -0.06382139, -0.39696902, 0.10767588, 0.01154487],
[ 0.18972392, 0.0260205 , 0.10645114, -0.21743847, -0.26412916,
0.15006566, 0.13827245, -0.21839432, -0.0661045 , 0.27946475,
-0.10810606, -0.32184672, -0.03605315, 0.04213592, -0.01746434],
[ 0.04842271, -0.17424968, 0.1226363 , 0.3272752 , -0.08305582,
-0.31486374, 0.10645151, 0.09955801, 0.07176954, -0.20580491,
0.04142599, -0.00282551, 0.15971068, 0.19535801, -0.218681 ],
[ 0.18084396, 0.0928266 , -0.27906552, -0.3218416 , 0.08461788,
-0.13167469, -0.22216263, 0.06937028, 0.10845792, -0.15438123,
-0.02529359, 0.03964274, -0.01773006, 0.04081337, 0.15702999],
[ 0.17229266, -0.27421105, 0.03916015, -0.10643585, 0.15348236,
-0.40775347, -0.14518811, -0.19719355, 0.15164377, 0.08712097,
-0.01809341, 0.03163807, -0.31661576, -0.08889712, -0.3158114 ],
[-0.09769286, -0.0287679 , 0.35823774, -0.2710501 , 0.32775387,
0.08072445, 0.30246186, -0.19245733, -0.17830896, 0.2923805 ,
0.09355405, -0.2524669 , 0.12927876, 0.38659224, -0.394949 ],
[ 0.06278192, 0.08469887, -0.00950807, 0.10956495, 0.09936713,
-0.19083263, 0.21161489, 0.3930601 , 0.00441524, 0.20089766,
-0.13769451, 0.27256852, -0.0958501 , -0.05921303, 0.33085537]],
dtype=float32),
},
Dense_2: {
bias: array([-1.85250044e-02, 1.07492432e-01, 6.70552254e-06, -2.40181834e-02,
1.02186531e-01, 1.17786124e-01, -7.42332637e-03, 6.70552254e-06],
dtype=float32),
kernel: array([[ 1.70731947e-01, -9.90723372e-02, -2.29330808e-02,
1.22201458e-01, -1.06555134e-01, -3.66996974e-02,
1.41981944e-01, 8.84072036e-02],
[ 1.80482566e-01, 1.18618101e-01, 5.27178943e-01,
-1.53569669e-01, -3.86155099e-02, -1.22957200e-01,
6.27236068e-03, 1.60065144e-02],
[ 1.45004794e-01, 5.34007728e-01, -3.44348907e-01,
-9.43183005e-02, 1.35729849e-01, -8.20837915e-03,
1.19165242e-01, -4.11043108e-01],
[ 2.49569073e-01, 2.14901850e-01, 3.24754179e-01,
-4.91983056e-01, -1.14351839e-01, -2.11404264e-02,
7.69451857e-02, 3.31384748e-01],
[ 4.12468165e-02, -1.21441126e-01, -3.10934186e-01,
3.62142444e-01, -2.74272680e-01, -5.16952693e-01,
-5.41899055e-02, 5.59365571e-01],
[-4.59907204e-02, -6.31701499e-02, 1.12813368e-01,
3.72401834e-01, 1.19809762e-01, 2.30254814e-01,
-1.38893276e-01, 4.47092503e-02],
[-6.63732141e-02, -3.46694767e-01, -4.84580398e-01,
1.17096156e-01, -2.02452645e-01, -3.72330904e-01,
5.67477584e-01, -2.42807895e-01],
[ 3.36939037e-01, -1.21135429e-01, 3.77379209e-01,
4.15782034e-01, 4.28560078e-02, -3.28819275e-01,
4.96320784e-01, -1.31850213e-01],
[ 3.29444051e-01, 7.28795379e-02, 2.03807220e-01,
1.12708807e-01, 3.64434421e-01, 1.61256164e-01,
-2.09103152e-01, 5.57109714e-04],
[ 1.63463175e-01, 4.33325768e-04, -4.44926351e-01,
-9.06594098e-02, 3.35962057e-01, 4.89929318e-01,
-1.45875633e-01, -2.39341617e-01],
[-2.70898372e-01, 1.98495671e-01, 1.63841411e-01,
-3.97940278e-01, 8.93494636e-02, 3.55310917e-01,
-4.32752073e-03, -2.55927563e-01],
[ 2.71851391e-01, -6.50364608e-02, 5.75378686e-02,
8.61035287e-02, 4.62560952e-02, 6.84097558e-02,
-3.49434435e-01, 3.20657253e-01],
[ 3.87204051e-01, 1.02552503e-01, -3.67724121e-01,
-1.37631506e-01, -2.76333094e-03, -9.74716246e-03,
-2.03522891e-02, -3.78593743e-01],
[ 1.50336102e-01, -1.60084844e-01, -4.73217756e-01,
2.41285011e-01, 4.34440225e-02, -3.39211166e-01,
1.99615315e-01, -1.64225787e-01],
[-9.00761038e-02, 1.48902372e-01, -3.05288136e-02,
4.30284500e-01, -3.87649029e-01, 5.50720513e-01,
-9.88744646e-02, -3.92888784e-01]], dtype=float32),
},
Dense_3: {
bias: array([0.23489459], dtype=float32),
kernel: array([[-0.07592051],
[ 0.6634396 ],
[ 0.193967 ],
[-0.21856818],
[ 0.44744045],
[ 0.7258886 ],
[-0.07619607],
[-0.20487107]], dtype=float32),
},
},
})
Lastly, let’s validate the model.
[8]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
auc=0.9927939731411726
This is the end of the lab.