Logistic Regression with SPU#

The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.

SPU is a domain specific compiler and runtime suite, which provides provable secure computation service. SPU compiler uses XLA as its front-end IR, which supports diverse AI framework (like Tensorflow, JAX and PyTorch). SPU compiler translates XLA to an IR which could be interpreted by the SPU runtime. Currently SPU team highly recommends using JAX as the frontend.

Learning Objectives:#

After doing this lab, you’ll know how to:

  • How to write a Logistic Regression trainning program with JAX.

  • How to convert a JAX program to a SPU(MPC) program painlessly.

In this lab, we select Breast Cancer as the dataset. We need to decide whether cancer is malignant or benign with 30 features. In the MPC program, two parties will train the model jointly and each party would provide half of features(15).

While, first, let’s just forget MPC settings and just write a Logistic Regression trainning program with JAX directly.

Train a model with JAX#

Load the Dataset#

We are going to split the whole dataset into train and test subsets after normalization with breast_cancer. * if train is True, returns train subsets. In order to simulate trainning with vertical dataset splitting, the party_id is provided. * else, returns test subsets.

[1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer


def breast_cancer(party_id=None, train: bool = True) -> (np.ndarray, np.ndarray):
    scaler = Normalizer(norm='max')
    x, y = load_breast_cancer(return_X_y=True)
    x = scaler.fit_transform(x)
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.2, random_state=42
    )

    if train:
        if party_id:
            if party_id == 1:
                return x_train[:, 15:], _
            else:
                return x_train[:, :15], y_train
        else:
            return x_train, y_train
    else:
        return x_test, y_test

Define the Model#

First, let’s define the loss function, which is a negative log-likelihood in our case.

[2]:
import jax.numpy as jnp


def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))


# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


# Training loss is the negative log-likelihood of the training examples.
def loss(W, b, inputs, targets):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.mean(jnp.log(label_probs))

Second, let’s define a single train step with SGD optimizer. Just to remind you, x1 represents 15 features from one party while x2 represents the other 15 features from the other party.

[3]:
from jax import value_and_grad


def train_step(W, b, x1, x2, y, learning_rate):
    x = jnp.concatenate([x1, x2], axis=1)
    loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b, x, y)
    W -= learning_rate * Wb_grad[0]
    b -= learning_rate * Wb_grad[1]
    return loss_value, W, b

Last, let’s build everything together as a fit method which returns the model and losses of each epoch.

[4]:
def fit(W, b, x1, x2, y, epochs=1, learning_rate=1e-2):
    losses = jnp.array([])
    for _ in range(epochs):
        l, W, b = train_step(W, b, x1, x2, y, learning_rate=learning_rate)
        losses = jnp.append(losses, l)
    return losses, W, b

Validate the Model#

We could use the AUC to validate a binary classification model.

[5]:
from sklearn.metrics import roc_auc_score


def validate_model(W, b, X_test, y_test):
    y_pred = predict(W, b, X_test)
    return roc_auc_score(y_test, y_pred)

If you are interested, we could also plot loss after each epoch of trainning.

[6]:
import matplotlib.pyplot as plt


def plot_losses(losses):
    plt.plot(np.arange(len(losses)), losses)
    plt.xlabel('epoch')
    plt.ylabel('loss')

Have a try!#

Let’s put everything we have together and train a LR model!

[7]:
%matplotlib inline

# Load the data
x1, _ = breast_cancer(party_id=1,train=True)
x2, y = breast_cancer(party_id=2,train=True)

# Hyperparameter
W = jnp.zeros((30,))
b = 0.0
epochs = 10
learning_rate = 1e-2

# Train the model
losses, W, b = fit(W, b, x1, x2, y, epochs=10, learning_rate=1e-2)

# Plot the loss
plot_losses(losses)

# Validate the model
X_test, y_test = breast_cancer(train=False)
auc=validate_model(W,b, X_test, y_test)
print(f'auc={auc}')

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
auc=0.9320340648542417
../_images/tutorial_lr_with_spu_17_2.png

Just remember the plot and AUC here since we would like to do the similar thing with SPU!

Train a Model with SPU#

At this part, we are going to show you how to do the similar trainning with MPC securely!

Init the Environment#

We are going to init three virtual devices on our physical environment. - alice, bob:Two PYU devices for local plaintext computation. - spu:SPU device consists with alice and bob for MPC secure computation.

[8]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], num_cpus=8, log_to_driver=True)

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

2022-08-24 18:08:23.975294: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst

Load the Dataset#

we instruct alice and bob to load the train subset repectively.

[9]:
x1, _ = alice(breast_cancer)(party_id=1)
x2, y = bob(breast_cancer)(party_id=2)

x1, x2, y

[9]:
(<secretflow.device.device.pyu.PYUObject at 0x7f5dd0790490>,
 <secretflow.device.device.pyu.PYUObject at 0x7f5f3c1b5df0>,
 <secretflow.device.device.pyu.PYUObject at 0x7f5f3c1b5d00>)

Before trainning, we need to pass hyperparamters and all data to SPU device. SecretFlow provides two methods: - secretflow.to: transfer a PythonObject or DeviceObject to a specific device. - DeviceObject.to: transfer the DeviceObject to a specific device.

[10]:
device = spu

W = jnp.zeros((30,))
b = 0.0

W_, b_, x1_, x2_, y_ = (
    sf.to(device, W),
    sf.to(device, b),
    x1.to(device),
    x2.to(device),
    y.to(device),
)

Train the model#

Now we are ready to train a LR model with SPU. After trainning, losses and model are SPUObjects which are still secret.

[11]:
losses, W_, b_ = device(
    fit,
    static_argnames=['epochs'],
    num_returns_policy=sf.device.SPUCompilerNumReturnsPolicy.FROM_USER,
    user_specified_num_returns=3,
)(W_, b_, x1_, x2_, y_, epochs=10, learning_rate=1e-2)

losses, W_, b_

[11]:
(<secretflow.device.device.spu.SPUObject at 0x7f5f3c1c5430>,
 <secretflow.device.device.spu.SPUObject at 0x7f5f3c1c52e0>,
 <secretflow.device.device.spu.SPUObject at 0x7f5f3c1c53a0>)

Reveal the result#

In order to check losses and model, we need to convert SPUObject(secret) to Python object(plaintest). SecretFlow provide sf.reveal to convert any DeviceObject to Python object.

Be care with sf.reveal,since it may result in secret leak。

[12]:
%matplotlib inline

losses = sf.reveal(losses)

plot_losses(losses)

(_run pid=859106) 2022-08-24 18:08:29.753681: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_run pid=859102) 2022-08-24 18:08:29.753681: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=859101) 2022-08-24 18:08:30.180754: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=859100) 2022-08-24 18:08:30.180754: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=859103) 2022-08-24 18:08:30.180754: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(pid=859105) 2022-08-24 18:08:30.180754: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/rh-ruby25/root/usr/local/lib64:/opt/rh/rh-ruby25/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(SPURuntime pid=859104) I0824 18:08:31.343399 859104 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yasl::link::internal::ReceiverServiceImpl] is serving on port=28315.
(SPURuntime pid=859104) I0824 18:08:31.343465 859104 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:28315 in web browser.
(SPURuntime pid=859107) I0824 18:08:31.362519 859107 external/com_github_brpc_brpc/src/brpc/server.cpp:1066] Server[yasl::link::internal::ReceiverServiceImpl] is serving on port=27815.
(SPURuntime pid=859107) I0824 18:08:31.362593 859107 external/com_github_brpc_brpc/src/brpc/server.cpp:1069] Check out http://k69b13338.eu95sqa:27815 in web browser.
(SPURuntime pid=859104) I0824 18:08:31.444237 859517 external/com_github_brpc_brpc/src/brpc/socket.cpp:2202] Checking Socket{id=0 addr=127.0.0.1:27815} (0x55b347c82480)
(SPURuntime pid=859104) I0824 18:08:31.444437 859517 external/com_github_brpc_brpc/src/brpc/socket.cpp:2262] Revived Socket{id=0 addr=127.0.0.1:27815} (0x55b347c82480) (Connectable)
(_run pid=859102) [2022-08-24 18:08:31.512] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=859106) [2022-08-24 18:08:31.500] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=859106) 2022-08-24 18:08:31,500,500 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=859102) 2022-08-24 18:08:31,511,511 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=859105) 2022-08-24 18:08:32,005,5 WARNING [xla_bridge.py:backends:265] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(SPURuntime pid=859104) [2022-08-24 18:08:32.504] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(SPURuntime pid=859107) [2022-08-24 18:08:32.507] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
../_images/tutorial_lr_with_spu_29_4.png

Finally, let’s validate the model with AUC.

[13]:
auc = validate_model(sf.reveal(W_), sf.reveal(b_), X_test, y_test)
print(f'auc={auc}')

auc=0.939731411726171

You may find the model from SPU trainning program achieve the same AUC as JAX program.

This is the end of lab.