Source code for secretflow.preprocessing.binning.kernels.quantile_summaries

#!/usr/bin/env python3
# *_* coding: utf-8 *_*

# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from dataclasses import dataclass
from typing import List, Union

import numpy as np


[docs]@dataclass() class Stats(object): """store information for each item in the summary Attributes: value: value of this stat w: weight of this stat delta: delta = rmax - rmin """ value: float w: int delta: int
[docs]class QuantileSummaries(object): """QuantileSummary insert: insert data to summary merge: merge summaries fast_init: A fast version implementation creates the summary with little performance loss compress: compress summary to some size Attributes: compress_thres: if num of stats greater than compress_thres, do compress head_size: buffer size for insert data, when samples come to head_size do create summary error: 0 <= error < 1 default: 0.001, error tolerance for binning. floor((p - 2 * error) * N) <= rank(x) <= ceil((p + 2 * error) * N) abnormal_list: List of abnormal feature, will not participate in binning """
[docs] def __init__( self, compress_thres: int = 10000, head_size: int = 10000, error: float = 1e-4, abnormal_list: List = None, ): self.compress_thres = compress_thres self.head_size = head_size self.error = error self.head_sampled = [] self.sampled = [] self.count = 0 self.missing_count = 0 if abnormal_list is None: self.abnormal_list = [] else: self.abnormal_list = abnormal_list
[docs] def fast_init(self, col_data: np.ndarray): if self.compress_thres > len(col_data): self.compress_thres = len(col_data) new_sampled = [] for ab_item in self.abnormal_list: col_data = col_data[col_data != ab_item] bin_list = ( np.linspace(0, len(col_data), self.compress_thres + 1) .round()[1:] .astype(int) ) pre_rank = 0 sorted_data = np.sort(col_data) for idx, bin_t in enumerate(bin_list): rank_t = sorted_data[bin_t - 1] delta = 0 new_stats = Stats(rank_t, bin_t - pre_rank, delta) new_sampled.append(new_stats) pre_rank = bin_t self.sampled = new_sampled self.head_sampled = [] self.count = len(col_data) if len(self.sampled) >= self.compress_thres: self.compress()
[docs] def compress(self): """compress the summary, summary.sample will under compress_thres""" merge_threshold = 2 * self.error * self.count compressed = self._compress_immut(merge_threshold) self.sampled = compressed
[docs] def query(self, quantile: float) -> float: """Use to query the value that specifies the quantile location Args: quantile : float [0.0, 1.0] Returns: float, the value of the quantile location """ if self.head_sampled: self.compress() if quantile < 0 or quantile > 1: raise ValueError("Quantile should be in range [0.0, 1.0]") if self.count == 0: return 0 if quantile <= self.error: return self.sampled[0].value if quantile >= 1 - self.error: return self.sampled[-1].value rank = math.ceil(quantile * self.count) target_error = math.ceil(self.error * self.count) min_rank = 0 i = 1 while i < len(self.sampled) - 1: cur_sample = self.sampled[i] min_rank += cur_sample.w max_rank = min_rank + cur_sample.delta if max_rank - target_error <= rank <= min_rank + target_error: return cur_sample.value i += 1 return self.sampled[-1].value
[docs] def value_to_rank(self, value: Union[float, int]) -> int: min_rank, max_rank = 0, 0 for sample in self.sampled: if sample.value < value: min_rank += sample.w max_rank = min_rank + sample.delta else: return (min_rank + max_rank) // 2 return (min_rank + max_rank) // 2
[docs] def batch_query_value(self, values: List[float]) -> List[int]: """batch query function Args: values : List sorted_list of value. eg:[13, 56, 79] Returns: List : output ranks of each query """ self.compress() res = [] min_rank, max_rank = 0, 0 idx = 0 sample_idx = 0 while sample_idx < len(self.sampled): v = values[idx] sample = self.sampled[sample_idx] if sample.value < v: min_rank += sample.w max_rank = min_rank + sample.delta sample_idx += 1 else: res.append((min_rank + max_rank) // 2) idx += 1 if idx >= len(values): break while idx < len(values): res.append((min_rank + max_rank) // 2) idx += 1 return res
def _compress_immut(self, merge_threshold: float) -> List: if not self.sampled: return self.sampled res = [] # Start from the last element head = self.sampled[-1] for i in range(len(self.sampled) - 2, 0, -1): this_sample = self.sampled[i] if this_sample.w + head.w + head.delta < merge_threshold: head.w = head.w + this_sample.w else: res.append(head) head = this_sample res.append(head) current_head = self.sampled[0] if current_head.value <= head.value and len(self.sampled) > 1: res.append(current_head) res.reverse() return res