Source code for

# Copyright 2022 Ant Group Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union, List, Callable

import pandas as pd
from pandas.core.indexes.base import Index
from jax.tree_util import tree_map

from import to_csv_wrapper
from secretflow.device import PYUObject, reveal

[docs]class DataFrameBase(ABC): """Abstract base class for horizontal, vertical and mixed partitioned DataFrame"""
[docs] @abstractmethod def min(self): """Gets minimum value of all columns""" pass
[docs] @abstractmethod def max(self): """Gets maximum value of all columns""" pass
[docs] @abstractmethod def count(self): """Gets number of rows""" pass
[docs] @abstractmethod def values(self): """Get underlying ndarray""" pass
@abstractmethod def __getitem__(self, item): """Get columns""" pass @abstractmethod def __setitem__(self, key, value): """Set columns""" pass
[docs]@dataclass class Partition(DataFrameBase): """Slice of data that makes up horizontal, vertical and mixed partitioned DataFrame. Attributes: data (PYUObject): Reference to pandas.DataFrame located in local node. """ data: PYUObject = None def __partition_wrapper(self, fn: Callable, *args, **kwargs) -> 'Partition': return Partition(, *args, **kwargs))
[docs] def mean(self, *args, **kwargs) -> 'Partition': """Returns the mean of the values over the requested axis. Returns: Partition: mean values series. """ return self.__partition_wrapper(pd.DataFrame.mean, *args, **kwargs)
[docs] def var(self, *args, **kwargs) -> 'Partition': """Returns the variance of the values over the requested axis. Returns: Partition: variance values series. """ return self.__partition_wrapper(pd.DataFrame.var, *args, **kwargs)
[docs] def std(self, *args, **kwargs) -> 'Partition': """Returns the standard deviation of the values over the requested axis. Returns: Partition: standard deviation values series. """ return self.__partition_wrapper(pd.DataFrame.std, *args, **kwargs)
[docs] def sem(self, *args, **kwargs) -> 'Partition': """Returns the standard error of the mean over the requested axis. Returns: Partition: standard error of the mean series. """ return self.__partition_wrapper(pd.DataFrame.sem, *args, **kwargs)
[docs] def skew(self, *args, **kwargs) -> 'Partition': """Returns the skewness over the requested axis. Returns: Partition: skewness series. """ return self.__partition_wrapper(pd.DataFrame.skew, *args, **kwargs)
[docs] def kurtosis(self, *args, **kwargs) -> 'Partition': """Returns the kurtosis over the requested axis. Returns: Partition: kurtosis series. """ return self.__partition_wrapper(pd.DataFrame.kurtosis, *args, **kwargs)
[docs] def sum(self, *args, **kwargs) -> 'Partition': """Returns the sum of the values over the requested axis. Returns: Partition: sum values series. """ return self.__partition_wrapper(pd.DataFrame.sum, *args, **kwargs)
[docs] def replace(self, *args, **kwargs) -> 'Partition': """Replace values given in to_replace with value. Same as pandas.DataFrame.replace Values of the DataFrame are replaced with other values dynamically. Returns: Partition: same shape except value replaced """ return self.__partition_wrapper(pd.DataFrame.replace, *args, **kwargs)
[docs] def quantile(self, q=0.5, axis=0) -> 'Partition': """Returns values at the given quantile over requested axis. Returns: Partition: quantile values series. """ # q is between 0 and 1 q = max(min(q, 1), 0) # limit q to one of 0, 0.25, 0.5, 0.75 and 1 q = round(4 * q) / 4 return self.__partition_wrapper(pd.DataFrame.quantile, q=q, axis=axis)
[docs] def min(self, *args, **kwargs) -> 'Partition': """Returns the minimum of the values over the requested axis. Returns: Partition: minimum values series. """ return self.__partition_wrapper(pd.DataFrame.min, *args, **kwargs)
[docs] def mode(self, *args, **kwargs) -> 'Partition': """Returns the mode of the values over the requested axis. For data protection reasons, only one mode will be returned. Returns: Partition: mode values series. """ def _mode(*args, **kwargs): return pd.DataFrame.mode(*args, **kwargs).iloc[0, :] return self.__partition_wrapper(_mode, *args, **kwargs)
[docs] def max(self, *args, **kwargs) -> 'Partition': """Returns the maximum of the values over the requested axis. Returns: Partition: maximum values series. """ return self.__partition_wrapper(pd.DataFrame.max, *args, **kwargs)
[docs] def count(self, *args, **kwargs) -> 'Partition': """Counts non-NA cells for each column or row. Returns: Partition: count values series. """ return self.__partition_wrapper(pd.DataFrame.count, *args, **kwargs)
[docs] def isna(self) -> 'Partition': """Detects missing values for an array-like object. Same as pandas.DataFrame.isna Returns DataFrame: Mask of bool values for each element in DataFrame that indicates whether an element is an NA value. """ return self.__partition_wrapper(pd.DataFrame.isna)
def __unwrap(self, args, kwargs): new_args = tree_map(lambda x: if (type(x) == type(self)) else x, args) new_kwargs = tree_map( lambda x: if (type(x) == type(self)) else x, kwargs ) return new_args, new_kwargs
[docs] def pow(self, *args, **kwargs) -> 'Partition': """Gets Exponential power of (partition of) dataframe and other, element-wise (binary operator pow). Equivalent to dataframe ** other, but with support to substitute a fill_value for missing data in one of the inputs. With reverse version, rpow. Among flexible wrappers (add, sub, mul, div, mod, pow) to arithmetic operators: +, -, *, /, //, %, **. Reference: pd.DataFrame.pow """ new_args, new_kwargs = self.__unwrap(args, kwargs) return self.__partition_wrapper(pd.DataFrame.pow, *new_args, **new_kwargs)
[docs] def subtract(self, *args, **kwargs) -> 'Partition': """Gets Subtraction of (partition of) dataframe and other, element-wise (binary operator sub). Equivalent to dataframe - other, but with support to substitute a fill_value for missing data in one of the inputs. With reverse version, rsub. Among flexible wrappers (add, sub, mul, div, mod, pow) to arithmetic operators: +, -, *, /, //, %, **. Reference: pd.DataFrame.subtract """ new_args, new_kwargs = self.__unwrap(args, kwargs) return self.__partition_wrapper(pd.DataFrame.subtract, *new_args, **new_kwargs)
[docs] def round(self, *args, **kwargs) -> 'Partition': """Round the (partition of) DataFrame to a variable number of decimal places. Reference: pd.DataFrame.round """ new_args, new_kwargs = self.__unwrap(args, kwargs) return self.__partition_wrapper(pd.DataFrame.round, *new_args, **new_kwargs)
[docs] def select_dtypes(self, *args, **kwargs) -> 'Partition': """Returns a subset of the DataFrame's columns based on the column dtypes. Reference: pandas.DataFrame.select_dtypes """ return self.__partition_wrapper(pd.DataFrame.select_dtypes, *args, **kwargs)
@property def values(self): """Returns the underlying ndarray.""" return df: df.values)( @property @reveal def index(self): """Returns the index (row labels) of the DataFrame.""" return df: df.index)( @property @reveal def dtypes(self): """Returns the dtypes in the DataFrame.""" # return series always. return lambda df: df.dtypes if isinstance(df, pd.DataFrame) else pd.Series({ df.types}) )(
[docs] def astype(self, dtype, copy: bool = True, errors: str = "raise"): """ Cast a pandas object to a specified dtype ``dtype``. All args are same as :py:meth:`pandas.DataFrame.astype`. """ return Partition(, dtype, copy, errors) )
@property @reveal def columns(self): """Returns the column labels of the DataFrame.""" return df: df.columns)( @property @reveal def shape(self): """Returns a tuple representing the dimensionality of the DataFrame.""" return df: df.shape)(
[docs] def iloc(self, index: Union[int, slice, List[int]]) -> 'Partition': """Integer-location based indexing for selection by position. Args: index (Union[int, slice, List[int]]): rows index. Returns: Partition: Selected DataFrame. """ return Partition( df, index: df.iloc[index])(, index) )
[docs] def drop( self, labels=None, axis=0, index=None, columns=None, level=None, inplace=False, errors='raise', ) -> Union['Partition', None]: """See `pandas.DataFrame.drop`""" def _drop(df: pd.DataFrame, **kwargs): if inplace: new_df = df.copy(deep=True) new_df.drop(**kwargs) return new_df else: return df.drop(**kwargs) new_data =, labels=labels, axis=axis, index=index, columns=columns, level=level, inplace=inplace, errors=errors, ) if inplace: = new_data else: return Partition(new_data)
[docs] def fillna( self, value=None, method=None, axis=None, inplace=False, limit=None, downcast=None, ) -> Union['Partition', None]: """See :py:meth:`pandas.DataFrame.fillna`""" def _fillna(df: pd.DataFrame, **kwargs): if inplace: new_df = df.copy(deep=True) new_df.fillna(**kwargs) return new_df else: return df.fillna(**kwargs) new_data =, value=value, method=method, axis=axis, inplace=inplace, limit=limit, downcast=downcast, ) if inplace: = new_data else: return Partition(new_data)
[docs] def rename( self, mapper=None, index=None, columns=None, axis=None, copy=True, inplace=False, level=None, errors='ignore', ) -> Union['Partition', None]: """See :py:meth:`pandas.DataFrame.rename`""" def _rename(df: pd.DataFrame, **kwargs): if inplace: new_df = df.copy(deep=True) new_df.rename(**kwargs) return new_df else: return df.rename(**kwargs) new_data =, mapper=mapper, index=index, columns=columns, axis=axis, copy=copy, inplace=inplace, level=level, errors=errors, ) if inplace: = new_data else: return Partition(new_data)
[docs] def value_counts(self, *args, **kwargs) -> 'Partition': """Return a Series containing counts of unique values.""" return self.__partition_wrapper(pd.DataFrame.value_counts, *args, **kwargs)
[docs] def to_csv(self, filepath, **kwargs): """Save DataFrame to csv file.""" return, filepath, **kwargs)
@reveal def __len__(self): """Returns the number of rows.""" return df: len(df))( def __getitem__(self, item: Union[str, List[str]]) -> 'Partition': """Get columns from DataFrame. NOTE: Always return DataFrame even if specify single column. Args: item (Union[str, List[str]]): Columns to get. Returns: Partition: DataFrame. """ item_list = item if not isinstance(item, (list, tuple, Index)): item_list = [item_list] return self.__partition_wrapper(pd.DataFrame.__getitem__, item_list) def __setitem__(self, key, value): """Assign values to columns. Args: key (Union[str, List[str]]): columns to be assigned. value (Partition): assigned values. """ if isinstance(value, Partition): assert ( == ), f'Can not assign a partition with different device.' def _setitem(df: pd.DataFrame, key, value): # Deep copy DataFrame since ray object store is immutable. df = df.copy(deep=True) df[key] = value return df =, key, value if not isinstance(value, Partition) else )
[docs] def copy(self): """Shallow copy.""" return Partition(