{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Network with SPU\n", "\n", "> Please read lab [Logistic Regression On SPU](./lr_with_spu.ipynb) first if you have not。\n", "\n", "In lab [Logistic Regression On SPU](./lr_with_spu.ipynb), we have showed how to use SecretFlow/SPU to convert a plaintext JAX training program to a secure MPC training program.\n", "\n", "In this lab, the idea is quite similar but this time we will work with a Neural Network model.\n", "\n", "We are going to use the same dataset and all the settings as lab [Logistic Regression On SPU](./lr_with_spu.ipynb).\n", "\n", "And first, let's work out the plaintext model." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ ">The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "> This tutorial needs more resources than 8c16g, which is the minimum requirement of SecretFlow." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Train a model with JAX/FLAX\n", "\n", "### Load the Dataset\n", "\n", "The below is just copied from lab [Logistic Regression On SPU](./lr_with_spu.ipynb). I'm not going to explain again." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "!{sys.executable} -m pip install flax==0.6.0" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import load_breast_cancer\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import Normalizer\n", "\n", "\n", "def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):\n", " x, y = load_breast_cancer(return_X_y=True)\n", " x = (x - np.min(x)) / (np.max(x) - np.min(x))\n", " x_train, x_test, y_train, y_test = train_test_split(\n", " x, y, test_size=0.2, random_state=42\n", " )\n", "\n", " if train:\n", " if party_id:\n", " if party_id == 1:\n", " return x_train[:, :15], _\n", " else:\n", " return x_train[:, 15:], y_train\n", " else:\n", " return x_train, y_train\n", " else:\n", " return x_test, y_test" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Define the Model\n", "\n", "\n", "We are going to use a 4-layer [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) model with a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function here." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from typing import Sequence\n", "import flax.linen as nn\n", "\n", "\n", "FEATURES = [30, 15, 8, 1]\n", "\n", "\n", "class MLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " for feat in self.features[:-1]:\n", " x = nn.relu(nn.Dense(feat)(x))\n", " x = nn.Dense(self.features[-1])(x)\n", " return x" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then we define the training method here." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "\n", "\n", "def predict(params, x):\n", " # TODO(junfeng): investigate why need to have a duplicated definition in notebook,\n", " # which is not the case in a normal python program.\n", " from typing import Sequence\n", " import flax.linen as nn\n", "\n", " FEATURES = [30, 15, 8, 1]\n", "\n", " class MLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " for feat in self.features[:-1]:\n", " x = nn.relu(nn.Dense(feat)(x))\n", " x = nn.Dense(self.features[-1])(x)\n", " return x\n", "\n", " return MLP(FEATURES).apply(params, x)\n", "\n", "\n", "def loss_func(params, x, y):\n", " pred = predict(params, x)\n", "\n", " def mse(y, pred):\n", " def squared_error(y, y_pred):\n", " return jnp.multiply(y - y_pred, y - y_pred) / 2.0\n", "\n", " return jnp.mean(squared_error(y, pred))\n", "\n", " return mse(y, pred)\n", "\n", "\n", "def train_auto_grad(x1, x2, y, params, n_batch=10, n_epochs=10, step_size=0.01):\n", " x = jnp.concatenate((x1, x2), axis=1)\n", " xs = jnp.array_split(x, len(x) / n_batch, axis=0)\n", " ys = jnp.array_split(y, len(y) / n_batch, axis=0)\n", "\n", " def body_fun(_, loop_carry):\n", " params = loop_carry\n", " for x, y in zip(xs, ys):\n", " _, grads = jax.value_and_grad(loss_func)(params, x, y)\n", " params = jax.tree_util.tree_map(\n", " lambda p, g: p - step_size * g, params, grads\n", " )\n", " return params\n", "\n", " params = jax.lax.fori_loop(0, n_epochs, body_fun, params)\n", " return params\n", "\n", "\n", "def model_init(n_batch=10):\n", " model = MLP(FEATURES)\n", " return model.init(jax.random.PRNGKey(1), jnp.ones((n_batch, FEATURES[0])))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Validate the Model\n", "\n", "We use AUC as the validation metric." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import roc_auc_score\n", "\n", "\n", "def validate_model(params, X_test, y_test):\n", " y_pred = predict(params, X_test)\n", " return roc_auc_score(y_test, y_pred)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### BUILD Together\n", "\n", "Let's put everything together and train a plaintext NN model!" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "auc=0.9927939731411726\n" ] } ], "source": [ "import jax\n", "\n", "# Load the data\n", "x1, _ = breast_cancer(party_id=1, train=True)\n", "x2, y = breast_cancer(party_id=2, train=True)\n", "\n", "\n", "# Hyperparameter\n", "n_batch = 10\n", "n_epochs = 10\n", "step_size = 0.01\n", "\n", "\n", "# Train the model\n", "init_params = model_init(n_batch)\n", "params = train_auto_grad(x1, x2, y, init_params, n_batch, n_epochs, step_size)\n", "\n", "# Test the model\n", "X_test, y_test = breast_cancer(train=False)\n", "auc = validate_model(params, X_test, y_test)\n", "print(f'auc={auc}')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Must keep the number of AUC in mind, we are going to repeat the training with SPU. Let's do that magic!" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Train a Model with SPU" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-04-28 14:35:12,293\tINFO worker.py:1538 -- Started a local Ray instance.\n", "\u001b[2m\u001b[36m(_run pid=175587)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n", "\u001b[2m\u001b[36m(_run pid=175587)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n", "\u001b[2m\u001b[36m(_run pid=175587)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n", "\u001b[2m\u001b[36m(_run pid=175587)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", "\u001b[2m\u001b[36m(_run pid=176242)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n", "\u001b[2m\u001b[36m(_run pid=176242)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n", "\u001b[2m\u001b[36m(_run pid=176242)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n", "\u001b[2m\u001b[36m(_run pid=176242)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", "\u001b[2m\u001b[36m(_run pid=180401)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n", "\u001b[2m\u001b[36m(_run pid=180401)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n", "\u001b[2m\u001b[36m(_run pid=180401)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n", "\u001b[2m\u001b[36m(_run pid=180401)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", "\u001b[2m\u001b[36m(_run pid=177052)\u001b[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n", "\u001b[2m\u001b[36m(_run pid=177052)\u001b[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host\n", "\u001b[2m\u001b[36m(_run pid=177052)\u001b[0m INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n", "\u001b[2m\u001b[36m(_run pid=177052)\u001b[0m WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(SPURuntime pid=187368)\u001b[0m 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n", "\u001b[2m\u001b[36m(SPURuntime pid=187367)\u001b[0m 2023-04-28 14:35:22.143 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127\n" ] } ], "source": [ "import secretflow as sf\n", "\n", "# Check the version of your SecretFlow\n", "print('The version of SecretFlow: {}'.format(sf.__version__))\n", "\n", "# In case you have a running secretflow runtime already.\n", "sf.shutdown()\n", "\n", "sf.init(['alice', 'bob'], address='local')\n", "\n", "alice, bob = sf.PYU('alice'), sf.PYU('bob')\n", "spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))\n", "\n", "x1, _ = alice(breast_cancer)(party_id=1, train=True)\n", "x2, y = bob(breast_cancer)(party_id=2, train=True)\n", "init_params = model_init(n_batch)\n", "\n", "\n", "device = spu\n", "x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)\n", "init_params_ = sf.to(alice, init_params).to(device)\n", "\n", "params_spu = spu(train_auto_grad, static_argnames=['n_batch', 'n_epochs', 'step_size'])(\n", " x1_, x2_, y_, init_params_, n_batch=n_batch, n_epochs=n_epochs, step_size=step_size\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let's check params from SPU program." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FrozenDict({\n", " params: {\n", " Dense_0: {\n", " bias: array([ 6.7055225e-06, 6.7055225e-06, 6.7055225e-06, -8.4322095e-03,\n", " 4.7300994e-02, 4.5412779e-04, 6.7055225e-06, 4.5442879e-03,\n", " 6.7055225e-06, -3.4062415e-02, -8.3989948e-03, 6.7055225e-06,\n", " 6.7055225e-06, 5.6699291e-02, -4.8456341e-03, 6.7055225e-06,\n", " 3.5731569e-02, 6.3510090e-03, 3.0306578e-03, 3.2686546e-02,\n", " 6.7055225e-06, -2.1292433e-02, -7.7798963e-03, 6.7055225e-06,\n", " 2.8470993e-02, 6.7055225e-06, -3.0836165e-03, 4.5374036e-05,\n", " 1.4400020e-02, 2.0861626e-02], dtype=float32),\n", " kernel: array([[-0.14870723, -0.23531294, -0.1493704 , -0.01558255, -0.13322462,\n", " 0.1917662 , -0.03679654, -0.03744406, -0.1417609 , 0.03231682,\n", " 0.12653404, -0.4025072 , -0.1689485 , 0.21399944, -0.13844648,\n", " 0.10585822, -0.11602122, 0.38625073, 0.05966607, 0.06318197,\n", " 0.0779368 , -0.01318966, -0.28804308, -0.09602153, 0.11111972,\n", " -0.08543564, 0.07547122, -0.04118884, -0.38266844, 0.23767346],\n", " [ 0.17795108, 0.2294012 , -0.24440196, -0.14849591, 0.33701816,\n", " 0.0258413 , -0.04214501, 0.41052908, 0.32439357, -0.16435765,\n", " 0.08169709, 0.05259326, 0.3113483 , 0.2931838 , 0.12270276,\n", " -0.38752455, -0.38534215, -0.06536001, -0.25914845, -0.3322725 ,\n", " -0.31587672, -0.29117638, -0.06018265, 0.2297913 , 0.10114388,\n", " -0.01309358, 0.17881514, -0.23215818, 0.3828069 , 0.03806593],\n", " [ 0.1428942 , -0.02135262, 0.16819091, 0.08982845, -0.38852412,\n", " -0.04850367, 0.13870972, -0.05800854, 0.28472927, -0.12711032,\n", " 0.25702882, 0.09648418, 0.11670431, -0.1896179 , -0.03994708,\n", " -0.09573121, 0.07308613, 0.14650668, 0.09226419, 0.03892082,\n", " -0.24624617, 0.03725916, -0.01914255, -0.25209764, 0.17078896,\n", " 0.24982187, -0.0028675 , -0.09844984, 0.20797251, 0.08843645],\n", " [-0.22511783, -0.0044653 , -0.04557581, -0.04286373, -0.13053825,\n", " -0.3426896 , -0.00925103, 0.09015252, -0.2824888 , 0.22022144,\n", " 0.11647445, 0.04475737, -0.05021369, 0.29519165, -0.23622867,\n", " 0.05994891, 0.2596493 , 0.18784739, 0.14603132, 0.2965685 ,\n", " 0.03959064, 0.16071922, -0.11333965, -0.06968905, 0.26477575,\n", " -0.317869 , 0.08121799, 0.25563055, -0.05901612, 0.19531868],\n", " [-0.24254663, 0.07968816, -0.06768736, -0.11746876, 0.1875621 ,\n", " -0.06137984, -0.05366111, -0.06934479, 0.07924516, 0.02541035,\n", " -0.31857365, 0.28704768, -0.06027508, -0.30148876, -0.17660952,\n", " 0.07973847, 0.1614199 , 0.3279493 , 0.20515053, 0.30348837,\n", " 0.2711059 , 0.276556 , 0.07071564, 0.20800509, -0.07333609,\n", " -0.10324922, 0.01553461, 0.31758228, 0.31677115, -0.06809102],\n", " [ 0.07325926, 0.06064408, 0.0530773 , 0.17844556, 0.18787359,\n", " 0.17704393, 0.08110972, 0.01482402, -0.04424939, 0.06166127,\n", " 0.28827167, 0.05878101, 0.26427427, 0.12087436, -0.02181949,\n", " -0.15166327, -0.04630022, 0.00738053, 0.2839891 , 0.10080083,\n", " -0.3035335 , -0.31350654, -0.17609106, 0.11223568, 0.1156193 ,\n", " -0.27605468, -0.06867941, 0.06136122, 0.3082044 , -0.28000844],\n", " [-0.25858068, -0.01556443, 0.27713627, -0.38400537, 0.39872903,\n", " -0.12919384, -0.02736983, 0.17572944, 0.13031955, 0.15870668,\n", " -0.02625516, 0.29411823, 0.03559025, 0.03587727, -0.2966054 ,\n", " -0.16969463, 0.0300006 , 0.16187829, -0.17532285, -0.08767432,\n", " -0.04854703, -0.10537073, 0.08301418, -0.04356302, -0.25446534,\n", " -0.09856299, -0.04166624, -0.04677388, -0.3353408 , -0.11825959],\n", " [ 0.27912897, -0.07000226, -0.02481516, 0.04389155, -0.08830354,\n", " -0.00139034, 0.08731189, -0.24834795, 0.15356407, -0.12887374,\n", " -0.00434314, -0.00279981, -0.07792975, -0.1029453 , 0.2409295 ,\n", " -0.25699303, 0.2918012 , -0.19479287, -0.27555436, 0.01553042,\n", " 0.12703311, 0.1288091 , 0.15366644, 0.1431344 , 0.06207459,\n", " -0.11137639, 0.05906925, -0.11649235, 0.01587239, -0.20323639],\n", " [ 0.06792891, 0.08563136, -0.09104523, 0.17886826, 0.07520616,\n", " -0.13827898, 0.33567435, -0.14805417, -0.03184932, 0.39237124,\n", " -0.1335338 , 0.19828805, 0.05121414, -0.04607381, -0.12948062,\n", " -0.22250798, 0.12677568, 0.39128548, -0.11602047, 0.00093162,\n", " -0.07845107, 0.17064299, 0.2707931 , 0.06743585, 0.07426128,\n", " -0.00924093, -0.0035352 , -0.3685534 , -0.12302665, 0.22056273],\n", " [ 0.02833928, -0.12450014, 0.17981096, 0.15364204, 0.05483492,\n", " 0.19171704, -0.0949284 , 0.06867886, -0.07678194, -0.01938733,\n", " 0.05701402, -0.39338416, 0.05287948, 0.3794972 , 0.24641661,\n", " -0.1212198 , 0.04000506, -0.38034967, 0.19541413, -0.0905077 ,\n", " 0.3206088 , 0.01485404, -0.03493308, 0.11109039, -0.33723742,\n", " -0.30601716, -0.11324729, 0.1596858 , 0.06751473, 0.1008921 ],\n", " [ 0.16805576, 0.19498089, -0.09763785, -0.14558062, 0.10152206,\n", " -0.31742054, -0.11583678, -0.2865575 , -0.10120936, -0.13012367,\n", " 0.19799586, -0.06929106, 0.00183079, -0.06139433, -0.23812771,\n", " 0.14183812, 0.41206583, -0.11150262, -0.07695962, -0.03937718,\n", " -0.05823223, -0.25616592, 0.17551638, -0.05776715, 0.04627597,\n", " 0.12046237, 0.31444448, -0.1823728 , -0.16253875, -0.09766676],\n", " [-0.06190741, -0.11557767, 0.07265058, 0.12529932, 0.20684099,\n", " 0.15767016, -0.08056761, -0.19449666, 0.02133167, 0.23543602,\n", " -0.17700855, -0.35116544, -0.22017023, 0.03137846, 0.10100484,\n", " -0.40086156, -0.13380852, -0.06593318, -0.14122422, -0.17200904,\n", " -0.0666105 , 0.09940979, -0.03091712, 0.25939053, 0.06447808,\n", " -0.2506336 , -0.0349206 , 0.08023839, 0.25556827, -0.2408923 ],\n", " [ 0.00898188, -0.40073588, -0.06301974, 0.06183384, 0.3735768 ,\n", " 0.03177406, 0.27502847, -0.28810993, -0.2024756 , 0.16113877,\n", " -0.21794656, 0.10632099, 0.00266866, -0.27301037, 0.07529524,\n", " 0.07778189, 0.02633543, 0.09457737, -0.28337651, 0.0255892 ,\n", " 0.17133063, 0.04773571, -0.01299471, -0.0919252 , 0.22021984,\n", " -0.1989678 , 0.34153467, 0.08680797, -0.08852738, -0.0090448 ],\n", " [ 0.12035008, 0.12541962, -0.36259866, 0.22371957, -0.07335131,\n", " 0.10498597, 0.00436583, -0.08324738, 0.22863485, -0.14954014,\n", " 0.08159503, -0.3141421 , 0.08762485, -0.05525228, 0.08568875,\n", " -0.02316961, -0.2230854 , 0.02858485, 0.10418503, 0.09759469,\n", " 0.08704272, 0.01555008, 0.17367665, 0.08375961, -0.01750728,\n", " -0.06537268, 0.05048656, -0.22944517, 0.05722432, 0.25090805],\n", " [-0.39710748, 0.10012694, -0.07080103, -0.16264898, -0.13910918,\n", " 0.16161054, 0.16022125, -0.00788775, 0.05428429, 0.16593601,\n", " 0.22370476, 0.36696965, 0.06149913, -0.04857542, 0.3345247 ,\n", " -0.07260554, 0.1938989 , -0.06002848, -0.30302036, 0.17182748,\n", " 0.29064724, 0.21397091, 0.04791559, 0.09810503, 0.1033058 ,\n", " 0.12732479, -0.0579783 , 0.15246823, -0.3666319 , 0.1779919 ],\n", " [ 0.01064624, 0.0888928 , 0.26858085, 0.34396815, 0.06943932,\n", " 0.30761874, -0.15886313, 0.00265385, 0.04297891, -0.06383656,\n", " -0.01197957, -0.10140778, 0.03901416, -0.02126652, 0.13493209,\n", " -0.16070978, -0.27638012, -0.11028586, 0.12214845, -0.2560637 ,\n", " -0.08863154, 0.03597671, -0.1732396 , 0.12559041, 0.14788477,\n", " 0.09702435, 0.17843248, 0.08070756, 0.0718791 , 0.08296195],\n", " [ 0.14691886, 0.13540354, -0.05013047, -0.2566406 , -0.2376638 ,\n", " 0.21672072, 0.1372795 , -0.03882806, 0.39052176, 0.0047731 ,\n", " 0.14544334, -0.0696618 , -0.15187763, 0.06678917, -0.24012098,\n", " 0.31160212, 0.06627946, -0.2530402 , -0.20175886, -0.22604358,\n", " 0.1381416 , -0.14101216, 0.3429103 , 0.12955913, 0.2845845 ,\n", " 0.06188303, -0.22960348, 0.2912202 , -0.08082792, -0.3445377 ],\n", " [-0.01824994, 0.12698065, 0.11829151, -0.08935194, -0.04362963,\n", " -0.06175369, -0.1114524 , -0.06696388, -0.34100425, -0.25512362,\n", " -0.1483988 , -0.20127416, -0.00367533, 0.05239835, 0.06488706,\n", " 0.08272076, 0.05891787, 0.2134408 , -0.13793291, 0.30933803,\n", " -0.09876332, -0.15072244, -0.10377637, 0.03409749, 0.0937078 ,\n", " -0.22452421, 0.3597254 , 0.24009626, -0.03083205, -0.10381168],\n", " [-0.14538439, 0.17941016, 0.01639399, -0.2706253 , -0.02600642,\n", " -0.03973 , -0.0325162 , 0.03153259, -0.15472709, -0.09655666,\n", " 0.04076509, 0.1300038 , -0.19558378, -0.17638195, 0.12240331,\n", " -0.26903665, 0.2714493 , -0.07004572, -0.07335924, 0.03825237,\n", " 0.22632292, 0.3012138 , 0.02217355, -0.30002278, -0.06066401,\n", " -0.07689169, -0.37136257, 0.19665234, -0.10525645, -0.27408272],\n", " [ 0.05384398, 0.03158583, -0.00409974, -0.04451011, -0.10076478,\n", " -0.06426084, 0.3136195 , -0.13606365, 0.1243284 , -0.10924114,\n", " -0.03940558, 0.22020963, -0.07174113, 0.08709462, 0.04955287,\n", " 0.36317343, 0.00659794, -0.15838777, 0.09210019, -0.17414865,\n", " -0.14202411, 0.3834263 , 0.02247368, 0.00736032, -0.02805607,\n", " -0.15887989, 0.03910746, -0.0943727 , 0.21787158, 0.01440434],\n", " [-0.09300622, -0.19802521, -0.31412005, 0.17171307, 0.1331803 ,\n", " -0.14113024, -0.21318011, -0.16237472, 0.09434846, 0.14660788,\n", " 0.01858762, -0.02211154, -0.14670722, 0.39278403, -0.20136856,\n", " 0.10904545, -0.02885009, -0.15209475, 0.1743193 , 0.0778787 ,\n", " 0.09585676, 0.10286772, 0.1895318 , -0.15744607, 0.0972386 ,\n", " -0.26544875, -0.05130047, 0.08041063, 0.05855417, 0.24786705],\n", " [ 0.05508722, 0.23071642, 0.00278442, -0.05163229, -0.13318591,\n", " 0.17231207, -0.0383717 , 0.17234325, -0.12098849, -0.12200612,\n", " -0.165717 , -0.08695543, -0.01522441, -0.31668693, 0.196136 ,\n", " -0.20849878, 0.34565175, 0.252592 , 0.03059202, -0.23635055,\n", " -0.02455017, -0.07401715, 0.18046305, 0.08005303, 0.02341022,\n", " 0.05160871, 0.0830403 , -0.10961437, 0.2051303 , 0.05485763],\n", " [-0.29294217, 0.01583408, -0.00052598, 0.07539546, 0.17627907,\n", " 0.16075702, 0.00591798, -0.02526975, -0.2719347 , -0.2642147 ,\n", " 0.17578189, 0.26844388, -0.16066906, 0.00551553, -0.41348425,\n", " 0.1321568 , 0.2071938 , -0.09202607, -0.32119918, 0.03001858,\n", " -0.03515013, -0.11420041, 0.00692059, 0.06027223, 0.31073922,\n", " 0.31373912, 0.15468763, 0.23844069, 0.20547047, 0.165754 ],\n", " [-0.1317544 , -0.09716719, -0.2110814 , 0.30688593, 0.13689038,\n", " 0.25466746, -0.23185365, 0.265381 , -0.20205005, 0.26761973,\n", " -0.01471928, -0.17001429, -0.00165382, 0.10118251, 0.28316593,\n", " -0.10187137, 0.02500786, 0.09213623, -0.06184761, 0.051311 ,\n", " -0.13956325, 0.29834348, 0.16425882, -0.20013842, 0.10159607,\n", " -0.09226643, -0.09284794, -0.24736227, 0.28198415, 0.18465933],\n", " [-0.19596493, -0.26223665, -0.02396852, 0.1405711 , -0.05117449,\n", " 0.09832071, 0.10009323, 0.08764507, -0.20915532, -0.04817107,\n", " 0.11512975, -0.0107393 , 0.06286559, -0.14394692, 0.1831078 ,\n", " 0.18601051, -0.01792853, -0.010507 , 0.2988264 , 0.02924132,\n", " 0.1502285 , -0.02573505, 0.10515428, 0.32683268, -0.06475027,\n", " -0.07946308, -0.33095527, -0.33394814, 0.14654751, -0.18609025],\n", " [-0.04332547, 0.18820217, 0.03160366, 0.11940409, -0.22678787,\n", " 0.09432799, -0.08720809, 0.25600654, -0.14890012, 0.09946848,\n", " 0.18772584, -0.19526623, 0.0827599 , -0.14669879, -0.12541471,\n", " -0.13776924, 0.09574251, -0.2980466 , 0.10541511, -0.11811657,\n", " -0.23554784, -0.01769215, -0.29761636, 0.04322377, -0.04169539,\n", " 0.04331157, -0.10865457, 0.3526432 , 0.27452517, 0.01664442],\n", " [ 0.17763771, -0.07080895, -0.12558904, -0.13398908, -0.22847766,\n", " -0.20403627, 0.07889682, 0.13384837, -0.31691694, -0.13476555,\n", " -0.08197045, 0.02778772, 0.02476428, 0.10588782, -0.25830707,\n", " -0.24311969, 0.03762388, 0.05451898, -0.13534577, -0.10997833,\n", " -0.3139264 , 0.05126831, -0.00060226, -0.15891929, -0.17077953,\n", " 0.2362888 , 0.08467598, 0.01052356, 0.08872832, 0.16418251],\n", " [ 0.37544525, 0.0681546 , 0.07722013, -0.40396348, -0.05511683,\n", " 0.00878677, 0.33257678, 0.18474084, -0.0799066 , 0.20011736,\n", " 0.14146338, -0.15846273, -0.15961201, -0.18772689, -0.17597765,\n", " -0.13404477, 0.21314138, -0.13090074, 0.10695033, 0.28710032,\n", " 0.13358802, 0.3303852 , -0.2687422 , -0.22376198, 0.29356587,\n", " -0.03488064, 0.14832059, 0.12624982, -0.20833445, 0.05823356],\n", " [ 0.17862728, -0.12085862, -0.07798004, 0.16461669, 0.13114056,\n", " 0.1119384 , -0.02916402, -0.01834482, 0.03708343, -0.39161655,\n", " 0.04380961, 0.12685701, -0.20311095, 0.14991562, 0.08968998,\n", " -0.1430527 , 0.3768945 , 0.2545389 , 0.09408659, 0.30030465,\n", " 0.00201878, -0.03300162, -0.31967437, 0.08429171, -0.10358454,\n", " 0.15462488, 0.15204427, -0.00353977, -0.15648344, 0.03190795],\n", " [-0.3349845 , -0.18704857, 0.12660322, 0.27142197, -0.04179126,\n", " -0.01659705, -0.15886122, 0.14643206, 0.10317151, 0.139131 ,\n", " 0.26203057, 0.03828669, 0.17041986, 0.28139216, 0.03020249,\n", " -0.21715921, 0.05988631, 0.20941454, 0.27820507, -0.30283943,\n", " 0.21741417, 0.06876856, -0.0162366 , -0.09319973, 0.16716208,\n", " -0.05672812, -0.01678701, -0.33967227, 0.04148872, 0.24174951]],\n", " dtype=float32),\n", " },\n", " Dense_1: {\n", " bias: array([-9.28075612e-03, -4.38565016e-03, 6.73728883e-02, 6.70552254e-06,\n", " -1.11967325e-02, 1.06356591e-02, -2.26502120e-02, -3.45642865e-03,\n", " 3.70997190e-02, 9.00812894e-02, 6.70552254e-06, 6.70552254e-06,\n", " 6.70552254e-06, -1.53630227e-02, 3.88986021e-02], dtype=float32),\n", " kernel: array([[-0.21578309, -0.08008368, -0.34167936, -0.03616343, -0.04043388,\n", " -0.19278756, 0.07816273, 0.3847432 , -0.27097666, 0.03089739,\n", " -0.11206758, 0.12151396, 0.38484663, 0.12947203, 0.03026646],\n", " [ 0.3035894 , 0.14900179, 0.02244793, 0.17264117, 0.0011169 ,\n", " -0.1606707 , 0.17210394, -0.19850568, -0.00882789, 0.06376703,\n", " -0.09706031, -0.27143008, 0.32902688, 0.01248117, -0.20562333],\n", " [ 0.01422736, 0.25237322, 0.26592904, -0.07876748, 0.02570754,\n", " 0.13746765, -0.3037846 , -0.30066282, 0.228537 , 0.07397157,\n", " -0.05444951, 0.06826244, -0.11475235, -0.04363853, -0.00258049],\n", " [-0.03807974, 0.36382473, -0.05991563, 0.1660564 , -0.18014075,\n", " 0.17624326, -0.24441232, -0.31741685, -0.06890935, -0.04919542,\n", " 0.13665393, -0.05236177, 0.12887959, 0.2582429 , -0.06479871],\n", " [ 0.01582411, -0.00546367, 0.06451672, 0.00377437, 0.05299711,\n", " 0.09622552, -0.33355796, 0.15772232, -0.00315991, 0.3426076 ,\n", " -0.01920256, -0.0157837 , 0.10247016, -0.02410382, 0.14005862],\n", " [-0.1542452 , -0.18916714, -0.12516193, -0.15350978, -0.20895821,\n", " -0.03576617, 0.0180776 , 0.16850933, 0.05937128, 0.03776275,\n", " 0.07396616, 0.03354299, 0.06906249, 0.15164083, -0.2608541 ],\n", " [ 0.0905385 , 0.3133642 , -0.17575558, 0.05339232, -0.19663817,\n", " 0.22920834, 0.21465397, 0.14934285, 0.30395645, -0.2403111 ,\n", " 0.11673826, -0.0449398 , -0.0359446 , 0.3089288 , -0.01469092],\n", " [-0.09350568, -0.09241582, 0.29311585, 0.13808654, 0.14410885,\n", " 0.11155026, 0.19201808, -0.22068372, 0.0091573 , -0.00837527,\n", " -0.10839485, -0.0492375 , 0.15357326, 0.3894407 , -0.15209287],\n", " [ 0.0029071 , 0.18366341, 0.03765289, -0.01738603, 0.18317957,\n", " 0.00410259, 0.09655431, 0.07968767, 0.21980065, 0.22737293,\n", " -0.15136166, 0.20435053, 0.11874333, -0.3370184 , 0.11251831],\n", " [-0.03699435, 0.05359124, -0.00424996, -0.00427449, -0.20195475,\n", " -0.12829332, 0.06293778, 0.13848272, -0.17896764, -0.38953093,\n", " -0.07185236, 0.22985502, -0.11224222, 0.04145651, -0.3817618 ],\n", " [ 0.23528674, 0.1663438 , -0.08346738, 0.20346904, -0.20409097,\n", " -0.07192652, 0.11208971, 0.24518102, 0.23959732, -0.1391288 ,\n", " -0.02638906, -0.11256091, -0.27086872, -0.00492385, 0.13006589],\n", " [-0.05570208, -0.34653068, 0.298495 , -0.16680127, 0.06143057,\n", " 0.09288131, 0.1472318 , -0.12598082, -0.01329006, -0.26823848,\n", " 0.08741044, 0.10009366, 0.1264808 , 0.13802043, 0.2563799 ],\n", " [ 0.01380032, -0.19647142, 0.14879738, 0.0388497 , -0.14403345,\n", " 0.3500362 , -0.03261025, -0.11959814, -0.35041225, -0.09013529,\n", " 0.16815332, -0.17363463, -0.26452613, 0.18936844, -0.30342007],\n", " [ 0.15264955, -0.16593191, 0.2803555 , -0.02613318, 0.09317887,\n", " -0.1145407 , -0.02915843, 0.09115867, 0.16309327, 0.16567504,\n", " -0.16353543, -0.02392778, 0.21730614, -0.37557966, 0.36441088],\n", " [ 0.27545726, -0.0511765 , 0.03052394, 0.38374472, 0.18914919,\n", " -0.30549794, -0.1365143 , -0.09850363, -0.08355592, -0.17305706,\n", " 0.00163533, 0.27035654, -0.01430997, 0.01418965, -0.23040168],\n", " [-0.11281115, -0.08904882, 0.05267188, -0.03345008, 0.17955093,\n", " 0.15272899, -0.05194864, 0.10906464, 0.21673168, -0.05776855,\n", " 0.29315004, -0.272271 , 0.22718571, -0.04166271, 0.08242701],\n", " [ 0.11221845, 0.15372409, -0.13822478, -0.1822467 , -0.26139548,\n", " 0.22891735, -0.12165104, -0.20519899, 0.39132354, 0.19772069,\n", " 0.00470251, -0.04089966, -0.17769708, 0.22472003, 0.24131916],\n", " [-0.08916578, -0.13332213, 0.11583962, -0.3159293 , -0.05461061,\n", " -0.03293931, 0.17573284, -0.03388457, -0.04562169, -0.00728241,\n", " -0.20086795, -0.04282369, -0.06481767, 0.00174528, 0.08415282],\n", " [ 0.10727057, 0.15353328, 0.09634779, 0.01951087, -0.00729644,\n", " -0.25289363, -0.23461105, 0.35619986, 0.1761693 , -0.18046483,\n", " -0.25238073, -0.05560882, -0.20357345, -0.13479468, 0.14422338],\n", " [-0.32128555, 0.01506339, 0.3208 , 0.3084674 , 0.06561027,\n", " -0.20671532, 0.07110539, 0.09107907, -0.05795462, 0.06884666,\n", " -0.24340847, 0.09923317, -0.39770418, -0.1435657 , 0.18189654],\n", " [ 0.17829058, -0.3734911 , -0.344891 , -0.18513158, -0.1252987 ,\n", " -0.359349 , -0.21523046, 0.4066509 , -0.06088345, -0.12821774,\n", " 0.30891037, -0.05408247, 0.13263397, 0.01792602, 0.22215459],\n", " [-0.15978388, -0.19274645, -0.39842924, -0.13794418, 0.24811234,\n", " -0.30259767, 0.25340182, 0.36628515, -0.04467097, 0.2068153 ,\n", " 0.10091661, -0.17184901, -0.01158652, 0.28761518, 0.07140049],\n", " [-0.38753265, -0.2148714 , -0.34941888, -0.37459916, 0.00249913,\n", " -0.38012785, -0.26021895, 0.06027205, -0.05131304, 0.24082436,\n", " 0.20541278, -0.09037189, -0.1668249 , 0.24143052, -0.26692837],\n", " [-0.20972599, -0.01015192, 0.16557814, 0.20875406, -0.19013 ,\n", " -0.31780058, -0.0311262 , -0.06458683, 0.39772552, -0.26640862,\n", " 0.31138209, -0.06382139, -0.39696902, 0.10767588, 0.01154487],\n", " [ 0.18972392, 0.0260205 , 0.10645114, -0.21743847, -0.26412916,\n", " 0.15006566, 0.13827245, -0.21839432, -0.0661045 , 0.27946475,\n", " -0.10810606, -0.32184672, -0.03605315, 0.04213592, -0.01746434],\n", " [ 0.04842271, -0.17424968, 0.1226363 , 0.3272752 , -0.08305582,\n", " -0.31486374, 0.10645151, 0.09955801, 0.07176954, -0.20580491,\n", " 0.04142599, -0.00282551, 0.15971068, 0.19535801, -0.218681 ],\n", " [ 0.18084396, 0.0928266 , -0.27906552, -0.3218416 , 0.08461788,\n", " -0.13167469, -0.22216263, 0.06937028, 0.10845792, -0.15438123,\n", " -0.02529359, 0.03964274, -0.01773006, 0.04081337, 0.15702999],\n", " [ 0.17229266, -0.27421105, 0.03916015, -0.10643585, 0.15348236,\n", " -0.40775347, -0.14518811, -0.19719355, 0.15164377, 0.08712097,\n", " -0.01809341, 0.03163807, -0.31661576, -0.08889712, -0.3158114 ],\n", " [-0.09769286, -0.0287679 , 0.35823774, -0.2710501 , 0.32775387,\n", " 0.08072445, 0.30246186, -0.19245733, -0.17830896, 0.2923805 ,\n", " 0.09355405, -0.2524669 , 0.12927876, 0.38659224, -0.394949 ],\n", " [ 0.06278192, 0.08469887, -0.00950807, 0.10956495, 0.09936713,\n", " -0.19083263, 0.21161489, 0.3930601 , 0.00441524, 0.20089766,\n", " -0.13769451, 0.27256852, -0.0958501 , -0.05921303, 0.33085537]],\n", " dtype=float32),\n", " },\n", " Dense_2: {\n", " bias: array([-1.85250044e-02, 1.07492432e-01, 6.70552254e-06, -2.40181834e-02,\n", " 1.02186531e-01, 1.17786124e-01, -7.42332637e-03, 6.70552254e-06],\n", " dtype=float32),\n", " kernel: array([[ 1.70731947e-01, -9.90723372e-02, -2.29330808e-02,\n", " 1.22201458e-01, -1.06555134e-01, -3.66996974e-02,\n", " 1.41981944e-01, 8.84072036e-02],\n", " [ 1.80482566e-01, 1.18618101e-01, 5.27178943e-01,\n", " -1.53569669e-01, -3.86155099e-02, -1.22957200e-01,\n", " 6.27236068e-03, 1.60065144e-02],\n", " [ 1.45004794e-01, 5.34007728e-01, -3.44348907e-01,\n", " -9.43183005e-02, 1.35729849e-01, -8.20837915e-03,\n", " 1.19165242e-01, -4.11043108e-01],\n", " [ 2.49569073e-01, 2.14901850e-01, 3.24754179e-01,\n", " -4.91983056e-01, -1.14351839e-01, -2.11404264e-02,\n", " 7.69451857e-02, 3.31384748e-01],\n", " [ 4.12468165e-02, -1.21441126e-01, -3.10934186e-01,\n", " 3.62142444e-01, -2.74272680e-01, -5.16952693e-01,\n", " -5.41899055e-02, 5.59365571e-01],\n", " [-4.59907204e-02, -6.31701499e-02, 1.12813368e-01,\n", " 3.72401834e-01, 1.19809762e-01, 2.30254814e-01,\n", " -1.38893276e-01, 4.47092503e-02],\n", " [-6.63732141e-02, -3.46694767e-01, -4.84580398e-01,\n", " 1.17096156e-01, -2.02452645e-01, -3.72330904e-01,\n", " 5.67477584e-01, -2.42807895e-01],\n", " [ 3.36939037e-01, -1.21135429e-01, 3.77379209e-01,\n", " 4.15782034e-01, 4.28560078e-02, -3.28819275e-01,\n", " 4.96320784e-01, -1.31850213e-01],\n", " [ 3.29444051e-01, 7.28795379e-02, 2.03807220e-01,\n", " 1.12708807e-01, 3.64434421e-01, 1.61256164e-01,\n", " -2.09103152e-01, 5.57109714e-04],\n", " [ 1.63463175e-01, 4.33325768e-04, -4.44926351e-01,\n", " -9.06594098e-02, 3.35962057e-01, 4.89929318e-01,\n", " -1.45875633e-01, -2.39341617e-01],\n", " [-2.70898372e-01, 1.98495671e-01, 1.63841411e-01,\n", " -3.97940278e-01, 8.93494636e-02, 3.55310917e-01,\n", " -4.32752073e-03, -2.55927563e-01],\n", " [ 2.71851391e-01, -6.50364608e-02, 5.75378686e-02,\n", " 8.61035287e-02, 4.62560952e-02, 6.84097558e-02,\n", " -3.49434435e-01, 3.20657253e-01],\n", " [ 3.87204051e-01, 1.02552503e-01, -3.67724121e-01,\n", " -1.37631506e-01, -2.76333094e-03, -9.74716246e-03,\n", " -2.03522891e-02, -3.78593743e-01],\n", " [ 1.50336102e-01, -1.60084844e-01, -4.73217756e-01,\n", " 2.41285011e-01, 4.34440225e-02, -3.39211166e-01,\n", " 1.99615315e-01, -1.64225787e-01],\n", " [-9.00761038e-02, 1.48902372e-01, -3.05288136e-02,\n", " 4.30284500e-01, -3.87649029e-01, 5.50720513e-01,\n", " -9.88744646e-02, -3.92888784e-01]], dtype=float32),\n", " },\n", " Dense_3: {\n", " bias: array([0.23489459], dtype=float32),\n", " kernel: array([[-0.07592051],\n", " [ 0.6634396 ],\n", " [ 0.193967 ],\n", " [-0.21856818],\n", " [ 0.44744045],\n", " [ 0.7258886 ],\n", " [-0.07619607],\n", " [-0.20487107]], dtype=float32),\n", " },\n", " },\n", "})\n" ] } ], "source": [ "params_spu = spu(train_auto_grad)(x1_, x2_, y_, init_params)\n", "params = sf.reveal(params_spu)\n", "print(params)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, let's validate the model." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "auc=0.9927939731411726\n" ] } ], "source": [ "X_test, y_test = breast_cancer(train=False)\n", "auc = validate_model(params, X_test, y_test)\n", "print(f'auc={auc}')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "This is the end of the lab." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.15" }, "vscode": { "interpreter": { "hash": "db45a4cb4cd37a8de684dfb7fcf899b68fccb8bd32d97c5ad13e5de1245c0986" } } }, "nbformat": 4, "nbformat_minor": 2 }