Source code for secretflow.ml.nn.fl.sparse

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List

import numpy as np
import jaxlib


[docs]class STCSparse: """Stc sparser, sample TopK element from original Weights TODO: 补充docstring """
[docs] def __init__(self, sparse_rate: float): self.sparse_rate = sparse_rate self.name = 'STC'
def __call__( self, weights: List[np.ndarray], ) -> List[np.ndarray]: compression_weights = [] mask_arrays = [] for weight in weights: weight_shape = weight.shape weight_flat = weight.flatten() if isinstance(weight_flat, jaxlib.xla_extension.DeviceArray): weight_flat = np.array(weight_flat) weight_flat_abs = np.abs(weight_flat) weight_len = weight_flat.shape[0] mask_num = round(self.sparse_rate * weight_len) mask_index = np.sort(np.argsort(weight_flat_abs)[:mask_num]) mask_array = np.ones(weight_flat.shape) if mask_index.shape[0] != 0: weight_flat[mask_index] = 0 mask_array[mask_index] = 0 if weight_len == mask_num: average_value = 0.0 else: average_value = np.sum(np.absolute(weight_flat)) / ( weight_len - mask_num ) weight_compress = average_value * np.sign(weight_flat) compression_weight = weight_compress.reshape(weight_shape) compression_weights.append(compression_weight) mask_array = mask_array.reshape(weight_shape) mask_arrays.append(mask_array) return compression_weights
[docs]class SCRSparse: """Stc sparser, sample TopK element from original Weights TODO: 补充docstring """
[docs] def __init__(self, threshold: float): self.threshold = threshold self.name = 'SCR'
def __call__( self, weights: List[np.ndarray], ) -> List[np.ndarray]: compression_weights = [] mask_arrays = [] for weight in weights: weight_shape = weight.shape if len(weight_shape) == 4: # CNN layer # Keep the 0th dimension sum_0 = np.sum(np.absolute(weight), axis=(1, 2, 3)) sum_0 = sum_0 / np.max(sum_0) index_zero_0 = self.get_dimension(sum_0, self.threshold) weight[index_zero_0, :, :, :] = 0.0 # Keep the 1th dimension sum_1 = np.sum(np.absolute(weight), axis=(0, 2, 3)) sum_1 = sum_1 / np.max(sum_1) index_zero_1 = self.get_dimension(sum_1, self.threshold) weight[:, index_zero_1, :, :] = 0.0 if len(weight_shape) == 2: # Dense layer # Keep the 0th dimension sum_0 = np.sum(np.absolute(weight), axis=1) sum_0 = sum_0 / np.max(sum_0) index_zero_0 = self.get_dimension(sum_0, self.threshold) weight[index_zero_0, :] = 0.0 # Keep the 1th dimension sum_1 = np.sum(np.absolute(weight), axis=0) sum_1 = sum_1 / np.max(sum_1) index_zero_1 = self.get_dimension(sum_1, self.threshold) weight[:, index_zero_1] = 0.0 compression_weight = weight compression_weights.append(compression_weight) mask_array = np.array(compression_weight, dtype=bool) mask_arrays.append(mask_array) return compression_weights
[docs] def get_dimension(self, index_value, threshold): return np.argwhere(index_value <= threshold)