# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from secretflow.device import PYU, DeviceObject, PYUObject
from secretflow.security.aggregation.aggregator import Aggregator
[docs]class PlainAggregator(Aggregator):
"""Plaintext aggregator.
The computation will be performed in plaintext.
Warnings:
PlainAggregator is for debugging purpose only.
You should not use it in production.
Examples:
>>> # Alice and bob are both pyu instances.
>>> aggregator = PlainAggregator(alice)
>>> a = alice(lambda : np.random.rand(2, 5))()
>>> b = bob(lambda : np.random.rand(2, 5))()
>>> sum_a_b = aggregator.sum([a, b], axis=0)
>>> # Get the result.
>>> sf.reveal(sum_a_b)
array([[0.5954927 , 0.9381409 , 0.99397117, 1.551537 , 0.32698634],
[1.288345 , 1.1820003 , 1.1769378 , 0.7396539 , 1.215364 ]],
dtype=float32)
>>> average_a_b = aggregator.average([a, b], axis=0)
>>> sf.reveal(average_a_b)
array([[0.29774636, 0.46907046, 0.49698558, 0.7757685 , 0.16349317],
[0.6441725 , 0.59100014, 0.5884689 , 0.36982694, 0.607682 ]],
dtype=float32)
"""
[docs] def __init__(self, device: PYU):
assert isinstance(device, PYU), f'Accepts PYU only but got {type(device)}.'
self.device = device
@staticmethod
def _get_dtype(arr):
if isinstance(arr, np.ndarray):
return arr.dtype
else:
try:
import tensorflow as tf
if isinstance(arr, tf.Tensor):
return arr.numpy().dtype
except ImportError:
return None
[docs] def sum(self, data: List[DeviceObject], axis=None) -> PYUObject:
"""Sum of array elements over a given axis.
Args:
data: array of device objects.
axis: optional. Same as the axis argument of :py:meth:`numpy.mean`.
Returns:
a device object holds the sum.
"""
assert data, 'Data to aggregate should not be None or empty!'
data = [d.to(self.device) for d in data]
def _sum(*data, axis):
if isinstance(data[0], (list, tuple)):
return [np.sum(element, axis=axis) for element in zip(*data)]
else:
return np.sum(data, axis=axis)
return self.device(_sum)(*data, axis=axis)
[docs] def average(self, data: List[DeviceObject], axis=None, weights=None) -> PYUObject:
"""Compute the weighted average along the specified axis.
Args:
data: array of device objects.
axis: optional. Same as the axis argument of :py:meth:`numpy.average`.
weights: optional. Same as the weights argument of :py:meth:`numpy.average`.
Returns:
a device object holds the weighted average.
"""
assert data, 'Data to aggregate should not be None or empty!'
data = [d.to(self.device) for d in data]
if isinstance(weights, (list, tuple)):
weights = [
w.to(self.device) if isinstance(w, DeviceObject) else w for w in weights
]
def _average(*data, axis, weights):
if isinstance(data[0], (list, tuple)):
results = []
for elements in zip(*data):
avg = np.average(elements, axis=axis, weights=weights)
res_dtype = self._get_dtype(elements[0])
if res_dtype:
avg = avg.astype(res_dtype)
results.append(avg)
return results
else:
res = np.average(data, axis=axis, weights=weights)
res_dtype = self._get_dtype(data[0])
return res.astype(res_dtype) if res_dtype else res
return self.device(_average)(*data, axis=axis, weights=weights)