Source code for secretflow.security.aggregation.plain_aggregator

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List

import numpy as np

from secretflow.device import PYU, DeviceObject, PYUObject
from secretflow.security.aggregation.aggregator import Aggregator


[docs]class PlainAggregator(Aggregator): """Plaintext aggregator. The computation will be performed in plaintext. Warnings: PlainAggregator is for debugging purpose only. You should not use it in production. Examples: >>> # Alice and bob are both pyu instances. >>> aggregator = PlainAggregator(alice) >>> a = alice(lambda : np.random.rand(2, 5))() >>> b = bob(lambda : np.random.rand(2, 5))() >>> sum_a_b = aggregator.sum([a, b], axis=0) >>> # Get the result. >>> sf.reveal(sum_a_b) array([[0.5954927 , 0.9381409 , 0.99397117, 1.551537 , 0.32698634], [1.288345 , 1.1820003 , 1.1769378 , 0.7396539 , 1.215364 ]], dtype=float32) >>> average_a_b = aggregator.average([a, b], axis=0) >>> sf.reveal(average_a_b) array([[0.29774636, 0.46907046, 0.49698558, 0.7757685 , 0.16349317], [0.6441725 , 0.59100014, 0.5884689 , 0.36982694, 0.607682 ]], dtype=float32) """
[docs] def __init__(self, device: PYU): assert isinstance(device, PYU), f'Accepts PYU only but got {type(device)}.' self.device = device
@staticmethod def _get_dtype(arr): if isinstance(arr, np.ndarray): return arr.dtype else: try: import tensorflow as tf if isinstance(arr, tf.Tensor): return arr.numpy().dtype except ImportError: return None
[docs] def sum(self, data: List[DeviceObject], axis=None) -> PYUObject: """Sum of array elements over a given axis. Args: data: array of device objects. axis: optional. Same as the axis argument of :py:meth:`numpy.mean`. Returns: a device object holds the sum. """ assert data, 'Data to aggregate should not be None or empty!' data = [d.to(self.device) for d in data] def _sum(*data, axis): if isinstance(data[0], (list, tuple)): return [np.sum(element, axis=axis) for element in zip(*data)] else: return np.sum(data, axis=axis) return self.device(_sum)(*data, axis=axis)
[docs] def average(self, data: List[DeviceObject], axis=None, weights=None) -> PYUObject: """Compute the weighted average along the specified axis. Args: data: array of device objects. axis: optional. Same as the axis argument of :py:meth:`numpy.average`. weights: optional. Same as the weights argument of :py:meth:`numpy.average`. Returns: a device object holds the weighted average. """ assert data, 'Data to aggregate should not be None or empty!' data = [d.to(self.device) for d in data] if isinstance(weights, (list, tuple)): weights = [ w.to(self.device) if isinstance(w, DeviceObject) else w for w in weights ] def _average(*data, axis, weights): if isinstance(data[0], (list, tuple)): results = [] for elements in zip(*data): avg = np.average(elements, axis=axis, weights=weights) res_dtype = self._get_dtype(elements[0]) if res_dtype: avg = avg.astype(res_dtype) results.append(avg) return results else: res = np.average(data, axis=axis, weights=weights) res_dtype = self._get_dtype(data[0]) return res.astype(res_dtype) if res_dtype else res return self.device(_average)(*data, axis=axis, weights=weights)