Source code for secretflow.ml.linear.hess_sgd.model

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Tuple, Union, List

import jax.numpy as jnp
import numpy as np

from secretflow.data import FedNdarray, PartitionWay
from secretflow.data.vertical import VDataFrame
from secretflow.device import HEU, SPU, PYUObject, wait, SPUObject
from secretflow.device.device.type_traits import spu_fxp_precision
from secretflow.utils.sigmoid import SigType, sigmoid
from secretflow.ml.linear.linear_model import RegType, LinearModel


# hess-lr
[docs]class HESSLogisticRegression: """This method provides logistic regression linear models for vertical split dataset setting by using secret sharing and homomorphic encryption with mini batch SGD training solver. HESS-SGD is short for HE & secret sharing SGD training. During the calculation process, the HEU is used to protect the weights and calculate the predicted y, and the SPU is used to calculate the sigmoid and gradient. SPU is a verifiable and measurable secure computing device that running under various MPC protocols to provide provable security. More detail: https://spu.readthedocs.io/en/beta/ HEU is a secure computing device that implementing HE encryption and decryption, and provides matrix operations similar to the numpy, reducing the threshold for use. More detail: https://heu.readthedocs.io/en/latest/ For more detail, please refer to paper in KDD'21: https://dl.acm.org/doi/10.1145/3447548.3467210 Args: spu : SPU SPU device. heu_x : HEU HEU device without label. heu_y : HEU HEU device with label. Notes: training dataset should be normalized or standardized, otherwise the SGD solver will not converge. """
[docs] def __init__(self, spu: SPU, heu_x: HEU, heu_y: HEU) -> None: assert ( isinstance(spu, SPU) and len(spu.actors) == 2 ), f"two parties support only" assert ( isinstance(heu_x, HEU) and heu_x.sk_keeper_name() in spu.actors ), f"heu must be colocated with spu" assert ( isinstance(heu_y, HEU) and heu_y.sk_keeper_name() in spu.actors ), f"heu must be colocated with spu" assert ( heu_x.sk_keeper_name() != heu_y.sk_keeper_name() ), f"two heu must be provided" assert ( # This type should keep same with the return type of # spu.Io.make_shares(). Since make_shares() always return DT_I32, # HEU should work in DT_I32 mode heu_x.cleartext_type == "DT_I32" and heu_y.cleartext_type == "DT_I32" ), f"Heu encoding config must set to DT_I32" self._spu = spu self._heu_x = heu_x self._heu_y = heu_y self._scale = spu_fxp_precision(spu.cluster_def['runtime_config']['field'])
def _data_check( self, x: Union[FedNdarray, VDataFrame] ) -> Tuple[PYUObject, PYUObject]: assert isinstance( x, (FedNdarray, VDataFrame) ), f"x should be FedNdarray or VDataFrame" assert len(x.partitions) == 2, f"x should be in two parties" x = x.values if isinstance(x, VDataFrame) else x assert ( x.partition_way == PartitionWay.VERTICAL ), "only support vertical dataset in HESS LR" x_shapes = list(x.partition_shape().values()) assert x_shapes[0][0] > x_shapes[0][1] and x_shapes[1][0] > x_shapes[1][1], ( "samples is too small: ", "1. Model will not converge; 2.Y label may leak to other parties.", ) x1, x2 = None, None for device, partition in x.partitions.items(): if device.party == self._heu_x.sk_keeper_name(): x1 = partition elif device.party == self._heu_y.sk_keeper_name(): x2 = partition else: raise ValueError(f"unexpected x's partition party {device.party}") return x1, x2 def _args_check(self, x, y) -> Tuple[PYUObject, PYUObject, PYUObject]: x1, x2 = self._data_check(x) assert isinstance( y, (FedNdarray, VDataFrame) ), f"y should be FedNdarray or VDataFrame" assert len(y.partitions) == 1, f"y should be in single party" assert x2.device in y.partitions, f"y should be colocated with x" y = y.values if isinstance(y, VDataFrame) else y return x1, x2, y.partitions[x2.device]
[docs] def fit( self, x: Union[FedNdarray, VDataFrame], y: Union[FedNdarray, VDataFrame], learning_rate=1e-3, epochs=1, batch_size=None, ): """Fit linear model with Stochastic Gradient Descent. Args: x : {FedNdarray, VDataFrame} Input data, must be colocated with SPU. y : {FedNdarray, VDataFrame} Target data, must be located on `self._heu_y`. learning_rate : float, default=1e-3. Learning rate. epochs : int, default=1 Number of epochs to train the model batch_size : int, default=None Number of samples per gradient update. If None, batch_size will default to number of all samples. """ assert ( isinstance(learning_rate, float) and learning_rate > 0 ), f"learning_rate shoule be float > 0" assert isinstance(epochs, int) and epochs > 0, f"epochs should be integer > 0" assert batch_size is None or ( isinstance(batch_size, int) and batch_size > 0 ), f"batch_size should be None or integer > 0" x1, x2, y = self._args_check(x, y) x_shape = x.partition_shape() dim1, dim2 = x_shape[x1.device][1], x_shape[x2.device][1] n_samples = x_shape[x1.device][0] batch_size = n_samples if batch_size is None else min(n_samples, batch_size) steps_per_epoch = (n_samples + batch_size - 1) // batch_size # dataset and w is scaled by driver, so devices are only aware of int type w1 = x1.device(np.zeros)(dim1 + 1, np.int_) # intercept w2 = x2.device(np.zeros)(dim2 + 1, np.int_) # intercept hw1 = w1.to(self._heu_y).encrypt() hw2 = w2.to(self._heu_x).encrypt() for epoch in range(epochs): for step in range(steps_per_epoch): logging.info(f"epoch {epoch}, step: {step} begin") batch_x1, batch_x2, batch_y = self._next_batch( x1, x2, y, step, batch_size ) # step 1. encoding features and label to fixed point integer scale = 1 << self._scale x1_ = self._encode(batch_x1, scale, True).to(self._heu_y) x2_ = self._encode(batch_x2, scale, True).to(self._heu_x) # NOTE: Gradient descent weight update: # W -= learning_rate * ((h - y) @ X^T) # Since we do all calculations with fixed point integers, here we # fuse feature scale and learning_rate to avoid encoding learning_rate. x1_t = self._encode(batch_x1, learning_rate * scale, True).to( self._heu_y ) x2_t = self._encode(batch_x2, learning_rate * scale, True).to( self._heu_x ) y_ = self._encode(batch_y, scale, False).to(self._spu) # step 2. calculate loss with sigmoid minimax approximation p1, p2 = (x1_ @ hw1).to(self._spu), (x2_ @ hw2).to(self._spu) e = self._spu(self._sigmoid_minimax)(p1, p2, y_, self._scale) # step 3. update weights by gradient descent e1, e2 = e.to(self._heu_y), e.to(self._heu_x) hw1 -= e1 @ x1_t hw2 -= e2 @ x2_t # NOTE: Due to Ray's asynchronous scheduling, batches in different steps # maybe be encrypted simultaneous, which will cause memory flooding. wait([e]) self._w = self._spu(self._truncate, static_argnames=('scale'))( hw1.to(self._spu), hw2.to(self._spu), scale=self._scale )
[docs] def save_model(self) -> LinearModel: """ Save fit model in LinearModel format. """ assert hasattr(self, '_w'), 'please fit model first' return LinearModel(self._w, RegType.Logistic, SigType.T1)
[docs] def load_model(self, m: LinearModel) -> None: """ Load LinearModel format model. """ assert ( isinstance(m.weights, SPUObject) and m.weights.device == self._spu ), 'weights should saved in same spu' self._w = m.weights assert m.reg_type is RegType.Logistic, 'only support Logistic Reg in HESS LR' assert m.sig_type is SigType.T1, 'only support T1 sigmoid in HESS LR'
@staticmethod def _next_batch(x1, x2, y, step, batch_size): def _slice(x, step, batch_size): beg = step * batch_size end = min(beg + batch_size, len(x)) return x[beg:end] batch_x1 = x1.device(_slice)(x1, step, batch_size) batch_x2 = x2.device(_slice)(x2, step, batch_size) batch_y = y.device(_slice)(y, step, batch_size) return batch_x1, batch_x2, batch_y @staticmethod def _sigmoid_minimax(p1, p2, y, scale): """sigmoid minimax approximation: y = 0.5 + 0.125 * x""" # NOTE: Prediction `p1` and `p2` are triple scaled integers and # label `y` is single scaled integer, so we have to right shift # `(p1 + p2)`. The whole return value is single scaled integer. y = y.reshape((y.shape[0],)) return (1 << (scale - 1)) + ((p1 + p2) >> (2 * scale + 3)) - y @staticmethod def _encode(x: PYUObject, scale: int, expand: bool): def encode(x, scale): if expand: b = np.ones((x.shape[0], 1)) # intercept x = np.hstack((x, b)) return (x * scale).astype(np.int64) return x.device(encode)(x, scale) @staticmethod def _truncate(w1: np.ndarray, w2: np.ndarray, scale: int) -> np.ndarray: w1 = (w1 >> scale).astype(jnp.float32) w2 = (w2 >> scale).astype(jnp.float32) bias = jnp.full((1, 1), w1[-1] + w2[-1], dtype=jnp.float32) w1 = jnp.resize(w1, (w1.shape[0] - 1, 1)) w2 = jnp.resize(w2, (w2.shape[0] - 1, 1)) w = jnp.concatenate((w1, w2, bias)) / (1 << scale) return w @staticmethod def _predict( x: List[np.ndarray], w: np.ndarray, total_batch: int, batch_size: int, ): """ predict on datasets x. Args: x: input datasets. w: model weights. total_batch: how many full batch in x. batch_size: how many samples use in one calculation. Return: pred scores. """ x = jnp.concatenate(x, axis=1) num_feat = x.shape[1] samples = x.shape[0] assert w.shape[0] == num_feat + 1, "w shape is mismatch to x" assert len(w.shape) == 1 or w.shape[1] == 1, "w should be list or 1D array" w.reshape((w.shape[0], 1)) bias = w[-1, 0] w = jnp.resize(w, (num_feat, 1)) preds = [] def get_pred(x): pred = jnp.matmul(x, w) + bias pred = sigmoid(pred, SigType.T1) return pred end = 0 for idx in range(total_batch): begin = idx * batch_size end = (idx + 1) * batch_size x_slice = x[begin:end, :] preds.append(get_pred(x_slice)) if end < samples: x_slice = x[end:samples, :] preds.append(get_pred(x_slice)) return jnp.concatenate(preds, axis=0)
[docs] def predict(self, x: Union[FedNdarray, VDataFrame]) -> PYUObject: """Probability estimates. Args: x : {FedNdarray, VDataFrame} Predict samples. Returns: PYUObject: probability of the sample for each class in the model. """ ds = self._data_check(x) shapes = x.partition_shape() samples = list(shapes.values())[0][0] total_batch = int(samples / 1024) spu_yhat = self._spu( self._predict, static_argnames=('total_batch', 'batch_size'), )( [d.to(self._spu) for d in ds], self._w, total_batch=total_batch, batch_size=1024, ) return spu_yhat