隐语SecretFlow实际场景MPC算法开发实践#

This tutorial is only available in Chinese.

推荐使用conda创建一个新环境 > conda create -n sf python=3.8

直接使用pip安装secretflow > pip install -U secretflow

基于secretflow:0.8.2b2版本

此代码示例主要是展示了如何基于secretflow以及SPU隐私计算设备完成一个实际的应用的开发,推荐先看前一个教程spu_basics熟悉基本的SPU概念。

任务介绍#

Vehicle Insurance Claim Fraud Detection

该数据集来源于kaggle,包含 - 车辆数据集-属性、模型、事故详细信息等 - 保单详细信息-保单类型、有效期等

目标是检测索赔申请是否欺诈: 字段FraudFound_P (0 or 1) 即为预测的target值,是一个典型的二分类场景

实验目标#

在本次实验中,我们将会利用一个开源数据集在隐语上完成隐私保护的逻辑回归、神经网络模型和XGB模型。主要涉及到如下的几个流程: 1. 数据加载 2. 数据洞察 3. 数据预处理 4. 模型构建 5. 模型的训练与预测

前置工作#

Ray集群启动(多机部署)#

考虑多机部署的情况,在启动secretflow之前需要先将ray集群启动。在header节点和worker节点上各自执行下述的指令。 > P.S. 启动集群之后,可以执行ray status看一下集群是否正确启动完成

Header节点

RAY_DISABLE_REMOTE_CODE=true \
ray start --head --node-ip-address="head_ip" --port="head_port" --resources='{"alice": 20}' --include-dashboard=False

Worker节点

RAY_DISABLE_REMOTE_CODE=true \
ray start --address="head_ip:head_port" --resources='{"bob": 20}'
[1]:
# 如下是多机版初始化secretflow的代码,需要给出header节点的IP和PORT
# head_ip = "xxx"
# head_port = "xxx"
# sf.init(address=f'{head_ip}:{head_port}')

单机部署#

我们在此使用单机部署的方式做一个样例展示。 通过调用sf.init()我们实例化了一个ray集群,有5个节点,也就对应了5个物理设备。

[2]:
import secretflow as sf

# Check the version of your SecretFlow
print('The version of SecretFlow: {}'.format(sf.__version__))

sf.shutdown()
# Standalone Mode
sf.init(['alice', 'bob', 'carol', 'davy', 'eric'], address='local')
2023-05-24 14:55:43,444 INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 

定义明文计算设备PYU#

我们在启动了上述5个节点之后,明确隐语中的逻辑设备。这里我们将alice、bob、carol三方作为数据的提供方,可以本地执行明文计算,也就是 PYU (PYthon runtime Unit) 设备。

[3]:
alice = sf.PYU('alice')
bob = sf.PYU('bob')
carol = sf.PYU('carol')

print(alice)
alice
(raylet) /home/zoupeicheng.zpc/miniconda3/envs/py38/lib/python3.8/site-packages/ray/dashboard/modules/reporter/reporter_agent.py:56: UserWarning: `gpustat` package is not installed. GPU monitoring is not available. To have full functionality of the dashboard please install `pip install ray[default]`.)
(raylet)   warnings.warn(

定义密文计算设备SPU (3PC)#

进一步,我们以SPU (Secure Processing Unit) 为例,选择3个物理节点组成基于MPC(下例为三方的ABY3协议)的隐私计算设备。

[4]:
import spu
from secretflow.utils.testing import unused_tcp_port

aby3_cluster_def = {
    'nodes': [
        {
            'party': 'alice',
            'address': f'127.0.0.1:{unused_tcp_port()}',
        },
        {'party': 'bob', 'id': 'local:1', 'address': f'127.0.0.1:{unused_tcp_port()}'},
        {
            'party': 'carol',
            'address': f'127.0.0.1:{unused_tcp_port()}',
        },
    ],
    'runtime_config': {
        'protocol': spu.spu_pb2.ABY3,
        'field': spu.spu_pb2.FM64,
    },
}

my_spu = sf.SPU(aby3_cluster_def)

数据加载#

Load Data (Mock)#

在定义好隐语中的逻辑设备概念之后,我们演示一下如何进行数据的读入。这里使用一个mock的data load方法get_data_mock()做一个演示。

[5]:
def get_data_mock():
    return 2


x_plaintext = get_data_mock()
print(f"x_plaintext: {x_plaintext}")
x_plaintext: 2

指定PYU设备读取数据

[6]:
x_alice_pyu = alice(get_data_mock)()

print(f"Plaintext Python Object: {x_plaintext}, PYU object: {x_alice_pyu}")
print(f"Reveal PYU object: {sf.reveal(x_alice_pyu)}")
Plaintext Python Object: 2, PYU object: <secretflow.device.device.pyu.PYUObject object at 0x7fcfba18db20>
Reveal PYU object: 2

PYU->SPU 数据转换

[7]:
x_alice_spu = x_alice_pyu.to(my_spu)
print(f"SPU object: {x_alice_spu}")

print(f"Reveal SPU object: {sf.reveal(x_alice_spu)}")
SPU object: <secretflow.device.device.spu.SPUObject object at 0x7fd000234ca0>
(_run pid=2294419) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2294419) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294419) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294419) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2294419) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2294419) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Reveal SPU object: 2

Load Data (Distributed)#

我们下面考虑对一个实际应用场景的数据进行读取,也就是全集数据垂直分布在不同的参与方中。 > 出于演示的目的,我们这里将中心化的明文数据进行垂直分割的拆分,首先观察下此数据集的特征。

读入明文全集数据#

[8]:
import os

"""
Create dir to save dataset files
This will create a directory `data` to store the dataset file
"""
if not os.path.exists('data'):
    os.mkdir('data')

"""
The original data is from Kaggle: https://www.kaggle.com/datasets/shivamb/vehicle-claim-fraud-detection.
We promise we only use the data for demo only.
"""
path = "https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/vehicle_nsurance_claim/fraud_oracle.csv"
if not os.path.exists('data/fraud_oracle.csv'):
    res = os.system('cd data && wget {}'.format(path))
    if res != 0:
        raise Exception('File: {} download fails!'.format(path))
else:
    print(f'File already downloaded.')
File already downloaded.
[9]:
from sklearn.model_selection import train_test_split
import pandas as pd

"""
This should point to the data downloaded from Kaggle.
By default, the .csv file shall be in the data directory
"""
full_data_path = 'data/fraud_oracle.csv'
df = pd.read_csv(full_data_path)
df.head()
[9]:
Month WeekOfMonth DayOfWeek Make AccidentArea DayOfWeekClaimed MonthClaimed WeekOfMonthClaimed Sex MaritalStatus ... AgeOfVehicle AgeOfPolicyHolder PoliceReportFiled WitnessPresent AgentType NumberOfSuppliments AddressChange_Claim NumberOfCars Year BasePolicy
0 Dec 5 Wednesday Honda Urban Tuesday Jan 1 Female Single ... 3 years 26 to 30 No No External none 1 year 3 to 4 1994 Liability
1 Jan 3 Wednesday Honda Urban Monday Jan 4 Male Single ... 6 years 31 to 35 Yes No External none no change 1 vehicle 1994 Collision
2 Oct 5 Friday Honda Urban Thursday Nov 2 Male Married ... 7 years 41 to 50 No No External none no change 1 vehicle 1994 Collision
3 Jun 2 Saturday Toyota Rural Friday Jul 1 Male Married ... more than 7 51 to 65 Yes No External more than 5 no change 1 vehicle 1994 Liability
4 Jan 5 Monday Honda Urban Tuesday Feb 2 Female Single ... 5 years 31 to 35 No No External none no change 1 vehicle 1994 Collision

5 rows × 33 columns

数据三方垂直拆分#

我们首先对这个数据进行一个拆分的处理,来模拟一个数据垂直分割的三方场景:

  • alice持有前10个属性

  • bob持有中间的10个属性

  • carol持有剩下的所有属性以及标签值

同时为了方便各方之间的样本做对齐,我们加了一个新的特征UID来标识数据样本。

我们预先基于sklearn将全集数据拆分成训练集和测试集,方便后续进行模型训练效果的验证。

[10]:
train_alice_path = "data/alice_train.csv"
train_bob_path = "data/bob_train.csv"
train_carol_path = "data/carol_train.csv"

test_alice_path = "data/alice_test.csv"
test_bob_path = "data/bob_test.csv"
test_carol_path = "data/carol_test.csv"


def load_dataset_full(data_path):
    df = pd.read_csv(data_path)
    df = df.drop([0])
    df = df.loc[df['DayOfWeekClaimed'] != '0']
    y = df['FraudFound_P']
    X = df.drop(columns='FraudFound_P')
    return X, y


def split_data():
    x, y = load_dataset_full(full_data_path)
    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.3, random_state=10
    )

    print(x_train.shape)
    train_alice_csv = x_train.iloc[:, :10]
    train_bob_csv = x_train.iloc[:, 10:20]
    train_carol_csv = pd.concat([x_train.iloc[:, 20:], y_train], axis=1)

    train_alice_csv.to_csv(train_alice_path, index_label='UID')
    train_bob_csv.to_csv(train_bob_path, index_label='UID')
    train_carol_csv.to_csv(train_carol_path, index_label='UID')

    print(x_test.shape)
    test_alice_csv = x_test.iloc[:, :10]
    test_bob_csv = x_test.iloc[:, 10:20]
    test_carol_csv = pd.concat([x_test.iloc[:, 20:], y_test], axis=1)

    test_alice_csv.to_csv(test_alice_path, index_label='UID')
    test_bob_csv.to_csv(test_bob_path, index_label='UID')
    test_carol_csv.to_csv(test_carol_path, index_label='UID')


split_data()
(10792, 32)
(4626, 32)
[11]:
alice_train_df = pd.read_csv(train_alice_path)
alice_train_df.head()
[11]:
UID Month WeekOfMonth DayOfWeek Make AccidentArea DayOfWeekClaimed MonthClaimed WeekOfMonthClaimed Sex MaritalStatus
0 2853 Mar 4 Sunday Toyota Urban Friday Apr 1 Male Married
1 7261 Apr 4 Saturday Honda Urban Monday Apr 4 Male Married
2 9862 Jun 4 Sunday Toyota Rural Monday Jun 4 Female Single
3 14037 Mar 2 Monday Mazda Urban Monday Mar 2 Male Single
4 10199 Jun 3 Friday Mazda Urban Tuesday Jun 4 Female Single
[12]:
bob_train_df = pd.read_csv(train_bob_path)
bob_train_df.head()
[12]:
UID Age Fault PolicyType VehicleCategory VehiclePrice PolicyNumber RepNumber Deductible DriverRating Days_Policy_Accident
0 2853 39 Policy Holder Sedan - All Perils Sedan 20000 to 29000 2854 8 400 2 more than 30
1 7261 58 Policy Holder Sedan - Liability Sport 20000 to 29000 7262 4 400 4 more than 30
2 9862 28 Policy Holder Sedan - All Perils Sedan less than 20000 9863 5 400 4 more than 30
3 14037 28 Policy Holder Sedan - Collision Sedan 20000 to 29000 14038 11 400 4 more than 30
4 10199 35 Policy Holder Sedan - Collision Sedan 20000 to 29000 10200 12 400 4 more than 30
[13]:
carol_train_df = pd.read_csv(train_carol_path)
carol_train_df.head()
[13]:
UID Days_Policy_Claim PastNumberOfClaims AgeOfVehicle AgeOfPolicyHolder PoliceReportFiled WitnessPresent AgentType NumberOfSuppliments AddressChange_Claim NumberOfCars Year BasePolicy FraudFound_P
0 2853 more than 30 1 7 years 36 to 40 No No External more than 5 no change 1 vehicle 1994 All Perils 0
1 7261 more than 30 none more than 7 51 to 65 No No External 1 to 2 no change 1 vehicle 1995 Liability 0
2 9862 more than 30 none 7 years 31 to 35 No No External none no change 1 vehicle 1995 All Perils 0
3 14037 more than 30 1 6 years 31 to 35 No No External none no change 1 vehicle 1996 Collision 0
4 10199 more than 30 none 5 years 31 to 35 No No Internal none no change 1 vehicle 1995 Collision 0

三方数据加载#

注意:这里的接口里面需要显示地指明用于多方之间样本对齐的key,以及明确使用何种设备来执行PSI。

[14]:
from secretflow.data.vertical import read_csv as v_read_csv

train_ds = v_read_csv(
    {alice: train_alice_path, bob: train_bob_path, carol: train_carol_path},
    keys='UID',
    drop_keys='UID',
    spu=my_spu,
)
test_ds = v_read_csv(
    {alice: test_alice_path, bob: test_bob_path, carol: test_carol_path},
    keys='UID',
    drop_keys='UID',
    spu=my_spu,
)
print(train_ds)
print(train_ds.columns)
(SPURuntime pid=2300421) 2023-05-24 14:55:50.244 [info] [bucket_psi.cc:Init:228] bucket size set to 1048576
(SPURuntime pid=2300421) 2023-05-24 14:55:50.269 [info] [bucket_psi.cc:Run:97] Begin sanity check for input file: data/carol_train.csv, precheck_switch:true
(SPURuntime pid=2300421) 2023-05-24 14:55:50.290 [info] [csv_checker.cc:CsvChecker:121] Executing duplicated scripts: LC_ALL=C sort --buffer-size=1G --temporary-directory=data --stable selected-keys.1684911350271260403 | LC_ALL=C uniq -d > duplicate-keys.1684911350271260403
(SPURuntime pid=2300421) 2023-05-24 14:55:50.320 [info] [bucket_psi.cc:Run:115] End sanity check for input file: data/carol_train.csv, size=10792
(SPURuntime pid=2300421) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:133] Skip doing psi, because dataset has been aligned!
(SPURuntime pid=2300421) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:178] Begin post filtering, indices.size=10792, should_sort=true
(SPURuntime pid=2300421) 2023-05-24 14:55:50.329 [info] [utils.cc:MultiKeySort:88] Executing sort scripts: tail -n +2 data/tmp-sort-in-1684911350322047349 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911350322047349
(SPURuntime pid=2300420) 2023-05-24 14:55:50.243 [info] [bucket_psi.cc:Init:228] bucket size set to 1048576
(SPURuntime pid=2300420) 2023-05-24 14:55:50.269 [info] [bucket_psi.cc:Run:97] Begin sanity check for input file: data/bob_train.csv, precheck_switch:true
(SPURuntime pid=2300420) 2023-05-24 14:55:50.281 [info] [csv_checker.cc:CsvChecker:121] Executing duplicated scripts: LC_ALL=C sort --buffer-size=1G --temporary-directory=data --stable selected-keys.1684911350269771195 | LC_ALL=C uniq -d > duplicate-keys.1684911350269771195
(SPURuntime pid=2300420) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:115] End sanity check for input file: data/bob_train.csv, size=10792
(SPURuntime pid=2300420) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:133] Skip doing psi, because dataset has been aligned!
(SPURuntime pid=2300420) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:178] Begin post filtering, indices.size=10792, should_sort=true
(SPURuntime pid=2300420) 2023-05-24 14:55:50.326 [info] [utils.cc:MultiKeySort:88] Executing sort scripts: tail -n +2 data/tmp-sort-in-1684911350321798883 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911350321798883
(SPURuntime pid=2300419) 2023-05-24 14:55:50.243 [info] [bucket_psi.cc:Init:228] bucket size set to 1048576
(SPURuntime pid=2300419) 2023-05-24 14:55:50.269 [info] [bucket_psi.cc:Run:97] Begin sanity check for input file: data/alice_train.csv, precheck_switch:true
(SPURuntime pid=2300419) 2023-05-24 14:55:50.288 [info] [csv_checker.cc:CsvChecker:121] Executing duplicated scripts: LC_ALL=C sort --buffer-size=1G --temporary-directory=data --stable selected-keys.1684911350269789933 | LC_ALL=C uniq -d > duplicate-keys.1684911350269789933
(SPURuntime pid=2300419) 2023-05-24 14:55:50.320 [info] [bucket_psi.cc:Run:115] End sanity check for input file: data/alice_train.csv, size=10792
(SPURuntime pid=2300419) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:133] Skip doing psi, because dataset has been aligned!
(SPURuntime pid=2300419) 2023-05-24 14:55:50.321 [info] [bucket_psi.cc:Run:178] Begin post filtering, indices.size=10792, should_sort=true
(SPURuntime pid=2300419) 2023-05-24 14:55:50.328 [info] [utils.cc:MultiKeySort:88] Executing sort scripts: tail -n +2 data/tmp-sort-in-1684911350321981931 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911350321981931
(SPURuntime pid=2300421) 2023-05-24 14:55:50.363 [info] [utils.cc:MultiKeySort:90] Finished sort scripts: tail -n +2 data/tmp-sort-in-1684911350322047349 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911350322047349, ret=0
(SPURuntime pid=2300421) 2023-05-24 14:55:50.363 [info] [bucket_psi.cc:Run:216] End post filtering, in=data/carol_train.csv, out=data/carol_train.csv.psi_output_85486
(SPURuntime pid=2300420) 2023-05-24 14:55:50.357 [info] [utils.cc:MultiKeySort:90] Finished sort scripts: tail -n +2 data/tmp-sort-in-1684911350321798883 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911350321798883, ret=0
(SPURuntime pid=2300420) 2023-05-24 14:55:50.357 [info] [bucket_psi.cc:Run:216] End post filtering, in=data/bob_train.csv, out=data/bob_train.csv.psi_output_85486
(SPURuntime pid=2300419) 2023-05-24 14:55:50.361 [info] [utils.cc:MultiKeySort:90] Finished sort scripts: tail -n +2 data/tmp-sort-in-1684911350321981931 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911350321981931, ret=0
(SPURuntime pid=2300419) 2023-05-24 14:55:50.362 [info] [bucket_psi.cc:Run:216] End post filtering, in=data/alice_train.csv, out=data/alice_train.csv.psi_output_85486
(SPURuntime pid=2300421) 2023-05-24 14:55:51.918 [info] [bucket_psi.cc:Init:228] bucket size set to 1048576
(SPURuntime pid=2300421) 2023-05-24 14:55:51.918 [info] [bucket_psi.cc:Run:97] Begin sanity check for input file: data/carol_test.csv, precheck_switch:true
(SPURuntime pid=2300421) 2023-05-24 14:55:51.925 [info] [csv_checker.cc:CsvChecker:121] Executing duplicated scripts: LC_ALL=C sort --buffer-size=1G --temporary-directory=data --stable selected-keys.1684911351918502537 | LC_ALL=C uniq -d > duplicate-keys.1684911351918502537
(SPURuntime pid=2300420) 2023-05-24 14:55:51.917 [info] [bucket_psi.cc:Init:228] bucket size set to 1048576
(SPURuntime pid=2300420) 2023-05-24 14:55:51.917 [info] [bucket_psi.cc:Run:97] Begin sanity check for input file: data/bob_test.csv, precheck_switch:true
(SPURuntime pid=2300420) 2023-05-24 14:55:51.923 [info] [csv_checker.cc:CsvChecker:121] Executing duplicated scripts: LC_ALL=C sort --buffer-size=1G --temporary-directory=data --stable selected-keys.1684911351917485080 | LC_ALL=C uniq -d > duplicate-keys.1684911351917485080
(SPURuntime pid=2300419) 2023-05-24 14:55:51.917 [info] [bucket_psi.cc:Init:228] bucket size set to 1048576
(SPURuntime pid=2300419) 2023-05-24 14:55:51.917 [info] [bucket_psi.cc:Run:97] Begin sanity check for input file: data/alice_test.csv, precheck_switch:true
(SPURuntime pid=2300419) 2023-05-24 14:55:51.923 [info] [csv_checker.cc:CsvChecker:121] Executing duplicated scripts: LC_ALL=C sort --buffer-size=1G --temporary-directory=data --stable selected-keys.1684911351917758713 | LC_ALL=C uniq -d > duplicate-keys.1684911351917758713
(SPURuntime pid=2300421) 2023-05-24 14:55:51.955 [info] [bucket_psi.cc:Run:115] End sanity check for input file: data/carol_test.csv, size=4626
(SPURuntime pid=2300421) 2023-05-24 14:55:51.956 [info] [bucket_psi.cc:Run:133] Skip doing psi, because dataset has been aligned!
(SPURuntime pid=2300421) 2023-05-24 14:55:51.956 [info] [bucket_psi.cc:Run:178] Begin post filtering, indices.size=4626, should_sort=true
(SPURuntime pid=2300421) 2023-05-24 14:55:51.959 [info] [utils.cc:MultiKeySort:88] Executing sort scripts: tail -n +2 data/tmp-sort-in-1684911351956618837 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911351956618837
(SPURuntime pid=2300421) 2023-05-24 14:55:51.986 [info] [utils.cc:MultiKeySort:90] Finished sort scripts: tail -n +2 data/tmp-sort-in-1684911351956618837 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911351956618837, ret=0
(SPURuntime pid=2300421) 2023-05-24 14:55:51.986 [info] [bucket_psi.cc:Run:216] End post filtering, in=data/carol_test.csv, out=data/carol_test.csv.psi_output_37067
(SPURuntime pid=2300420) 2023-05-24 14:55:51.955 [info] [bucket_psi.cc:Run:115] End sanity check for input file: data/bob_test.csv, size=4626
(SPURuntime pid=2300420) 2023-05-24 14:55:51.956 [info] [bucket_psi.cc:Run:133] Skip doing psi, because dataset has been aligned!
(SPURuntime pid=2300420) 2023-05-24 14:55:51.956 [info] [bucket_psi.cc:Run:178] Begin post filtering, indices.size=4626, should_sort=true
(SPURuntime pid=2300420) 2023-05-24 14:55:51.957 [info] [utils.cc:MultiKeySort:88] Executing sort scripts: tail -n +2 data/tmp-sort-in-1684911351956284333 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911351956284333
(SPURuntime pid=2300420) 2023-05-24 14:55:51.983 [info] [utils.cc:MultiKeySort:90] Finished sort scripts: tail -n +2 data/tmp-sort-in-1684911351956284333 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911351956284333, ret=0
(SPURuntime pid=2300420) 2023-05-24 14:55:51.983 [info] [bucket_psi.cc:Run:216] End post filtering, in=data/bob_test.csv, out=data/bob_test.csv.psi_output_37067
(SPURuntime pid=2300419) 2023-05-24 14:55:51.955 [info] [bucket_psi.cc:Run:115] End sanity check for input file: data/alice_test.csv, size=4626
(SPURuntime pid=2300419) 2023-05-24 14:55:51.956 [info] [bucket_psi.cc:Run:133] Skip doing psi, because dataset has been aligned!
(SPURuntime pid=2300419) 2023-05-24 14:55:51.956 [info] [bucket_psi.cc:Run:178] Begin post filtering, indices.size=4626, should_sort=true
(SPURuntime pid=2300419) 2023-05-24 14:55:51.958 [info] [utils.cc:MultiKeySort:88] Executing sort scripts: tail -n +2 data/tmp-sort-in-1684911351956327849 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911351956327849
(SPURuntime pid=2300419) 2023-05-24 14:55:51.983 [info] [utils.cc:MultiKeySort:90] Finished sort scripts: tail -n +2 data/tmp-sort-in-1684911351956327849 | LC_ALL=C sort --buffer-size=3G --parallel=8 --temporary-directory=./ --stable --field-separator=, --key=1,1 >>data/tmp-sort-out-1684911351956327849, ret=0
(SPURuntime pid=2300419) 2023-05-24 14:55:51.983 [info] [bucket_psi.cc:Run:216] End post filtering, in=data/alice_test.csv, out=data/alice_test.csv.psi_output_37067
VDataFrame(partitions={alice: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fd000234460>), bob: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fcfb9f9e310>), carol: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fcfb9f9ebb0>)}, aligned=True)
Index(['Month', 'WeekOfMonth', 'DayOfWeek', 'Make', 'AccidentArea',
       'DayOfWeekClaimed', 'MonthClaimed', 'WeekOfMonthClaimed', 'Sex',
       'MaritalStatus', 'Age', 'Fault', 'PolicyType', 'VehicleCategory',
       'VehiclePrice', 'PolicyNumber', 'RepNumber', 'Deductible',
       'DriverRating', 'Days_Policy_Accident', 'Days_Policy_Claim',
       'PastNumberOfClaims', 'AgeOfVehicle', 'AgeOfPolicyHolder',
       'PoliceReportFiled', 'WitnessPresent', 'AgentType',
       'NumberOfSuppliments', 'AddressChange_Claim', 'NumberOfCars', 'Year',
       'BasePolicy', 'FraudFound_P'],
      dtype='object')

数据洞察#

基于上层封装的VDataFrame抽象,隐语提供了多种数据分析的API,例如统计信息、查改某些列的信息等。

[15]:
print(train_ds['WeekOfMonth'].count())

print(train_ds['WeekOfMonth'].max())
print(train_ds['WeekOfMonth'].min())
WeekOfMonth    10792
dtype: int64
WeekOfMonth    5
dtype: int64
WeekOfMonth    1
dtype: int64

数据预处理#

在读取完数据之后,下面我们演示如何在隐语上对一个实际多方持有的数据进行数据预处理。

Label Encoder#

对无序且二值的值,我们可以使用label encoding,转化为0/1表示

[16]:
from secretflow.preprocessing import LabelEncoder

cols = [
    'AccidentArea',
    'Sex',
    'Fault',
    'PoliceReportFiled',
    'WitnessPresent',
    'AgentType',
]
for col in cols:
    print(f"Col name {col}: {df[col].unique()}")

train_ds_v1 = train_ds.copy()
test_ds_v1 = test_ds.copy()

label_encoder = LabelEncoder()
for col in cols:
    label_encoder.fit(train_ds_v1[col])
    train_ds_v1[col] = label_encoder.transform(train_ds_v1[col])
    test_ds_v1[col] = label_encoder.transform(test_ds_v1[col])
Col name AccidentArea: ['Urban' 'Rural']
Col name Sex: ['Female' 'Male']
Col name Fault: ['Policy Holder' 'Third Party']
Col name PoliceReportFiled: ['No' 'Yes']
Col name WitnessPresent: ['No' 'Yes']
Col name AgentType: ['External' 'Internal']

(Ordinal) Categorical Features#

对于有序的类别数据,我们构建映射,将类别数据转化为0~n-1的整数

[17]:
cols1 = [
    "Days_Policy_Accident",
    "Days_Policy_Claim",
    "AgeOfPolicyHolder",
    "AddressChange_Claim",
    "NumberOfCars",
]
col_disc = [
    {
        "Days_Policy_Accident": {
            "more than 30": 31,
            "15 to 30": 22.5,
            "none": 0,
            "1 to 7": 4,
            "8 to 15": 11.5,
        }
    },
    {
        "Days_Policy_Claim": {
            "more than 30": 31,
            "15 to 30": 22.5,
            "8 to 15": 11.5,
            "none": 0,
        }
    },
    {
        "AgeOfPolicyHolder": {
            "26 to 30": 28,
            "31 to 35": 33,
            "41 to 50": 45.5,
            "51 to 65": 58,
            "21 to 25": 23,
            "36 to 40": 38,
            "16 to 17": 16.5,
            "over 65": 66,
            "18 to 20": 19,
        }
    },
    {
        "AddressChange_Claim": {
            "1 year": 1,
            "no change": 0,
            "4 to 8 years": 6,
            "2 to 3 years": 2.5,
            "under 6 months": 0.5,
        }
    },
    {
        "NumberOfCars": {
            "3 to 4": 3.5,
            "1 vehicle": 1,
            "2 vehicles": 2,
            "5 to 8": 6.5,
            "more than 8": 9,
        }
    },
]

cols2 = [
    "Month",
    "DayOfWeek",
    "DayOfWeekClaimed",
    "MonthClaimed",
    "PastNumberOfClaims",
    "NumberOfSuppliments",
    "VehiclePrice",
    "AgeOfVehicle",
]
col_ordering = [
    {
        "Month": {
            "Jan": 1,
            "Feb": 2,
            "Mar": 3,
            "Apr": 4,
            "May": 5,
            "Jun": 6,
            "Jul": 7,
            "Aug": 8,
            "Sep": 9,
            "Oct": 10,
            "Nov": 11,
            "Dec": 12,
        }
    },
    {
        "DayOfWeek": {
            "Monday": 1,
            "Tuesday": 2,
            "Wednesday": 3,
            "Thursday": 4,
            "Friday": 5,
            "Saturday": 6,
            "Sunday": 7,
        }
    },
    {
        "DayOfWeekClaimed": {
            "Monday": 1,
            "Tuesday": 2,
            "Wednesday": 3,
            "Thursday": 4,
            "Friday": 5,
            "Saturday": 6,
            "Sunday": 7,
        }
    },
    {
        "MonthClaimed": {
            "Jan": 1,
            "Feb": 2,
            "Mar": 3,
            "Apr": 4,
            "May": 5,
            "Jun": 6,
            "Jul": 7,
            "Aug": 8,
            "Sep": 9,
            "Oct": 10,
            "Nov": 11,
            "Dec": 12,
        }
    },
    {"PastNumberOfClaims": {"none": 0, "1": 1, "2 to 4": 2, "more than 4": 5}},
    {"NumberOfSuppliments": {"none": 0, "1 to 2": 1, "3 to 5": 3, "more than 5": 6}},
    {
        "VehiclePrice": {
            "more than 69000": 69001,
            "20000 to 29000": 24500,
            "30000 to 39000": 34500,
            "less than 20000": 19999,
            "40000 to 59000": 49500,
            "60000 to 69000": 64500,
        }
    },
    {
        "AgeOfVehicle": {
            "3 years": 3,
            "6 years": 6,
            "7 years": 7,
            "more than 7": 8,
            "5 years": 5,
            "new": 0,
            "4 years": 4,
            "2 years": 2,
        }
    },
]

from secretflow.data.vertical import VDataFrame


def replace(df, col_maps):
    df = df.copy()

    def func_(df, col_map):
        col_name = list(col_map.keys())[0]
        col_dict = list(col_map.values())[0]
        if col_name not in df.columns:
            return
        new_list = []
        for i in df[col_name]:
            new_list.append(col_dict[i])
        df[col_name] = new_list

    for col_map in col_maps:
        func_(df, col_map)
    return df


col_maps = col_disc + col_ordering

train_ds_v2 = train_ds_v1.copy()
test_ds_v2 = test_ds_v1.copy()

# NOTE: Reveal is only used for demo only!!
print(f"orig ds in alice:\n {sf.reveal(train_ds_v2.partitions[alice].data)}")
train_ds_v2 = train_ds_v2.apply_func(replace, col_maps=col_maps)

print(f"orig ds in alice:\n {sf.reveal(train_ds_v2.partitions[alice].data)}")
test_ds_v2 = test_ds_v2.apply_func(replace, col_maps=col_maps)
orig ds in alice:
       Month  WeekOfMonth DayOfWeek       Make  AccidentArea DayOfWeekClaimed  \
0       Jan            1    Friday     Toyota             1        Wednesday
1       Jan            1    Monday    Pontiac             1           Monday
2       Dec            1    Friday     Toyota             1        Wednesday
3       Oct            2    Monday  Chevrolet             1           Monday
4       Sep            4   Tuesday    Pontiac             1          Tuesday
...     ...          ...       ...        ...           ...              ...
10787   Dec            3   Tuesday      Mazda             1        Wednesday
10788   Sep            3   Tuesday       Ford             1        Wednesday
10789   Aug            5  Thursday       Ford             1           Friday
10790   Feb            2  Thursday    Pontiac             1          Tuesday
10791   May            5    Monday  Chevrolet             1          Tuesday

      MonthClaimed  WeekOfMonthClaimed  Sex MaritalStatus
0              Jan                   4    1       Married
1              Jan                   3    1       Married
2              Dec                   2    1       Married
3              Oct                   2    1       Married
4              Sep                   4    0        Single
...            ...                 ...  ...           ...
10787          Dec                   3    1       Married
10788          Oct                   1    1       Married
10789          Sep                   5    1       Married
10790          Feb                   3    1        Single
10791          May                   5    1       Married

[10792 rows x 10 columns]
orig ds in alice:
        Month  WeekOfMonth  DayOfWeek       Make  AccidentArea  \
0          1            1          5     Toyota             1
1          1            1          1    Pontiac             1
2         12            1          5     Toyota             1
3         10            2          1  Chevrolet             1
4          9            4          2    Pontiac             1
...      ...          ...        ...        ...           ...
10787     12            3          2      Mazda             1
10788      9            3          2       Ford             1
10789      8            5          4       Ford             1
10790      2            2          4    Pontiac             1
10791      5            5          1  Chevrolet             1

       DayOfWeekClaimed  MonthClaimed  WeekOfMonthClaimed  Sex MaritalStatus
0                     3             1                   4    1       Married
1                     1             1                   3    1       Married
2                     3            12                   2    1       Married
3                     1            10                   2    1       Married
4                     2             9                   4    0        Single
...                 ...           ...                 ...  ...           ...
10787                 3            12                   3    1       Married
10788                 3            10                   1    1       Married
10789                 5             9                   5    1       Married
10790                 2             2                   3    1        Single
10791                 2             5                   5    1       Married

[10792 rows x 10 columns]

(Nominal) Categorical Features#

无序的类别数据,我们直接采用onehot encoder进行01编码

Onehot Encoder#

[18]:
from secretflow.preprocessing import OneHotEncoder

onehot_cols = ['Make', 'MaritalStatus', 'PolicyType', 'VehicleCategory', 'BasePolicy']

onehot_encoder = OneHotEncoder()
onehot_encoder.fit(train_ds_v2[onehot_cols])

enc_feats = onehot_encoder.transform(train_ds_v2[onehot_cols])
feature_names = enc_feats.columns
train_ds_v3 = train_ds_v2.drop(columns=onehot_cols)
train_ds_v3[feature_names] = enc_feats


enc_feats = onehot_encoder.transform(test_ds_v2[onehot_cols])
test_ds_v3 = test_ds_v2.drop(columns=onehot_cols)
test_ds_v3[feature_names] = enc_feats

print(f"orig ds in alice:\n {sf.reveal(train_ds_v3.partitions[alice].data)}")
orig ds in alice:
        Month  WeekOfMonth  DayOfWeek  AccidentArea  DayOfWeekClaimed  \
0          1            1          5             1                 3
1          1            1          1             1                 1
2         12            1          5             1                 3
3         10            2          1             1                 1
4          9            4          2             1                 2
...      ...          ...        ...           ...               ...
10787     12            3          2             1                 3
10788      9            3          2             1                 3
10789      8            5          4             1                 5
10790      2            2          4             1                 2
10791      5            5          1             1                 2

       MonthClaimed  WeekOfMonthClaimed  Sex  Make_Accura  Make_BMW  ...  \
0                 1                   4    1          0.0       0.0  ...
1                 1                   3    1          0.0       0.0  ...
2                12                   2    1          0.0       0.0  ...
3                10                   2    1          0.0       0.0  ...
4                 9                   4    0          0.0       0.0  ...
...             ...                 ...  ...          ...       ...  ...
10787            12                   3    1          0.0       0.0  ...
10788            10                   1    1          0.0       0.0  ...
10789             9                   5    1          0.0       0.0  ...
10790             2                   3    1          0.0       0.0  ...
10791             5                   5    1          0.0       0.0  ...

       Make_Pontiac  Make_Porche  Make_Saab  Make_Saturn  Make_Toyota  \
0               0.0          0.0        0.0          0.0          1.0
1               1.0          0.0        0.0          0.0          0.0
2               0.0          0.0        0.0          0.0          1.0
3               0.0          0.0        0.0          0.0          0.0
4               1.0          0.0        0.0          0.0          0.0
...             ...          ...        ...          ...          ...
10787           0.0          0.0        0.0          0.0          0.0
10788           0.0          0.0        0.0          0.0          0.0
10789           0.0          0.0        0.0          0.0          0.0
10790           1.0          0.0        0.0          0.0          0.0
10791           0.0          0.0        0.0          0.0          0.0

       Make_VW  MaritalStatus_Divorced  MaritalStatus_Married  \
0          0.0                     0.0                    1.0
1          0.0                     0.0                    1.0
2          0.0                     0.0                    1.0
3          0.0                     0.0                    1.0
4          0.0                     0.0                    0.0
...        ...                     ...                    ...
10787      0.0                     0.0                    1.0
10788      0.0                     0.0                    1.0
10789      0.0                     0.0                    1.0
10790      0.0                     0.0                    0.0
10791      0.0                     0.0                    1.0

       MaritalStatus_Single  MaritalStatus_Widow
0                       0.0                  0.0
1                       0.0                  0.0
2                       0.0                  0.0
3                       0.0                  0.0
4                       1.0                  0.0
...                     ...                  ...
10787                   0.0                  0.0
10788                   0.0                  0.0
10789                   0.0                  0.0
10790                   1.0                  0.0
10791                   0.0                  0.0

[10792 rows x 31 columns]
[19]:
train_ds_final = train_ds_v3.copy()
test_ds_final = test_ds_v3.copy()

X_train = train_ds_v3.drop(columns=['FraudFound_P'])
y_train = train_ds_final['FraudFound_P']
X_test = test_ds_final.drop(columns='FraudFound_P')
y_test = test_ds_final['FraudFound_P']

print("data load done")
data load done

数据对象转换#

此处我们将PYUObject 转化为 SPUObject,方便输入到SPU device执行基于MPC协议的隐私计算

[20]:
import jax
import jax.numpy as jnp

"""
Convert the VDataFrame object to SPUObject
"""


def vdataframe_to_spu(vdf: VDataFrame):
    spu_partitions = []
    for device in [alice, bob, carol]:
        spu_partitions.append(vdf.partitions[device].data.to(my_spu))
    base_partition = spu_partitions[0]
    for i in range(1, len(spu_partitions)):
        base_partition = my_spu(lambda x, y: jnp.concatenate([x, y], axis=1))(
            base_partition, spu_partitions[i]
        )
    return base_partition


X_train_spu = vdataframe_to_spu(X_train)
y_train_spu = y_train.partitions[carol].data.to(my_spu)
X_test_spu = vdataframe_to_spu(X_test)
y_test_spu = y_test.partitions[carol].data.to(my_spu)
print(f"X_train type: {X_train}\n\nX_train_spu type: {X_train_spu}")

"""
NOTE: This is only for demo only!! This shall not be used in production.
"""
X_train_plaintext = sf.reveal(X_train_spu)
y_train_plaintext = sf.reveal(y_train_spu)
X_test_plaintext = sf.reveal(X_test_spu)
y_test_plaintext = sf.reveal(y_test_spu)

print(f'X_train_plaintext: \n{X_train_plaintext}')
(_run pid=2294419) [2023-05-24 14:55:57.641] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2294389) [2023-05-24 14:55:57.703] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2294392) [2023-05-24 14:55:57.670] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(_run pid=2294392) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2294392) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294392) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294392) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2294392) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2294392) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2294389) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2294389) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294389) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294389) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2294389) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2294389) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2294562) [2023-05-24 14:55:57.944] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
X_train type: VDataFrame(partitions={alice: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fcfb9f7b0a0>), bob: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fcfb9f7b070>), carol: Partition(data=<secretflow.device.device.pyu.PYUObject object at 0x7fcfb9f61910>)}, aligned=True)

X_train_spu type: <secretflow.device.device.spu.SPUObject object at 0x7fcfb9f07610>
[2023-05-24 14:55:58.124] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
X_train_plaintext:
[[ 1.  1.  5. ...  1.  0.  0.]
 [ 1.  1.  1. ...  0.  0.  1.]
 [12.  1.  5. ...  0.  1.  0.]
 ...
 [ 8.  5.  4. ...  1.  0.  0.]
 [ 2.  2.  4. ...  0.  1.  0.]
 [ 5.  5.  1. ...  0.  1.  0.]]
(_run pid=2294562) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2294562) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294562) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
(_run pid=2294562) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2294562) INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
(_run pid=2294562) WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

模型构建#

在完成数据的读入之后,下面我们进行模型的构建。在本demo中,主要提供了三种模型的构建: - LR: 逻辑回归 - NN:神经网络模型 - XGB: XGBoost 树模型

注意,本示例主要是演示在隐语上进行算法开发的流程,并没有针对模型 (LR, NN) 进行调参。我们分别提供了明文和密文的计算结果,实验结果显示两者的输出是基本一致的,表明隐语的密态计算能够和明文计算保持精度一致。

LR ( jax ) using SPU#

[21]:
from jax.example_libraries import optimizers, stax
from jax.example_libraries.stax import (
    Conv,
    MaxPool,
    AvgPool,
    Flatten,
    Dense,
    Relu,
    Sigmoid,
    LogSoftmax,
    Softmax,
    BatchNorm,
)


def sigmoid(x):
    x = (x - jnp.min(x)) / (jnp.max(x) - jnp.min(x))
    return 1 / (1 + jnp.exp(-x))


# Outputs probability of a label being true.
def predict_lr(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


# Training loss is the negative log-likelihood of the training examples.
def loss_lr(W, b, inputs, targets):
    preds = predict_lr(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.mean(jnp.log(label_probs))


def train_step(W, b, X, y, learning_rate):
    loss_value, Wb_grad = jax.value_and_grad(loss_lr, (0, 1))(W, b, X, y)
    W -= learning_rate * Wb_grad[0]
    b -= learning_rate * Wb_grad[1]
    return loss_value, W, b


def fit(W, b, X, y, epochs=1, learning_rate=1e-2, batch_size=128):
    losses = jnp.array([])

    xs = jnp.array_split(X, len(X) / batch_size, axis=0)
    ys = jnp.array_split(y, len(y) / batch_size, axis=0)

    for _ in range(epochs):
        for batch_x, batch_y in zip(xs, ys):
            l, W, b = train_step(W, b, batch_x, batch_y, learning_rate=learning_rate)
            losses = jnp.append(losses, l)
    return losses, W, b
[22]:
from jax import random
import sys
import time
import logging

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().setLevel(logging.INFO)

from sklearn.metrics import roc_auc_score


# Hyperparameter
key = random.PRNGKey(42)
W = jax.random.normal(key, shape=(64,))
b = 0.0
epochs = 1
learning_rate = 1e-2
batch_size = 128

"""
CPU-version plaintext computation
"""
losses_cpu, W_cpu, b_cpu = fit(
    W,
    b,
    X_train_plaintext,
    y_train_plaintext,
    epochs=epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
)
y_pred_cpu = predict_lr(W_cpu, b_cpu, X_test_plaintext)
print(f"\033[31m(Jax LR CPU) auc: {roc_auc_score(y_test_plaintext, y_pred_cpu)}\033[0m")

"""
SPU-version secure computation
"""
W_, b_ = (
    sf.to(alice, W).to(my_spu),
    sf.to(alice, b).to(my_spu),
)
losses_spu, W_spu, b_spu = my_spu(
    fit,
    static_argnames=["epochs", "learning_rate", "batch_size"],
    num_returns_policy=sf.device.SPUCompilerNumReturnsPolicy.FROM_USER,
    user_specified_num_returns=3,
)(
    W_,
    b_,
    X_train_spu,
    y_train_spu,
    epochs=epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
)

y_pred_spu = my_spu(predict_lr)(W_spu, b_spu, X_test_spu)
y_pred = sf.reveal(y_pred_spu)
print(f"\033[31m(Jax LR SPU) auc: {roc_auc_score(y_test_plaintext, y_pred)}\033[0m")
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
INFO:jax._src.lib.xla_bridge:Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(Jax LR CPU) auc: 0.524559290927313
(Jax LR SPU) auc: 0.48416274742946

NN ( jax + flax ) using SPU#

[23]:
import sys

!{sys.executable} -m pip install flax==0.6.0 -q
WARNING: There was an error checking the latest version of pip.

[24]:
from typing import Sequence
import flax.linen as nn


class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x


FEATURES = [1]
flax_nn = MLP(FEATURES)


def predict(params, x):
    from typing import Sequence
    import flax.linen as nn

    class MLP(nn.Module):
        features: Sequence[int]

        @nn.compact
        def __call__(self, x):
            for feat in self.features[:-1]:
                x = nn.relu(nn.Dense(feat)(x))
            x = nn.Dense(self.features[-1])(x)
            return x

    FEATURES = [1]
    flax_nn = MLP(FEATURES)
    return flax_nn.apply(params, x)


def loss_func(params, x, y):
    preds = predict(params, x)
    label_probs = preds * y + (1 - preds) * (1 - y)
    return -jnp.mean(jnp.log(label_probs))


def train_auto_grad(X, y, params, batch_size=10, epochs=10, learning_rate=0.01):
    xs = jnp.array_split(X, len(X) / batch_size, axis=0)
    ys = jnp.array_split(y, len(y) / batch_size, axis=0)

    for _ in range(epochs):
        for batch_x, batch_y in zip(xs, ys):
            _, grads = jax.value_and_grad(loss_func)(params, batch_x, batch_y)
            params = jax.tree_util.tree_map(
                lambda p, g: p - learning_rate * g, params, grads
            )
    return params


epochs = 1
learning_rate = 1e-2
batch_size = 128

feature_dim = 64  # from the dataset
init_params = flax_nn.init(jax.random.PRNGKey(1), jnp.ones((batch_size, feature_dim)))

"""
CPU-version plaintext computation
"""
params = train_auto_grad(
    X_train_plaintext, y_train_plaintext, init_params, batch_size, epochs, learning_rate
)
y_pred = predict(params, X_test_plaintext)
print(f"\033[31m(Flax NN CPU) auc: {roc_auc_score(y_test_plaintext, y_pred)}\033[0m")

"""
SPU-version secure computation
"""
params_spu = sf.to(alice, init_params).to(my_spu)
params_spu = my_spu(
    train_auto_grad, static_argnames=['batch_size', 'epochs', 'learning_rate']
)(
    X_train_spu,
    y_train_spu,
    params_spu,
    batch_size=batch_size,
    epochs=epochs,
    learning_rate=learning_rate,
)
y_pred_spu = my_spu(predict)(params_spu, X_test_spu)
y_pred_ = sf.reveal(y_pred_spu)
print(f"\033[31m(Flax NN SPU) auc: {roc_auc_score(y_test_plaintext, y_pred_)}\033[0m")
(Flax NN CPU) auc: 0.5022025986877813
(Flax NN SPU) auc: 0.5022034401772514

XGB ( jax ) using SPU#

[25]:
from secretflow.ml.boost.ss_xgb_v import Xgb
import time
from sklearn.metrics import roc_auc_score

"""
SPU-version Secure computation
"""
xgb = Xgb(my_spu)
params = {
    # <<< !!! >>> change args to your test settings.
    # for more detail, see Xgb.train.__doc__
    'num_boost_round': 10,
    'max_depth': 4,
    'learning_rate': 0.05,
    'sketch_eps': 0.05,
    'objective': 'logistic',
    'reg_lambda': 1,
    'subsample': 0.75,
    'colsample_by_tree': 1,
    'base_score': 0.5,
}

start = time.time()
model = xgb.train(params, X_train, y_train)
print(f"train time: {time.time() - start}")

start = time.time()
spu_yhat = model.predict(X_test)
print(f"predict time: {time.time() - start}")

yhat = sf.reveal(spu_yhat)
print(f"\033[31m(SS-XGB) auc: {roc_auc_score(y_test_plaintext, yhat)}\033[0m")
INFO:root:Create proxy actor <class 'secretflow.ml.boost.ss_xgb_v.core.tree_worker.XgbTreeWorker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.boost.ss_xgb_v.core.tree_worker.XgbTreeWorker'> with party bob.
INFO:root:Create proxy actor <class 'secretflow.ml.boost.ss_xgb_v.core.tree_worker.XgbTreeWorker'> with party carol.
INFO:root:fragment_count 1
INFO:root:prepare time 0.3021812438964844s
INFO:root:global_setup time 2.1790771484375s
INFO:root:build & infeed bucket_map fragments [0, 0]
INFO:root:build & infeed bucket_map time 0.32753467559814453s
INFO:root:init_pred time 0.04900479316711426s
INFO:root:epoch 0 tree_setup time 0.11261296272277832s
(_spu_compile pid=2294562) /* error: missing value */
(_spu_compile pid=2294562) {}:task_name:_spu_compile
(_spu_compile pid=2294562) /* error: missing value */
(_spu_compile pid=2294562) {}:task_name:_spu_compile
INFO:root:fragment[0, 0] gradient sum time 0.46968507766723633s
(SPURuntime pid=2300421) 2023-05-24 14:57:37.774 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 63
(SPURuntime pid=2300420) 2023-05-24 14:57:37.772 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 63
(SPURuntime pid=2300419) 2023-05-24 14:57:37.777 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 63
INFO:root:level 0 time 0.8138599395751953s
INFO:root:fragment[0, 0] gradient sum time 0.42415881156921387s
INFO:root:level 1 time 0.7801706790924072s
INFO:root:fragment[0, 0] gradient sum time 0.5980944633483887s
INFO:root:level 2 time 1.0563251972198486s
INFO:root:fragment[0, 0] gradient sum time 0.5459139347076416s
INFO:root:level 3 time 0.9727473258972168s
INFO:root:epoch 0 time 3.961634874343872s
(XgbTreeWorker pid=2312009) [2023-05-24 14:57:41.177] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(XgbTreeWorker pid=2312010) [2023-05-24 14:57:41.195] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(XgbTreeWorker pid=2312013) [2023-05-24 14:57:41.212] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
INFO:root:epoch 1 tree_setup time 0.1868457794189453s
(_spu_compile pid=2294389) /* error: missing value */
(_spu_compile pid=2294389) {}:task_name:_spu_compile
INFO:root:fragment[0, 0] gradient sum time 0.4389340877532959s
INFO:root:level 0 time 0.7621815204620361s
INFO:root:fragment[0, 0] gradient sum time 0.4361765384674072s
INFO:root:level 1 time 0.7648963928222656s
INFO:root:fragment[0, 0] gradient sum time 0.46132969856262207s
INFO:root:level 2 time 0.8321614265441895s
INFO:root:fragment[0, 0] gradient sum time 0.5284152030944824s
INFO:root:level 3 time 0.9844522476196289s
INFO:root:epoch 1 time 3.611001968383789s
INFO:root:epoch 2 tree_setup time 0.11684489250183105s
INFO:root:fragment[0, 0] gradient sum time 0.3915536403656006s
INFO:root:level 0 time 0.7235503196716309s
INFO:root:fragment[0, 0] gradient sum time 0.42185139656066895s
INFO:root:level 1 time 0.7636382579803467s
INFO:root:fragment[0, 0] gradient sum time 0.46842050552368164s
INFO:root:level 2 time 0.8526918888092041s
INFO:root:fragment[0, 0] gradient sum time 0.5360352993011475s
INFO:root:level 3 time 0.9925544261932373s
INFO:root:epoch 2 time 3.6581156253814697s
INFO:root:epoch 3 tree_setup time 0.17425537109375s
(_spu_compile pid=2294392) /* error: missing value */
(_spu_compile pid=2294392) {}:task_name:_spu_compile
INFO:root:fragment[0, 0] gradient sum time 0.4217956066131592s
INFO:root:level 0 time 0.7510805130004883s
INFO:root:fragment[0, 0] gradient sum time 0.4604473114013672s
INFO:root:level 1 time 0.9107728004455566s
INFO:root:fragment[0, 0] gradient sum time 0.46227431297302246s
INFO:root:level 2 time 0.8462221622467041s
INFO:root:fragment[0, 0] gradient sum time 0.5579540729522705s
INFO:root:level 3 time 1.005491018295288s
INFO:root:epoch 3 time 3.8129096031188965s
INFO:root:epoch 4 tree_setup time 0.16765856742858887s
(_spu_compile pid=2294419) /* error: missing value */
(_spu_compile pid=2294419) {}:task_name:_spu_compile
INFO:root:fragment[0, 0] gradient sum time 0.42427778244018555s
INFO:root:level 0 time 0.6625769138336182s
INFO:root:fragment[0, 0] gradient sum time 0.4184889793395996s
INFO:root:level 1 time 0.7157876491546631s
INFO:root:fragment[0, 0] gradient sum time 0.4058997631072998s
INFO:root:level 2 time 0.7082300186157227s
INFO:root:fragment[0, 0] gradient sum time 0.5008931159973145s
INFO:root:level 3 time 0.8738894462585449s
INFO:root:epoch 4 time 3.195892333984375s
INFO:root:epoch 5 tree_setup time 0.14134454727172852s
INFO:root:fragment[0, 0] gradient sum time 0.400066614151001s
INFO:root:level 0 time 0.641761064529419s
INFO:root:fragment[0, 0] gradient sum time 0.3746778964996338s
INFO:root:level 1 time 0.6471817493438721s
INFO:root:fragment[0, 0] gradient sum time 0.42211461067199707s
INFO:root:level 2 time 0.72641921043396s
INFO:root:fragment[0, 0] gradient sum time 0.46430206298828125s
INFO:root:level 3 time 0.8361771106719971s
INFO:root:epoch 5 time 3.0644524097442627s
INFO:root:epoch 6 tree_setup time 0.11562013626098633s
INFO:root:fragment[0, 0] gradient sum time 0.3859133720397949s
INFO:root:level 0 time 0.635554313659668s
INFO:root:fragment[0, 0] gradient sum time 0.41338205337524414s
INFO:root:level 1 time 0.6643311977386475s
INFO:root:fragment[0, 0] gradient sum time 0.3950026035308838s
INFO:root:level 2 time 0.6798880100250244s
INFO:root:fragment[0, 0] gradient sum time 0.45518994331359863s
INFO:root:level 3 time 0.8225488662719727s
INFO:root:epoch 6 time 3.0115859508514404s
INFO:root:epoch 7 tree_setup time 0.11349105834960938s
INFO:root:fragment[0, 0] gradient sum time 0.37856125831604004s
INFO:root:level 0 time 0.6111841201782227s
INFO:root:fragment[0, 0] gradient sum time 0.4059903621673584s
INFO:root:level 1 time 0.7024598121643066s
INFO:root:fragment[0, 0] gradient sum time 0.421400785446167s
INFO:root:level 2 time 0.7304990291595459s
INFO:root:fragment[0, 0] gradient sum time 0.4517638683319092s
INFO:root:level 3 time 0.8225312232971191s
INFO:root:epoch 7 time 3.069700002670288s
INFO:root:epoch 8 tree_setup time 0.1077120304107666s
INFO:root:fragment[0, 0] gradient sum time 0.3767871856689453s
INFO:root:level 0 time 0.6241645812988281s
INFO:root:fragment[0, 0] gradient sum time 0.374847412109375s
INFO:root:level 1 time 0.6185367107391357s
INFO:root:fragment[0, 0] gradient sum time 0.40407586097717285s
INFO:root:level 2 time 0.7317712306976318s
INFO:root:fragment[0, 0] gradient sum time 0.45026302337646484s
INFO:root:level 3 time 0.8427674770355225s
INFO:root:epoch 8 time 3.035346269607544s
INFO:root:epoch 9 tree_setup time 0.13258838653564453s
INFO:root:fragment[0, 0] gradient sum time 0.4179399013519287s
INFO:root:level 0 time 0.6501905918121338s
INFO:root:fragment[0, 0] gradient sum time 0.38472986221313477s
INFO:root:level 1 time 0.6701993942260742s
INFO:root:fragment[0, 0] gradient sum time 0.4303562641143799s
INFO:root:level 2 time 0.7147767543792725s
INFO:root:fragment[0, 0] gradient sum time 0.46015405654907227s
INFO:root:level 3 time 0.8475117683410645s
INFO:root:epoch 9 time 3.0256881713867188s
INFO:root:Create proxy actor <class 'secretflow.ml.boost.ss_xgb_v.core.tree_worker.XgbTreeWorker'> with party alice.
INFO:root:Create proxy actor <class 'secretflow.ml.boost.ss_xgb_v.core.tree_worker.XgbTreeWorker'> with party bob.
INFO:root:Create proxy actor <class 'secretflow.ml.boost.ss_xgb_v.core.tree_worker.XgbTreeWorker'> with party carol.
train time: 37.70984649658203
(XgbTreeWorker pid=2316828) [2023-05-24 14:58:14.123] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(XgbTreeWorker pid=2316834) [2023-05-24 14:58:14.274] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
(XgbTreeWorker pid=2316829) [2023-05-24 14:58:14.254] [info] [thread_pool.cc:30] Create a fixed thread pool with size 63
predict time: 3.1688623428344727
(SS-XGB) auc: 0.8239633480846438
[26]:
"""
Plaintext baseline
"""
import xgboost as SKxgb

params = {
    # <<< !!! >>> change args to your test settings.
    # for more detail, see Xgb.train.__doc__
    "n_estimators": 10,
    "max_depth": 4,
    'eval_metric': 'auc',
    "learning_rate": 0.05,
    "sketch_eps": 0.05,
    "objective": "binary:logistic",
    "reg_lambda": 1,
    "subsample": 0.75,
    "colsample_by_tree": 1,
    "base_score": 0.5,
}
dtrain = SKxgb.DMatrix(X_train_plaintext, label=y_train_plaintext)
bst = SKxgb.train(params, dtrain, params["n_estimators"])
dtest = SKxgb.DMatrix(X_test_plaintext)
y_pred = bst.predict(dtest)
print(f"\033[31m(Sklearn-XGB) auc: {roc_auc_score(y_test_plaintext, y_pred)}\033[0m")
[14:58:15] WARNING: ../src/learner.cc:767:
Parameters: { "n_estimators", "sketch_eps" } are not used.

(Sklearn-XGB) auc: 0.8231883362827539

The End#

显示地调用sf.shutdown()关闭实例化的集群。 > 注意:如果是在.py文件中运行代码,不需要显示地执行shutdown,在程序进程运行结束后会隐式地执行shutdown函数。

[27]:
sf.shutdown()

小结一下#

  • 介绍了如何针对一个实际场景的应用,在隐语上进行开发,提供隐私保护的能力

  • 隐语上的数据加载、预处理、建模、训练流程

  • 下一步,自己实现任意的计算(jax实现的计算),对于TF,pytorch的支持WIP