SecretFlow 组件指南#

获取组件列表#

Python API#

可以通过下列方式检查 SecretFlow 组件列表:

from secretflow.component.entry import COMP_LIST

COMP_LIST 是 CompListDef 实例。

CLI#

检查当前的 SecretFlow 版本#

$ secretflow -v
SecretFlow version 0.8.3b1.

列出所有组件#

$ secretflow component ls
DOMAIN                                   NAME                                     VERSION
---------------------------------------------------------------------------------------------------------
feature                                  vert_woe_binning                         0.0.1
feature                                  vert_woe_substitution                    0.0.1
ml.eval                                  biclassification_eval                    0.0.1
ml.eval                                  prediction_bias_eval                     0.0.1
ml.eval                                  ss_pvalue                                0.0.1
ml.predict                               sgb_predict                              0.0.1
ml.predict                               ss_sgd_predict                           0.0.1
ml.predict                               ss_xgb_predict                           0.0.1
ml.train                                 sgb_train                                0.0.1
ml.train                                 ss_sgd_train                             0.0.1
ml.train                                 ss_xgb_train                             0.0.1
preprocessing                            feature_filter                           0.0.1
preprocessing                            psi                                      0.0.1
preprocessing                            train_test_split                         0.0.1
stats                                    ss_pearsonr                              0.0.1
stats                                    ss_vif                                   0.0.1
stats                                    table_statistics                         0.0.1

获取组件的定义#

您必须用以下格式指定一个组件:domain/name:version。

比如我们开查看 PSI 的组件定义。

$ secretflow component inspect preprocessing/psi:0.0.1
You are inspecting definition of component with id [preprocessing/psi:0.0.1].
---------------------------------------------------------------------------------------------------------
{
"domain": "preprocessing",
"name": "psi",
"desc": "PSI between two parties.",
"version": "0.0.1",
"attrs": [
    {
    "name": "protocol",
    "desc": "PSI protocol.",
    "type": "AT_STRING",
    "atomic": {
        "isOptional": true,
        "defaultValue": {
        "s": "ECDH_PSI_2PC"
        },
        "allowedValues": {
        "ss": [
            "ECDH_PSI_2PC",
            "KKRT_PSI_2PC",
            "BC22_PSI_2PC"
        ]
        }
    }
    },
    {
    "name": "bucket_size",
    "desc": "Specify the hash bucket size used in PSI. Larger values consume more memory.",
    "type": "AT_INT",
    "atomic": {
        "isOptional": true,
        "defaultValue": {
        "i64": "1048576"
        },
        "lowerBoundEnabled": true,
        "lowerBound": {}
    }
    },
    {
    "name": "ecdh_curve_type",
    "desc": "Curve type for ECDH PSI.",
    "type": "AT_STRING",
    "atomic": {
        "isOptional": true,
        "defaultValue": {
        "s": "CURVE_FOURQ"
        },
        "allowedValues": {
        "ss": [
            "CURVE_25519",
            "CURVE_FOURQ",
            "CURVE_SM2",
            "CURVE_SECP256K1"
        ]
        }
    }
    }
],
"inputs": [
    {
    "name": "receiver_input",
    "desc": "Individual table for receiver",
    "types": [
        "sf.table.individual"
    ],
    "attrs": [
        {
        "name": "key",
        "desc": "Column(s) used to join. If not provided, ids of the dataset will be used."
        }
    ]
    },
    {
    "name": "sender_input",
    "desc": "Individual table for sender",
    "types": [
        "sf.table.individual"
    ],
    "attrs": [
        {
        "name": "key",
        "desc": "Column(s) used to join. If not provided, ids of the dataset will be used."
        }
    ]
    }
],
"outputs": [
    {
    "name": "psi_output",
    "desc": "Output vertical table",
    "types": [
        "sf.table.vertical_table"
    ]
    }
]
}

您可以通过以下方式一次性检查所有组件:

$ secretflow component inspect -a
...

您可以通过以下方式将列表保存到文件:

$ secretflow component inspect -a -f output.json
You are inspecting the compelete comp list.
---------------------------------------------------------------------------------------------------------
Saved to output.json.

执行节点#

Python API#

在下面的例子中,我们展示如何用 Python API 来执行节点。

我们将用一个极小的数据集来检测 PSI 组件。

  1. 将以下 bash 脚本保存为 generate_csv.sh

#!/bin/bash

set -e
show_help() {
    echo "Usage: bash generate_csv.sh -c {col_name} -p {file_name}"
    echo "  -c"
    echo "          the column name of id."
    echo "  -p"
    echo "          the path of output csv."
}
if [[ "$#" -lt 1 ]]; then
    show_help
    exit
fi

while getopts ":c:p:" OPTION; do
    case $OPTION in
    c)
        COL_NAME=$OPTARG
        ;;
    p)
        FILE_PATH=$OPTARG
        ;;
    *)
        echo "Incorrect options provided"
        exit 1
        ;;
    esac
done


# header
echo $COL_NAME > $FILE_PATH

# generate 800 random int
for ((i=0; i<800; i++))
do
# from 0 to 1000
id=$(shuf -i 0-1000 -n 1)

# check duplicates
while grep -q "^$id$" $FILE_PATH
do
    id=$(shuf -i 0-1000 -n 1)
done

# write
echo "$id" >> $FILE_PATH
done

echo "Generated csv file is $FILE_PATH."

然后为两方产生输入。

mkdir -p /tmp/alice
sh generate_csv.sh -c id1 -p /tmp/alice/input.csv

mkdir -p /tmp/bob
sh generate_csv.sh -c id2 -p /tmp/bob/input.csv
  1. 将以下 Python 代码保存为 psi_demo.py

import json

from secretflow.component.entry import comp_eval
from secretflow.spec.extend.cluster_pb2 import (
    SFClusterConfig,
    SFClusterDesc,
)
from secretflow.spec.v1.component_pb2 import Attribute
from secretflow.spec.v1.data_pb2 import (
    DistData,
    TableSchema,
    IndividualTable,
    StorageConfig,
)
from secretflow.spec.v1.evaluation_pb2 import NodeEvalParam
import click


@click.command()
@click.argument("party", type=str)
def run(party: str):
    desc = SFClusterDesc(
        parties=["alice", "bob"],
        devices=[
            SFClusterDesc.DeviceDesc(
                name="spu",
                type="spu",
                parties=["alice", "bob"],
                config=json.dumps(
                    {
                        "runtime_config": {"protocol": "REF2K", "field": "FM64"},
                        "link_desc": {
                            "connect_retry_times": 60,
                            "connect_retry_interval_ms": 1000,
                            "brpc_channel_protocol": "http",
                            "brpc_channel_connection_type": "pooled",
                            "recv_timeout_ms": 1200 * 1000,
                            "http_timeout_ms": 1200 * 1000,
                        },
                    }
                ),
            ),
            SFClusterDesc.DeviceDesc(
                name="heu",
                type="heu",
                parties=[],
                config=json.dumps(
                    {
                        "mode": "PHEU",
                        "schema": "paillier",
                        "key_size": 2048,
                    }
                ),
            ),
        ],
    )

    sf_cluster_config = SFClusterConfig(
        desc=desc,
        public_config=SFClusterConfig.PublicConfig(
            ray_fed_config=SFClusterConfig.RayFedConfig(
                parties=["alice", "bob"],
                addresses=[
                    "127.0.0.1:61041",
                    "127.0.0.1:61042",
                ],
            ),
            spu_configs=[
                SFClusterConfig.SPUConfig(
                    name="spu",
                    parties=["alice", "bob"],
                    addresses=[
                        "127.0.0.1:61045",
                        "127.0.0.1:61046",
                    ],
                )
            ],
        ),
        private_config=SFClusterConfig.PrivateConfig(
            self_party=party,
            ray_head_addr="local",  # local means setup a Ray cluster instead connecting to an existed one.
        ),
    )

    # check https://www.secretflow.org.cn/docs/spec/latest/zh-Hans/intro#nodeevalparam for details.
    sf_node_eval_param = NodeEvalParam(
        domain="preprocessing",
        name="psi",
        version="0.0.1",
        attr_paths=[
            "protocol",
            "sort",
            "bucket_size",
            "ecdh_curve_type",
            "input/receiver_input/key",
            "input/sender_input/key",
        ],
        attrs=[
            Attribute(s="ECDH_PSI_2PC"),
            Attribute(b=True),
            Attribute(i64=1048576),
            Attribute(s="CURVE_FOURQ"),
            Attribute(ss=["id1"]),
            Attribute(ss=["id2"]),
        ],
        inputs=[
            DistData(
                name="receiver_input",
                type="sf.table.individual",
                data_refs=[
                    DistData.DataRef(uri="input.csv", party="alice", format="csv"),
                ],
            ),
            DistData(
                name="sender_input",
                type="sf.table.individual",
                data_refs=[
                    DistData.DataRef(uri="input.csv", party="bob", format="csv"),
                ],
            ),
        ],
        output_uris=[
            "output.csv",
        ],
    )

    sf_node_eval_param.inputs[0].meta.Pack(
        IndividualTable(
            schema=TableSchema(
                id_types=["str"],
                ids=["id1"],
            ),
            line_count=-1,
        ),
    )

    sf_node_eval_param.inputs[1].meta.Pack(
        IndividualTable(
            schema=TableSchema(
                id_types=["str"],
                ids=["id2"],
            ),
            line_count=-1,
        ),
    )

    storage_config = StorageConfig(
        type="local_fs",
        local_fs=StorageConfig.LocalFSConfig(wd=f"/tmp/{party}"),
    )

    res = comp_eval(sf_node_eval_param, storage_config, sf_cluster_config)

    print(f'Node eval res is \n{res}')


if __name__ == "__main__":
    run()
  1. 在两个终端中执行

你应该在两个终端中看到以下输出:

Node eval res is
outputs {
  name: "output.csv"
  type: "sf.table.vertical_table"
  system_info {
  }
  meta {
    type_url: "type.googleapis.com/secretflow.spec.v1.VerticalTable"
    value: "\n\n\n\003id1\"\003str\n\n\n\003id2\"\003str\020\211\005"
  }
  data_refs {
    uri: "output.csv"
    party: "alice"
    format: "csv"
  }
  data_refs {
    uri: "output.csv"
    party: "bob"
    format: "csv"
  }
}
  1. /tmp/alice/output.csv/tmp/bob/output.csv 中检查结果。两个文件的内容除了表头应该是一致的。

CLI#

您也可以使用 SecretFlow CLI 来执行节点。

$ secretflow component run --log_file={log_file} --result_file={result_file_path} --eval_param={encoded_eval_param} --storage={encoded_storage_config} --cluster={encoded_cluster_def}
  • log_file: 日志文件路径

  • result_file: 结果文件路径。

  • eval_param: base64编码的 NodeEvalParam prototext。

  • storage: Base64编码的StorageConfig prototext。

  • cluster: Base64编码的 SFClusterConfig prototext。

因为你必须将 prototext 编码之后才能使用 CLI , 因此我们不推荐你用 CLI 的方式做节点执行。

创建组件#

Python API#

如果你想要在 SecretFlow 中创建一个新的组件,你可以参考 secretflow/component/preprocessing/train_test_split.py 。这几乎是最简单的组件。

构建一个 SecretFlow 组件的简要步骤如下:

  1. 在 secretflow/component/ 目录下创建一个新的文件。

  2. 使用 secretflow.component.component.Component 创建一个组件类:

from secretflow.component.component import Component

train_test_split_comp = Component(
    "train_test_split",
    domain="preprocessing",
    version="0.0.1",
    desc="""Split datasets into random train and test subsets.
    Please check: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
    """,
)
  1. 定义属性和输入输出。

from secretflow.component.component import IoType
from secretflow.component.data_utils import DistDataType

train_test_split_comp.float_attr(
    name="train_size",
    desc="Proportion of the dataset to include in the train subset.",
    is_list=False,
    is_optional=True,
    default_value=0.75,
    allowed_values=None,
    lower_bound=0.0,
    upper_bound=1.0,
    lower_bound_inclusive=True,
    upper_bound_inclusive=True,
)
train_test_split_comp.float_attr(
    name="test_size",
    desc="Proportion of the dataset to include in the test subset.",
    is_list=False,
    is_optional=True,
    default_value=0.25,
    allowed_values=None,
    lower_bound=0.0,
    upper_bound=1.0,
    lower_bound_inclusive=True,
    upper_bound_inclusive=True,
)
train_test_split_comp.int_attr(
    name="random_state",
    desc="Specify the random seed of the shuffling.",
    is_list=False,
    is_optional=True,
    default_value=1234,
)
train_test_split_comp.bool_attr(
    name="shuffle",
    desc="Whether to shuffle the data before splitting.",
    is_list=False,
    is_optional=True,
    default_value=True,
)
train_test_split_comp.io(
    io_type=IoType.INPUT,
    name="input_data",
    desc="Input dataset.",
    types=[DistDataType.VERTICAL_TABLE],
    col_params=None,
)
train_test_split_comp.io(
    io_type=IoType.OUTPUT,
    name="train",
    desc="Output train dataset.",
    types=[DistDataType.VERTICAL_TABLE],
    col_params=None,
)
train_test_split_comp.io(
    io_type=IoType.OUTPUT,
    name="test",
    desc="Output test dataset.",
    types=[DistDataType.VERTICAL_TABLE],
    col_params=None,
)
  1. 定义执行函数。

from secretflow.spec.v1.data_pb2 import DistData

# Signature of eval_fn must be
#  func(*, ctx, attr_0, attr_1, ..., input_0, input_1, ..., output_0, output_1, ...) -> typing.Dict[str, DistData]
# All the arguments are keyword-only, so orders don't matter.
@train_test_split_comp.eval_fn
def train_test_split_eval_fn(
    *, ctx, train_size, test_size, random_state, shuffle, input_data, train, test
):
    # Please check more examples to learn component utils.
    # ctx includes some parsed cluster def and other useful meta.

    # The output of eval_fn is a map of DistDatas of which keys are output names.
    return {"train": DistData(), "test": DistData()}
  1. 将你的新组件加入到 secretflow.component.entry 的 ALL_COMPONENTS 中。