SPU 训练神经网络#

请先阅读 Logistic Regression On SPU

在实验 Logistic Regression On SPU 中,我们展示了如何使用 SecretFlow/SPU 将明文 JAX 训练程序转换为安全 MPC 训练程序。

在这个实验室中,这个想法非常相似,但这次我们将使用神经网络模型。

我们将使用相同的数据集和所有设置。

首先,让我们制定明文模型。

以下代码仅作为示例,请勿在生产环境直接使用。

本教程需要比 8c16g 更多的资源,这是 SecretFlow 的最低要求。

使用 JAX/FLAX 训练模型#

加载数据集#

以下内容是复制于实验 Logistic Regression On SPU 。 我不会再次在这里解释。

[ ]:
import sys

!{sys.executable} -m pip install flax==0.6.0
[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    x, y = load_breast_cancer(return_X_y=True)
    x = (x - np.min(x)) / (np.max(x) - np.min(x))
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, :15], _
            else:
                return x_train[:, 15:], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

定义模型#

我们将使用 4 层 MLP 模型和 ReLU 激活函数。

[2]:
from typing import Sequence
import flax.linen as nn


FEATURES = [30, 15, 8, 1]


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

然后我们在这里定义训练方法。

[3]:
import jax.numpy as jnp


def predict(params, x):
    # TODO(junfeng): investigate why need to have a duplicated definition in notebook,
    # which is not the case in a normal python program.
    from typing import Sequence
    import flax.linen as nn

    FEATURES = [30, 15, 8, 1]

    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    return MLP(FEATURES).apply(params, x)


def loss_func(params, x, y):
    pred = predict(params, x)

    def mse(y, pred):
        def squared_error(y, y_pred):
            return jnp.multiply(y - y_pred, y - y_pred) / 2.0

        return jnp.mean(squared_error(y, pred))

    return mse(y, pred)


def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):
    x = jnp.concatenate((x1, x2), axis=1)
    xs = jnp.array_split(x, len(x) / n_batch, axis=0)
    ys = jnp.array_split(y, len(y) / n_batch, axis=0)

    def body_fun(_, loop_carry):
        params = loop_carry
        for x, y in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, x, y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - step_size * g, params, grads
            )
        return params

    params = jax.lax.fori_loop(0, n_epochs, body_fun, params)
    return params


def model_init(n_batch=10):
    model = MLP(FEATURES)
    return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))

验证模型#

我们使用 AUC 作为验证指标。

[4]:
from sklearn.metrics import roc_auc_score


def validate_model(params, X_test, y_test):
    y_pred = predict(params, X_test)
    return roc_auc_score(y_test, y_pred)

放在一起#

让我们把所有不住放在一起,训练一个明文 NN 模型!

[5]:
import jax

# Load the data
x1, _ = breast_cancer(party_id=1, train=True)
x2, y = breast_cancer(party_id=2, train=True)


# Hyperparameter
n_batch = 10
n_epochs = 10
step_size = 0.01


# Train the model
init_params = model_init(n_batch)
params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)

# Test the model
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9927939731411726

让我们把所有东西放在一起,训练一个明文神经网络模型!

使用 SPU 训练模型#

[6]:
import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

x1, _ = alice(breast_cancer)(party_id=1, train=True)
x2, y = bob(breast_cancer)(party_id=2, train=True)
init_params = model_init(n_batch)


device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(alice, init_params).to(device)

params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(
    x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size
)
2023-04-28 14:35:12,293 INFO worker.py:1538 -- Started a local Ray instance.
(_run pid=175587) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=175587) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=175587) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=175587) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=176242) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=176242) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=176242) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=176242) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=180401) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=180401) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=180401) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=180401) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=177052) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=177052) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=177052) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=177052) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=187368) 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime pid=187367) 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127

让我们检查 SPU 程序的参数。

[7]:
params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)
params = sf.reveal(params_spu)
print(params)
FrozenDict({
    params: {
        Dense_0: {
            bias: array([ 6.7055225e-06,  6.7055225e-06,  6.7055225e-06, -8.4322095e-03,
                    4.7300994e-02,  4.5412779e-04,  6.7055225e-06,  4.5442879e-03,
                    6.7055225e-06, -3.4062415e-02, -8.3989948e-03,  6.7055225e-06,
                    6.7055225e-06,  5.6699291e-02, -4.8456341e-03,  6.7055225e-06,
                    3.5731569e-02,  6.3510090e-03,  3.0306578e-03,  3.2686546e-02,
                    6.7055225e-06, -2.1292433e-02, -7.7798963e-03,  6.7055225e-06,
                    2.8470993e-02,  6.7055225e-06, -3.0836165e-03,  4.5374036e-05,
                    1.4400020e-02,  2.0861626e-02], dtype=float32),
            kernel: array([[-0.14870723, -0.23531294, -0.1493704 , -0.01558255, -0.13322462,
                     0.1917662 , -0.03679654, -0.03744406, -0.1417609 ,  0.03231682,
                     0.12653404, -0.4025072 , -0.1689485 ,  0.21399944, -0.13844648,
                     0.10585822, -0.11602122,  0.38625073,  0.05966607,  0.06318197,
                     0.0779368 , -0.01318966, -0.28804308, -0.09602153,  0.11111972,
                    -0.08543564,  0.07547122, -0.04118884, -0.38266844,  0.23767346],
                   [ 0.17795108,  0.2294012 , -0.24440196, -0.14849591,  0.33701816,
                     0.0258413 , -0.04214501,  0.41052908,  0.32439357, -0.16435765,
                     0.08169709,  0.05259326,  0.3113483 ,  0.2931838 ,  0.12270276,
                    -0.38752455, -0.38534215, -0.06536001, -0.25914845, -0.3322725 ,
                    -0.31587672, -0.29117638, -0.06018265,  0.2297913 ,  0.10114388,
                    -0.01309358,  0.17881514, -0.23215818,  0.3828069 ,  0.03806593],
                   [ 0.1428942 , -0.02135262,  0.16819091,  0.08982845, -0.38852412,
                    -0.04850367,  0.13870972, -0.05800854,  0.28472927, -0.12711032,
                     0.25702882,  0.09648418,  0.11670431, -0.1896179 , -0.03994708,
                    -0.09573121,  0.07308613,  0.14650668,  0.09226419,  0.03892082,
                    -0.24624617,  0.03725916, -0.01914255, -0.25209764,  0.17078896,
                     0.24982187, -0.0028675 , -0.09844984,  0.20797251,  0.08843645],
                   [-0.22511783, -0.0044653 , -0.04557581, -0.04286373, -0.13053825,
                    -0.3426896 , -0.00925103,  0.09015252, -0.2824888 ,  0.22022144,
                     0.11647445,  0.04475737, -0.05021369,  0.29519165, -0.23622867,
                     0.05994891,  0.2596493 ,  0.18784739,  0.14603132,  0.2965685 ,
                     0.03959064,  0.16071922, -0.11333965, -0.06968905,  0.26477575,
                    -0.317869  ,  0.08121799,  0.25563055, -0.05901612,  0.19531868],
                   [-0.24254663,  0.07968816, -0.06768736, -0.11746876,  0.1875621 ,
                    -0.06137984, -0.05366111, -0.06934479,  0.07924516,  0.02541035,
                    -0.31857365,  0.28704768, -0.06027508, -0.30148876, -0.17660952,
                     0.07973847,  0.1614199 ,  0.3279493 ,  0.20515053,  0.30348837,
                     0.2711059 ,  0.276556  ,  0.07071564,  0.20800509, -0.07333609,
                    -0.10324922,  0.01553461,  0.31758228,  0.31677115, -0.06809102],
                   [ 0.07325926,  0.06064408,  0.0530773 ,  0.17844556,  0.18787359,
                     0.17704393,  0.08110972,  0.01482402, -0.04424939,  0.06166127,
                     0.28827167,  0.05878101,  0.26427427,  0.12087436, -0.02181949,
                    -0.15166327, -0.04630022,  0.00738053,  0.2839891 ,  0.10080083,
                    -0.3035335 , -0.31350654, -0.17609106,  0.11223568,  0.1156193 ,
                    -0.27605468, -0.06867941,  0.06136122,  0.3082044 , -0.28000844],
                   [-0.25858068, -0.01556443,  0.27713627, -0.38400537,  0.39872903,
                    -0.12919384, -0.02736983,  0.17572944,  0.13031955,  0.15870668,
                    -0.02625516,  0.29411823,  0.03559025,  0.03587727, -0.2966054 ,
                    -0.16969463,  0.0300006 ,  0.16187829, -0.17532285, -0.08767432,
                    -0.04854703, -0.10537073,  0.08301418, -0.04356302, -0.25446534,
                    -0.09856299, -0.04166624, -0.04677388, -0.3353408 , -0.11825959],
                   [ 0.27912897, -0.07000226, -0.02481516,  0.04389155, -0.08830354,
                    -0.00139034,  0.08731189, -0.24834795,  0.15356407, -0.12887374,
                    -0.00434314, -0.00279981, -0.07792975, -0.1029453 ,  0.2409295 ,
                    -0.25699303,  0.2918012 , -0.19479287, -0.27555436,  0.01553042,
                     0.12703311,  0.1288091 ,  0.15366644,  0.1431344 ,  0.06207459,
                    -0.11137639,  0.05906925, -0.11649235,  0.01587239, -0.20323639],
                   [ 0.06792891,  0.08563136, -0.09104523,  0.17886826,  0.07520616,
                    -0.13827898,  0.33567435, -0.14805417, -0.03184932,  0.39237124,
                    -0.1335338 ,  0.19828805,  0.05121414, -0.04607381, -0.12948062,
                    -0.22250798,  0.12677568,  0.39128548, -0.11602047,  0.00093162,
                    -0.07845107,  0.17064299,  0.2707931 ,  0.06743585,  0.07426128,
                    -0.00924093, -0.0035352 , -0.3685534 , -0.12302665,  0.22056273],
                   [ 0.02833928, -0.12450014,  0.17981096,  0.15364204,  0.05483492,
                     0.19171704, -0.0949284 ,  0.06867886, -0.07678194, -0.01938733,
                     0.05701402, -0.39338416,  0.05287948,  0.3794972 ,  0.24641661,
                    -0.1212198 ,  0.04000506, -0.38034967,  0.19541413, -0.0905077 ,
                     0.3206088 ,  0.01485404, -0.03493308,  0.11109039, -0.33723742,
                    -0.30601716, -0.11324729,  0.1596858 ,  0.06751473,  0.1008921 ],
                   [ 0.16805576,  0.19498089, -0.09763785, -0.14558062,  0.10152206,
                    -0.31742054, -0.11583678, -0.2865575 , -0.10120936, -0.13012367,
                     0.19799586, -0.06929106,  0.00183079, -0.06139433, -0.23812771,
                     0.14183812,  0.41206583, -0.11150262, -0.07695962, -0.03937718,
                    -0.05823223, -0.25616592,  0.17551638, -0.05776715,  0.04627597,
                     0.12046237,  0.31444448, -0.1823728 , -0.16253875, -0.09766676],
                   [-0.06190741, -0.11557767,  0.07265058,  0.12529932,  0.20684099,
                     0.15767016, -0.08056761, -0.19449666,  0.02133167,  0.23543602,
                    -0.17700855, -0.35116544, -0.22017023,  0.03137846,  0.10100484,
                    -0.40086156, -0.13380852, -0.06593318, -0.14122422, -0.17200904,
                    -0.0666105 ,  0.09940979, -0.03091712,  0.25939053,  0.06447808,
                    -0.2506336 , -0.0349206 ,  0.08023839,  0.25556827, -0.2408923 ],
                   [ 0.00898188, -0.40073588, -0.06301974,  0.06183384,  0.3735768 ,
                     0.03177406,  0.27502847, -0.28810993, -0.2024756 ,  0.16113877,
                    -0.21794656,  0.10632099,  0.00266866, -0.27301037,  0.07529524,
                     0.07778189,  0.02633543,  0.09457737, -0.28337651,  0.0255892 ,
                     0.17133063,  0.04773571, -0.01299471, -0.0919252 ,  0.22021984,
                    -0.1989678 ,  0.34153467,  0.08680797, -0.08852738, -0.0090448 ],
                   [ 0.12035008,  0.12541962, -0.36259866,  0.22371957, -0.07335131,
                     0.10498597,  0.00436583, -0.08324738,  0.22863485, -0.14954014,
                     0.08159503, -0.3141421 ,  0.08762485, -0.05525228,  0.08568875,
                    -0.02316961, -0.2230854 ,  0.02858485,  0.10418503,  0.09759469,
                     0.08704272,  0.01555008,  0.17367665,  0.08375961, -0.01750728,
                    -0.06537268,  0.05048656, -0.22944517,  0.05722432,  0.25090805],
                   [-0.39710748,  0.10012694, -0.07080103, -0.16264898, -0.13910918,
                     0.16161054,  0.16022125, -0.00788775,  0.05428429,  0.16593601,
                     0.22370476,  0.36696965,  0.06149913, -0.04857542,  0.3345247 ,
                    -0.07260554,  0.1938989 , -0.06002848, -0.30302036,  0.17182748,
                     0.29064724,  0.21397091,  0.04791559,  0.09810503,  0.1033058 ,
                     0.12732479, -0.0579783 ,  0.15246823, -0.3666319 ,  0.1779919 ],
                   [ 0.01064624,  0.0888928 ,  0.26858085,  0.34396815,  0.06943932,
                     0.30761874, -0.15886313,  0.00265385,  0.04297891, -0.06383656,
                    -0.01197957, -0.10140778,  0.03901416, -0.02126652,  0.13493209,
                    -0.16070978, -0.27638012, -0.11028586,  0.12214845, -0.2560637 ,
                    -0.08863154,  0.03597671, -0.1732396 ,  0.12559041,  0.14788477,
                     0.09702435,  0.17843248,  0.08070756,  0.0718791 ,  0.08296195],
                   [ 0.14691886,  0.13540354, -0.05013047, -0.2566406 , -0.2376638 ,
                     0.21672072,  0.1372795 , -0.03882806,  0.39052176,  0.0047731 ,
                     0.14544334, -0.0696618 , -0.15187763,  0.06678917, -0.24012098,
                     0.31160212,  0.06627946, -0.2530402 , -0.20175886, -0.22604358,
                     0.1381416 , -0.14101216,  0.3429103 ,  0.12955913,  0.2845845 ,
                     0.06188303, -0.22960348,  0.2912202 , -0.08082792, -0.3445377 ],
                   [-0.01824994,  0.12698065,  0.11829151, -0.08935194, -0.04362963,
                    -0.06175369, -0.1114524 , -0.06696388, -0.34100425, -0.25512362,
                    -0.1483988 , -0.20127416, -0.00367533,  0.05239835,  0.06488706,
                     0.08272076,  0.05891787,  0.2134408 , -0.13793291,  0.30933803,
                    -0.09876332, -0.15072244, -0.10377637,  0.03409749,  0.0937078 ,
                    -0.22452421,  0.3597254 ,  0.24009626, -0.03083205, -0.10381168],
                   [-0.14538439,  0.17941016,  0.01639399, -0.2706253 , -0.02600642,
                    -0.03973   , -0.0325162 ,  0.03153259, -0.15472709, -0.09655666,
                     0.04076509,  0.1300038 , -0.19558378, -0.17638195,  0.12240331,
                    -0.26903665,  0.2714493 , -0.07004572, -0.07335924,  0.03825237,
                     0.22632292,  0.3012138 ,  0.02217355, -0.30002278, -0.06066401,
                    -0.07689169, -0.37136257,  0.19665234, -0.10525645, -0.27408272],
                   [ 0.05384398,  0.03158583, -0.00409974, -0.04451011, -0.10076478,
                    -0.06426084,  0.3136195 , -0.13606365,  0.1243284 , -0.10924114,
                    -0.03940558,  0.22020963, -0.07174113,  0.08709462,  0.04955287,
                     0.36317343,  0.00659794, -0.15838777,  0.09210019, -0.17414865,
                    -0.14202411,  0.3834263 ,  0.02247368,  0.00736032, -0.02805607,
                    -0.15887989,  0.03910746, -0.0943727 ,  0.21787158,  0.01440434],
                   [-0.09300622, -0.19802521, -0.31412005,  0.17171307,  0.1331803 ,
                    -0.14113024, -0.21318011, -0.16237472,  0.09434846,  0.14660788,
                     0.01858762, -0.02211154, -0.14670722,  0.39278403, -0.20136856,
                     0.10904545, -0.02885009, -0.15209475,  0.1743193 ,  0.0778787 ,
                     0.09585676,  0.10286772,  0.1895318 , -0.15744607,  0.0972386 ,
                    -0.26544875, -0.05130047,  0.08041063,  0.05855417,  0.24786705],
                   [ 0.05508722,  0.23071642,  0.00278442, -0.05163229, -0.13318591,
                     0.17231207, -0.0383717 ,  0.17234325, -0.12098849, -0.12200612,
                    -0.165717  , -0.08695543, -0.01522441, -0.31668693,  0.196136  ,
                    -0.20849878,  0.34565175,  0.252592  ,  0.03059202, -0.23635055,
                    -0.02455017, -0.07401715,  0.18046305,  0.08005303,  0.02341022,
                     0.05160871,  0.0830403 , -0.10961437,  0.2051303 ,  0.05485763],
                   [-0.29294217,  0.01583408, -0.00052598,  0.07539546,  0.17627907,
                     0.16075702,  0.00591798, -0.02526975, -0.2719347 , -0.2642147 ,
                     0.17578189,  0.26844388, -0.16066906,  0.00551553, -0.41348425,
                     0.1321568 ,  0.2071938 , -0.09202607, -0.32119918,  0.03001858,
                    -0.03515013, -0.11420041,  0.00692059,  0.06027223,  0.31073922,
                     0.31373912,  0.15468763,  0.23844069,  0.20547047,  0.165754  ],
                   [-0.1317544 , -0.09716719, -0.2110814 ,  0.30688593,  0.13689038,
                     0.25466746, -0.23185365,  0.265381  , -0.20205005,  0.26761973,
                    -0.01471928, -0.17001429, -0.00165382,  0.10118251,  0.28316593,
                    -0.10187137,  0.02500786,  0.09213623, -0.06184761,  0.051311  ,
                    -0.13956325,  0.29834348,  0.16425882, -0.20013842,  0.10159607,
                    -0.09226643, -0.09284794, -0.24736227,  0.28198415,  0.18465933],
                   [-0.19596493, -0.26223665, -0.02396852,  0.1405711 , -0.05117449,
                     0.09832071,  0.10009323,  0.08764507, -0.20915532, -0.04817107,
                     0.11512975, -0.0107393 ,  0.06286559, -0.14394692,  0.1831078 ,
                     0.18601051, -0.01792853, -0.010507  ,  0.2988264 ,  0.02924132,
                     0.1502285 , -0.02573505,  0.10515428,  0.32683268, -0.06475027,
                    -0.07946308, -0.33095527, -0.33394814,  0.14654751, -0.18609025],
                   [-0.04332547,  0.18820217,  0.03160366,  0.11940409, -0.22678787,
                     0.09432799, -0.08720809,  0.25600654, -0.14890012,  0.09946848,
                     0.18772584, -0.19526623,  0.0827599 , -0.14669879, -0.12541471,
                    -0.13776924,  0.09574251, -0.2980466 ,  0.10541511, -0.11811657,
                    -0.23554784, -0.01769215, -0.29761636,  0.04322377, -0.04169539,
                     0.04331157, -0.10865457,  0.3526432 ,  0.27452517,  0.01664442],
                   [ 0.17763771, -0.07080895, -0.12558904, -0.13398908, -0.22847766,
                    -0.20403627,  0.07889682,  0.13384837, -0.31691694, -0.13476555,
                    -0.08197045,  0.02778772,  0.02476428,  0.10588782, -0.25830707,
                    -0.24311969,  0.03762388,  0.05451898, -0.13534577, -0.10997833,
                    -0.3139264 ,  0.05126831, -0.00060226, -0.15891929, -0.17077953,
                     0.2362888 ,  0.08467598,  0.01052356,  0.08872832,  0.16418251],
                   [ 0.37544525,  0.0681546 ,  0.07722013, -0.40396348, -0.05511683,
                     0.00878677,  0.33257678,  0.18474084, -0.0799066 ,  0.20011736,
                     0.14146338, -0.15846273, -0.15961201, -0.18772689, -0.17597765,
                    -0.13404477,  0.21314138, -0.13090074,  0.10695033,  0.28710032,
                     0.13358802,  0.3303852 , -0.2687422 , -0.22376198,  0.29356587,
                    -0.03488064,  0.14832059,  0.12624982, -0.20833445,  0.05823356],
                   [ 0.17862728, -0.12085862, -0.07798004,  0.16461669,  0.13114056,
                     0.1119384 , -0.02916402, -0.01834482,  0.03708343, -0.39161655,
                     0.04380961,  0.12685701, -0.20311095,  0.14991562,  0.08968998,
                    -0.1430527 ,  0.3768945 ,  0.2545389 ,  0.09408659,  0.30030465,
                     0.00201878, -0.03300162, -0.31967437,  0.08429171, -0.10358454,
                     0.15462488,  0.15204427, -0.00353977, -0.15648344,  0.03190795],
                   [-0.3349845 , -0.18704857,  0.12660322,  0.27142197, -0.04179126,
                    -0.01659705, -0.15886122,  0.14643206,  0.10317151,  0.139131  ,
                     0.26203057,  0.03828669,  0.17041986,  0.28139216,  0.03020249,
                    -0.21715921,  0.05988631,  0.20941454,  0.27820507, -0.30283943,
                     0.21741417,  0.06876856, -0.0162366 , -0.09319973,  0.16716208,
                    -0.05672812, -0.01678701, -0.33967227,  0.04148872,  0.24174951]],
                  dtype=float32),
        },
        Dense_1: {
            bias: array([-9.28075612e-03, -4.38565016e-03,  6.73728883e-02,  6.70552254e-06,
                   -1.11967325e-02,  1.06356591e-02, -2.26502120e-02, -3.45642865e-03,
                    3.70997190e-02,  9.00812894e-02,  6.70552254e-06,  6.70552254e-06,
                    6.70552254e-06, -1.53630227e-02,  3.88986021e-02], dtype=float32),
            kernel: array([[-0.21578309, -0.08008368, -0.34167936, -0.03616343, -0.04043388,
                    -0.19278756,  0.07816273,  0.3847432 , -0.27097666,  0.03089739,
                    -0.11206758,  0.12151396,  0.38484663,  0.12947203,  0.03026646],
                   [ 0.3035894 ,  0.14900179,  0.02244793,  0.17264117,  0.0011169 ,
                    -0.1606707 ,  0.17210394, -0.19850568, -0.00882789,  0.06376703,
                    -0.09706031, -0.27143008,  0.32902688,  0.01248117, -0.20562333],
                   [ 0.01422736,  0.25237322,  0.26592904, -0.07876748,  0.02570754,
                     0.13746765, -0.3037846 , -0.30066282,  0.228537  ,  0.07397157,
                    -0.05444951,  0.06826244, -0.11475235, -0.04363853, -0.00258049],
                   [-0.03807974,  0.36382473, -0.05991563,  0.1660564 , -0.18014075,
                     0.17624326, -0.24441232, -0.31741685, -0.06890935, -0.04919542,
                     0.13665393, -0.05236177,  0.12887959,  0.2582429 , -0.06479871],
                   [ 0.01582411, -0.00546367,  0.06451672,  0.00377437,  0.05299711,
                     0.09622552, -0.33355796,  0.15772232, -0.00315991,  0.3426076 ,
                    -0.01920256, -0.0157837 ,  0.10247016, -0.02410382,  0.14005862],
                   [-0.1542452 , -0.18916714, -0.12516193, -0.15350978, -0.20895821,
                    -0.03576617,  0.0180776 ,  0.16850933,  0.05937128,  0.03776275,
                     0.07396616,  0.03354299,  0.06906249,  0.15164083, -0.2608541 ],
                   [ 0.0905385 ,  0.3133642 , -0.17575558,  0.05339232, -0.19663817,
                     0.22920834,  0.21465397,  0.14934285,  0.30395645, -0.2403111 ,
                     0.11673826, -0.0449398 , -0.0359446 ,  0.3089288 , -0.01469092],
                   [-0.09350568, -0.09241582,  0.29311585,  0.13808654,  0.14410885,
                     0.11155026,  0.19201808, -0.22068372,  0.0091573 , -0.00837527,
                    -0.10839485, -0.0492375 ,  0.15357326,  0.3894407 , -0.15209287],
                   [ 0.0029071 ,  0.18366341,  0.03765289, -0.01738603,  0.18317957,
                     0.00410259,  0.09655431,  0.07968767,  0.21980065,  0.22737293,
                    -0.15136166,  0.20435053,  0.11874333, -0.3370184 ,  0.11251831],
                   [-0.03699435,  0.05359124, -0.00424996, -0.00427449, -0.20195475,
                    -0.12829332,  0.06293778,  0.13848272, -0.17896764, -0.38953093,
                    -0.07185236,  0.22985502, -0.11224222,  0.04145651, -0.3817618 ],
                   [ 0.23528674,  0.1663438 , -0.08346738,  0.20346904, -0.20409097,
                    -0.07192652,  0.11208971,  0.24518102,  0.23959732, -0.1391288 ,
                    -0.02638906, -0.11256091, -0.27086872, -0.00492385,  0.13006589],
                   [-0.05570208, -0.34653068,  0.298495  , -0.16680127,  0.06143057,
                     0.09288131,  0.1472318 , -0.12598082, -0.01329006, -0.26823848,
                     0.08741044,  0.10009366,  0.1264808 ,  0.13802043,  0.2563799 ],
                   [ 0.01380032, -0.19647142,  0.14879738,  0.0388497 , -0.14403345,
                     0.3500362 , -0.03261025, -0.11959814, -0.35041225, -0.09013529,
                     0.16815332, -0.17363463, -0.26452613,  0.18936844, -0.30342007],
                   [ 0.15264955, -0.16593191,  0.2803555 , -0.02613318,  0.09317887,
                    -0.1145407 , -0.02915843,  0.09115867,  0.16309327,  0.16567504,
                    -0.16353543, -0.02392778,  0.21730614, -0.37557966,  0.36441088],
                   [ 0.27545726, -0.0511765 ,  0.03052394,  0.38374472,  0.18914919,
                    -0.30549794, -0.1365143 , -0.09850363, -0.08355592, -0.17305706,
                     0.00163533,  0.27035654, -0.01430997,  0.01418965, -0.23040168],
                   [-0.11281115, -0.08904882,  0.05267188, -0.03345008,  0.17955093,
                     0.15272899, -0.05194864,  0.10906464,  0.21673168, -0.05776855,
                     0.29315004, -0.272271  ,  0.22718571, -0.04166271,  0.08242701],
                   [ 0.11221845,  0.15372409, -0.13822478, -0.1822467 , -0.26139548,
                     0.22891735, -0.12165104, -0.20519899,  0.39132354,  0.19772069,
                     0.00470251, -0.04089966, -0.17769708,  0.22472003,  0.24131916],
                   [-0.08916578, -0.13332213,  0.11583962, -0.3159293 , -0.05461061,
                    -0.03293931,  0.17573284, -0.03388457, -0.04562169, -0.00728241,
                    -0.20086795, -0.04282369, -0.06481767,  0.00174528,  0.08415282],
                   [ 0.10727057,  0.15353328,  0.09634779,  0.01951087, -0.00729644,
                    -0.25289363, -0.23461105,  0.35619986,  0.1761693 , -0.18046483,
                    -0.25238073, -0.05560882, -0.20357345, -0.13479468,  0.14422338],
                   [-0.32128555,  0.01506339,  0.3208    ,  0.3084674 ,  0.06561027,
                    -0.20671532,  0.07110539,  0.09107907, -0.05795462,  0.06884666,
                    -0.24340847,  0.09923317, -0.39770418, -0.1435657 ,  0.18189654],
                   [ 0.17829058, -0.3734911 , -0.344891  , -0.18513158, -0.1252987 ,
                    -0.359349  , -0.21523046,  0.4066509 , -0.06088345, -0.12821774,
                     0.30891037, -0.05408247,  0.13263397,  0.01792602,  0.22215459],
                   [-0.15978388, -0.19274645, -0.39842924, -0.13794418,  0.24811234,
                    -0.30259767,  0.25340182,  0.36628515, -0.04467097,  0.2068153 ,
                     0.10091661, -0.17184901, -0.01158652,  0.28761518,  0.07140049],
                   [-0.38753265, -0.2148714 , -0.34941888, -0.37459916,  0.00249913,
                    -0.38012785, -0.26021895,  0.06027205, -0.05131304,  0.24082436,
                     0.20541278, -0.09037189, -0.1668249 ,  0.24143052, -0.26692837],
                   [-0.20972599, -0.01015192,  0.16557814,  0.20875406, -0.19013   ,
                    -0.31780058, -0.0311262 , -0.06458683,  0.39772552, -0.26640862,
                     0.31138209, -0.06382139, -0.39696902,  0.10767588,  0.01154487],
                   [ 0.18972392,  0.0260205 ,  0.10645114, -0.21743847, -0.26412916,
                     0.15006566,  0.13827245, -0.21839432, -0.0661045 ,  0.27946475,
                    -0.10810606, -0.32184672, -0.03605315,  0.04213592, -0.01746434],
                   [ 0.04842271, -0.17424968,  0.1226363 ,  0.3272752 , -0.08305582,
                    -0.31486374,  0.10645151,  0.09955801,  0.07176954, -0.20580491,
                     0.04142599, -0.00282551,  0.15971068,  0.19535801, -0.218681  ],
                   [ 0.18084396,  0.0928266 , -0.27906552, -0.3218416 ,  0.08461788,
                    -0.13167469, -0.22216263,  0.06937028,  0.10845792, -0.15438123,
                    -0.02529359,  0.03964274, -0.01773006,  0.04081337,  0.15702999],
                   [ 0.17229266, -0.27421105,  0.03916015, -0.10643585,  0.15348236,
                    -0.40775347, -0.14518811, -0.19719355,  0.15164377,  0.08712097,
                    -0.01809341,  0.03163807, -0.31661576, -0.08889712, -0.3158114 ],
                   [-0.09769286, -0.0287679 ,  0.35823774, -0.2710501 ,  0.32775387,
                     0.08072445,  0.30246186, -0.19245733, -0.17830896,  0.2923805 ,
                     0.09355405, -0.2524669 ,  0.12927876,  0.38659224, -0.394949  ],
                   [ 0.06278192,  0.08469887, -0.00950807,  0.10956495,  0.09936713,
                    -0.19083263,  0.21161489,  0.3930601 ,  0.00441524,  0.20089766,
                    -0.13769451,  0.27256852, -0.0958501 , -0.05921303,  0.33085537]],
                  dtype=float32),
        },
        Dense_2: {
            bias: array([-1.85250044e-02,  1.07492432e-01,  6.70552254e-06, -2.40181834e-02,
                    1.02186531e-01,  1.17786124e-01, -7.42332637e-03,  6.70552254e-06],
                  dtype=float32),
            kernel: array([[ 1.70731947e-01, -9.90723372e-02, -2.29330808e-02,
                     1.22201458e-01, -1.06555134e-01, -3.66996974e-02,
                     1.41981944e-01,  8.84072036e-02],
                   [ 1.80482566e-01,  1.18618101e-01,  5.27178943e-01,
                    -1.53569669e-01, -3.86155099e-02, -1.22957200e-01,
                     6.27236068e-03,  1.60065144e-02],
                   [ 1.45004794e-01,  5.34007728e-01, -3.44348907e-01,
                    -9.43183005e-02,  1.35729849e-01, -8.20837915e-03,
                     1.19165242e-01, -4.11043108e-01],
                   [ 2.49569073e-01,  2.14901850e-01,  3.24754179e-01,
                    -4.91983056e-01, -1.14351839e-01, -2.11404264e-02,
                     7.69451857e-02,  3.31384748e-01],
                   [ 4.12468165e-02, -1.21441126e-01, -3.10934186e-01,
                     3.62142444e-01, -2.74272680e-01, -5.16952693e-01,
                    -5.41899055e-02,  5.59365571e-01],
                   [-4.59907204e-02, -6.31701499e-02,  1.12813368e-01,
                     3.72401834e-01,  1.19809762e-01,  2.30254814e-01,
                    -1.38893276e-01,  4.47092503e-02],
                   [-6.63732141e-02, -3.46694767e-01, -4.84580398e-01,
                     1.17096156e-01, -2.02452645e-01, -3.72330904e-01,
                     5.67477584e-01, -2.42807895e-01],
                   [ 3.36939037e-01, -1.21135429e-01,  3.77379209e-01,
                     4.15782034e-01,  4.28560078e-02, -3.28819275e-01,
                     4.96320784e-01, -1.31850213e-01],
                   [ 3.29444051e-01,  7.28795379e-02,  2.03807220e-01,
                     1.12708807e-01,  3.64434421e-01,  1.61256164e-01,
                    -2.09103152e-01,  5.57109714e-04],
                   [ 1.63463175e-01,  4.33325768e-04, -4.44926351e-01,
                    -9.06594098e-02,  3.35962057e-01,  4.89929318e-01,
                    -1.45875633e-01, -2.39341617e-01],
                   [-2.70898372e-01,  1.98495671e-01,  1.63841411e-01,
                    -3.97940278e-01,  8.93494636e-02,  3.55310917e-01,
                    -4.32752073e-03, -2.55927563e-01],
                   [ 2.71851391e-01, -6.50364608e-02,  5.75378686e-02,
                     8.61035287e-02,  4.62560952e-02,  6.84097558e-02,
                    -3.49434435e-01,  3.20657253e-01],
                   [ 3.87204051e-01,  1.02552503e-01, -3.67724121e-01,
                    -1.37631506e-01, -2.76333094e-03, -9.74716246e-03,
                    -2.03522891e-02, -3.78593743e-01],
                   [ 1.50336102e-01, -1.60084844e-01, -4.73217756e-01,
                     2.41285011e-01,  4.34440225e-02, -3.39211166e-01,
                     1.99615315e-01, -1.64225787e-01],
                   [-9.00761038e-02,  1.48902372e-01, -3.05288136e-02,
                     4.30284500e-01, -3.87649029e-01,  5.50720513e-01,
                    -9.88744646e-02, -3.92888784e-01]], dtype=float32),
        },
        Dense_3: {
            bias: array([0.23489459], dtype=float32),
            kernel: array([[-0.07592051],
                   [ 0.6634396 ],
                   [ 0.193967  ],
                   [-0.21856818],
                   [ 0.44744045],
                   [ 0.7258886 ],
                   [-0.07619607],
                   [-0.20487107]], dtype=float32),
        },
    },
})

最后,让我们验证模型。

[8]:
X_test, y_test = breast_cancer(train=False)
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')
auc=0.9927939731411726

实验到此结束。