# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""keras global evaluation metrics
"""
from dataclasses import dataclass
from typing import List
from abc import ABC, abstractmethod
import logging
# The reason we just do not inherit or combine tensorflow metrics
# is tensorflow metrics are un-serializable but we need send they from worker to server.
[docs]class Metric(ABC):
[docs] @abstractmethod
def result(self):
pass
@abstractmethod
def __radd__(self, other):
pass
@abstractmethod
def __add__(self, other):
pass
[docs]@dataclass
class Default(Metric):
name: str
total: float
count: float
def __add__(self, other):
assert self.name == other.name
total = self.total + other.total
count = self.count + other.count
return Default(self.name, total, count)
def __radd__(self, other):
assert other == 0
return Default(self.name, self.total, self.count)
[docs] def result(self):
logging.warn(
"Please pay attention to local metrics, global only do naive aggregation "
)
import tensorflow as tf
metric = tf.keras.metrics.Mean()
metric.total = self.total
metric.count = self.count
return metric.result()
[docs]@dataclass
class Mean(Metric):
"""keras.metrics.Mean on fede
Attributes:
total: sum of metrics
count: num of samples
"""
name: str
total: float
count: float
def __radd__(self, other):
assert other == 0
return Mean(self.name, self.total, self.count)
def __add__(self, other: 'Mean'):
assert self.name == other.name
total = self.total + other.total
count = self.count + other.count
return Mean(self.name, total, count)
[docs] def result(self):
import tensorflow as tf
metric = tf.keras.metrics.Mean()
metric.total = self.total
metric.count = self.count
return metric.result()
[docs]class AUC(Metric):
"""Federated keras.metrics.AUC
Attributes:
thresholds: threshold of buckets. same to tf.keras.metrics.AUC,must contain 0 and 1.
true_positives: num samples of true positive.
true_negatives: num samples of true negative.
false_positives: num samples of false positive.
false_negatives: num samples of false negative.
curve: type of AUC curve, same to 'tf.keras.metrics.AUC', it can be 'ROC' or 'PR'.
"""
[docs] def __init__(
self,
name: str,
thresholds: List[float],
true_positives: List[float],
true_negatives: List[float],
false_positives: List[float],
false_negatives: List[float],
curve=None,
):
self.name = name
self.thresholds = thresholds
self.true_positives = true_positives
self.true_negatives = true_negatives
self.false_positives = false_positives
self.false_negatives = false_negatives
if curve is not None:
self.curve = curve
else:
from tensorflow.python.keras.utils.metrics_utils import AUCCurve
self.curve: AUCCurve = AUCCurve.ROC
def __radd__(self, other):
assert other == 0
return AUC(
self.name,
self.thresholds,
self.true_positives,
self.true_negatives,
self.false_positives,
self.false_negatives,
self.curve,
)
def __add__(self, other: 'AUC'):
assert self.name == other.name
assert self.curve == other.curve, f'Curves are different!'
assert len(self.thresholds) == len(other.thresholds) and all(
i == j for i, j in zip(self.thresholds, other.thresholds)
), f'Thresholds are different!'
true_positives = self.true_positives + other.true_positives
true_negatives = self.true_negatives + other.true_negatives
false_positives = self.false_positives + other.false_positives
false_negatives = self.false_negatives + other.false_negatives
return AUC(
self.name,
self.thresholds,
true_positives,
true_negatives,
false_positives,
false_negatives,
self.curve,
)
[docs] def result(self):
import tensorflow as tf
# 由于tf.keras.metrics.AUC会默认给thresholds添加{-epsilon, 1+epsilon}两个边界值,因此这里需要去掉两个边界点。
metric = tf.keras.metrics.AUC(
thresholds=self.thresholds[1:-1], curve=self.curve
)
metric.true_positives = self.true_positives
metric.true_negatives = self.true_negatives
metric.false_positives = self.false_positives
metric.false_negatives = self.false_negatives
return metric.result()
[docs]@dataclass
class Precision(Metric):
"""Federated keras.metrics.Precision
Attributes:
thresholds: value of threshold, float or list, in [0, 1].
true_positives: num samples of true positive
false_positives: num samples of false positive
"""
name: str
thresholds: float
true_positives: float
false_positives: float
def __radd__(self, other):
assert other == 0
return Precision(
self.name, self.thresholds, self.true_positives, self.false_positives
)
def __add__(self, other: 'Precision'):
assert self.name == other.name
thresholds = self.thresholds
true_positives = self.true_positives + other.true_positives
false_positives = self.false_positives + other.false_positives
return Precision(self.name, thresholds, true_positives, false_positives)
[docs] def result(self):
import tensorflow as tf
metric = tf.keras.metrics.Precision()
metric.thresholds = self.thresholds
metric.true_positives = self.true_positives
metric.false_positives = self.false_positives
return metric.result()
[docs]@dataclass
class Recall(Metric):
"""Federated keras.metrics.Recall
Attributes:
thresholds: value of threshold, float or list, in [0, 1].
true_positives: num samples of true positive
false_negatives: num samples of false negative
"""
name: str
thresholds: float
true_positives: float
false_negatives: float
def __radd__(self, other):
assert other == 0
return Recall(
self.name, self.thresholds, self.true_positives, self.false_negatives
)
def __add__(self, other: 'Recall'):
assert self.name == other.name
thresholds = self.thresholds
true_positives = self.true_positives + other.true_positives
false_negatives = self.false_negatives + other.false_negatives
return Recall(self.name, thresholds, true_positives, false_negatives)
[docs] def result(self):
import tensorflow as tf
metric = tf.keras.metrics.Recall()
metric.thresholds = self.thresholds
metric.true_positives = self.true_positives
metric.false_negatives = self.false_negatives
return metric.result()
[docs]def aggregate_metrics(local_metrics: List[List]) -> List:
"""Aggregate Model metrics values of each party and calculate global metrics.
Args:
local_metrics: Model metrics values in this party.
Returns:
A list of aggregations of each party metrics.
"""
return [sum(metrics) for metrics in zip(*local_metrics)]