Source code for

# Copyright 2022 Ant Group Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Tuple, Union

import numpy as np
import pandas as pd

import secretflow.utils.ndarray_encoding as ndarray_encoding
from secretflow.device import PYU, DeviceObject, PYUObject, proxy, reveal
from import Aggregator
from import is_nesting_list
from import DiffieHellman

class _Masker:
    def __init__(self, self_device: PYU, fxp_bits: int):
        self._device = self_device
        self._dh = DiffieHellman()
        self._pub_key, self._pri_key = self._dh.generate_key_pair()
        self._fxp_bits = fxp_bits

    def pub_key(self) -> int:
        return self._pub_key

    def gen_rng(self, pub_keys: Dict[PYU, int]) -> None:
        assert pub_keys, f'Public keys is None or empty.'
        self._rngs = {
            device: np.random.default_rng(
                int(self._dh.generate_secret(self._pri_key, peer_pub_key), base=16)
            for device, peer_pub_key in pub_keys.items()
            if device != self._device

    def mask(
        data: Union[
            List[Union[pd.DataFrame, pd.Series, np.ndarray]],
            Union[pd.DataFrame, pd.Series, np.ndarray],
    ) -> Tuple[Union[List[np.ndarray], np.ndarray], np.dtype]:
        assert data is not None, 'Data shall not be None or empty.'
        is_list = isinstance(data, list)
        if not is_list:
            data = [data]
        if weight is None:
            weight = 1
        masked_data = []
        dtype = None
        for datum in data:
            if isinstance(datum, (pd.DataFrame, pd.Series)):
                datum = datum.values
            assert isinstance(
                datum, np.ndarray
            ), f'Accept ndarray or dataframe/series only but got {type(datum)}'
            if dtype is None:
                dtype = datum.dtype
                assert (
                    datum.dtype == dtype
                ), f'Data should have same dtypes but got {datum.dtype} {dtype}.'
            is_float = np.issubdtype(datum.dtype, np.floating)
            if not is_float:
                assert np.issubdtype(
                    datum.dtype, np.integer
                ), f'Data type are neither integer nor float.'
                if datum.dtype != np.int64:
                    datum = datum.astype(np.int64)
            # Do mulitple before encoding to finite field.
            masked_datum: np.ndarray = (
                ndarray_encoding.encode(datum * weight, self._fxp_bits)
                if is_float
                else datum * weight
            for pyu, rng in self._rngs.items():
                if pyu == self._device:
                mask = rng.integers(
                if pyu > self._device:
                    masked_datum += mask
                    masked_datum -= mask

        if is_list:
            return masked_data, dtype
            return masked_data[0], dtype

[docs]class SecureAggregator(Aggregator): """The secure aggregation implementation of `Masking with One-Time Pads`. `Masking with One-Time Pads` negotiates a secret for every two participants, then uses the secret to hide its input `x`, and each participant outputs. .. math:: y_u = x_u + \sum _{u < v}S_{u,v} - \sum_{u>v}S_{u,v}\quad(mod\:R) the secrets are cancelled out after aggregation and then we can get the correct result. .. math:: \sum y = \sum x For example, the participants Alice, Bob, and Carol each own :math:`x_1, x_2, x_3`, negotiate the secret :math:`s_{a,b}, s_{a,c}, s_{b,c}`, and then output: .. math:: y_1 = x_1 + s_{a,b} + s_{a,c} .. math:: y_2 = x_2 - s_{a,b} + s_{b,c} .. math:: y_3 = x_3 - s_{a,c} - s_{b,c} then it is easy to get .. math:: y_1 + y_2 + y_3 = x_1 + s_{a,b} + s_{a,c} + x_2 - s_{a,b} + s_{b,c} + x_3 - s_{a,c} - s_{b,c} = x_1 + x_2 + x_3 Notes: `Masking with One-Time Pads` is based on semi-honest assumptions and does not support client dropping. For more information, please refer to `Practical Secure Aggregationfor Privacy-Preserving Machine Learning <>`_ Warnings: The SecureAggregator uses :py:meth:`numpy.random.PCG64`. There are many discussions of whether PCG is a CSPRNG (e.g., we perfer a conservative strategy unless a further security analysis came up. Therefore we recommend users to use a standardized CSPRNG in industrial scenarios. Examples: >>> # Alice and bob are both pyu instances. >>> aggregator = SecureAggregator(alice, [alice, bob]) >>> a = alice(lambda : np.random.rand(2, 5))() >>> b = bob(lambda : np.random.rand(2, 5))() >>> sum_a_b = aggregator.sum([a, b], axis=0) >>> # Get the result. >>> sf.reveal(sum_a_b) array([[0.5954927 , 0.9381409 , 0.99397117, 1.551537 , 0.3269863 ], [1.288345 , 1.1820003 , 1.1769378 , 0.7396539 , 1.215364 ]], dtype=float32) >>> average_a_b = aggregator.average([a, b], axis=0) >>> sf.reveal(average_a_b) array([[0.29774636, 0.46907043, 0.49698558, 0.7757685 , 0.16349316], [0.6441725 , 0.5910001 , 0.5884689 , 0.3698269 , 0.607682 ]], dtype=float32) """
[docs] def __init__(self, device: PYU, participants: List[PYU], fxp_bits: int = 18): assert len(set(participants)) == len( participants ), 'Should not have duplicated devices.' self._device = device self._participants = set(participants) self._fxp_bits = fxp_bits self._maskers = { pyu: _Masker(pyu, self._fxp_bits, device=pyu) for pyu in participants } pub_keys = reveal( {pyu: masker.pub_key() for pyu, masker in self._maskers.items()} ) for masker in self._maskers.values(): masker.gen_rng(pub_keys)
def _check_data(self, data: List[PYUObject]): assert data, f'The data should not be None or empty.' assert len(data) == len( self._maskers ), f'Length of the data not equals devices: {len(data)} vs {len(self._maskers)}' devices_of_data = set(datum.device for datum in data) assert ( devices_of_data == self._participants ), 'Devices of the data must be corresponding with this aggregator.' @classmethod def _is_list(cls, masked_data: Union[List, Any]) -> bool: is_list = isinstance(masked_data[0], list) for masked_datum in masked_data[1:]: assert ( isinstance(masked_datum, list) == is_list ), f'Some data are list where some others are not.' assert not is_list or len(masked_datum) == len( masked_datum[0] ), f'Lengths of datum in data are different.' return is_list
[docs] def sum(self, data: List[PYUObject], axis=None): def _sum(*masked_data: List[np.ndarray], dtypes: List[np.dtype]): for dtype in dtypes[1:]: assert ( dtype == dtypes[0] ), f'Data should have same dtypes but got {dtype} {dtypes[0]}.' is_float = np.issubdtype(dtypes[0], np.floating) if is_nesting_list(masked_data): results = [np.sum(element, axis=axis) for element in zip(*masked_data)] return ( [ ndarray_encoding.decode(result, self._fxp_bits) for result in results ] if is_float else results ) else: result = np.sum(masked_data, axis=axis) return ( ndarray_encoding.decode(result, self._fxp_bits) if is_float else result ) self._check_data(data) masked_data = [None] * len(data) dtypes = [None] * len(data) for i, datum in enumerate(data): masked_data[i], dtypes[i] = self._maskers[datum.device].mask(datum) masked_data = [ for d in masked_data] dtypes = [ for dtype in dtypes] return self._device(_sum)(*masked_data, dtypes=dtypes)
[docs] def average(self, data: List[PYUObject], axis=None, weights=None): def _average(*masked_data: List[np.ndarray], dtypes: List[np.dtype], weights): for dtype in dtypes[1:]: assert ( dtype == dtypes[0] ), f'Data should have same dtypes but got {dtype} {dtypes[0]}.' is_float = np.issubdtype(dtypes[0], np.floating) sum_weights = ( np.sum(weights, axis=axis) if weights else len(data) ) if is_nesting_list(masked_data): sum_data = [np.sum(element, axis=axis) for element in zip(*masked_data)] if is_float: sum_data = [ ndarray_encoding.decode(sum_datum, self._fxp_bits) for sum_datum in sum_data ] return [element / sum_weights for element in sum_data] else: if is_float: return ( ndarray_encoding.decode( np.sum(masked_data, axis=axis), self._fxp_bits ) / sum_weights ) return np.sum(masked_data, axis=axis) / sum_weights self._check_data(data) masked_data = [None] * len(data) dtypes = [None] * len(data) _weights = [] if weights is not None and isinstance(weights, (list, tuple, np.ndarray)): assert len(weights) == len( data ), f'Length of the weights not equals data: {len(weights)} vs {len(data)}.' for i, w in enumerate(weights): if isinstance(w, DeviceObject): assert ( w.device == data[i].device ), 'Device of weight is not same with the corresponding data.' _weights.append( else: _weights.append(w) for i, (datum, weight) in enumerate(zip(data, weights)): masked_data[i], dtypes[i] = self._maskers[datum.device].mask( datum, weight ) else: for i, datum in enumerate(data): masked_data[i], dtypes[i] = self._maskers[datum.device].mask( datum, weights ) masked_data = [ for d in masked_data] dtypes = [ for dtype in dtypes] return self._device(_average)(*masked_data, dtypes=dtypes, weights=_weights)