Source code for secretflow.ml.nn.sl.backend.tensorflow.utils
#!/usr/bin/env python3
# *_* coding: utf-8 *_*
# # Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
import tensorflow as tf
[docs]class custom_loss:
"""Decorator to define a function with a custom loss.
This decorator allows to define loss functions with additional keyword arguments.
These keyword arguments must match the results of model's forward pass.
Examples:
>>> import tensorflow as tf
>>> # define model
>>> class MyModel(tf.keras.Model):
>>> def call(self, inputs, **kwargs):
>>> # do forward pass
>>> return None, y_pred, {'kwarg1': kwarg1, 'kwarg2': kwarg2}
>>> # define loss function
>>> @custom_loss
>>> def my_loss(y_true, y_pred, kwarg1 = None, kwarg2 = None):
>>> # cumpute loss
>>> pass
>>> # compile model with custom loss function
>>> model = MyModel(...)
>>> model.compile(
>>> loss=my_loss,
>>> optimizer=tf.keras.optimizers.Adam(0.01),
>>> metrics=['acc'],
>>> )
Note: `custom_loss`, `my_loss` and `MyModel` need to be added to custom_objects when loading the model.
"""
[docs] def __init__(self, func: Callable):
self.name = func.__name__
self.func = func
self.kwargs = {}
[docs] def with_kwargs(self, kwargs):
self.kwargs = kwargs if kwargs else {}
def __call__(self, y_true, y_pred):
return self.func(y_true, y_pred, **self.kwargs)
[docs] def get_config(self):
return {
'name': self.name,
}
[docs] @classmethod
def from_config(cls, config):
custom_objects = tf.keras.utils.get_custom_objects()
# The object with func name has already been wrapped, so return it directly.
return custom_objects[config['name']]