Source code for synthyverse.generators.unmaskingtrees_generator.unmaskingtrees
from ..base import TabularBaseGenerator
import pandas as pd
from utrees import UnmaskingTrees
[docs]
class UnmaskingTreesGenerator(TabularBaseGenerator):
"""Unmasking Trees.
Unmasking Trees is an autoregressive model which hierarchically partitions features into binary bins,
to then recursively train XGBoost classifiers along the meta-tree hierarchy.
We use the implementation from the utrees pypi package. Can be costly for large datasets.
Paper: "Unmasking trees for tabular data" by C. McCarter (2024).
Args:
depth (int): Depth of the meta-tree. Default: 4.
duplicate_K (int): Number of duplications for each sample. Default: 50.
xgboost_kwargs (dict): Dictionary of additional XGBoost parameters. Default: {}.
strategy (str): Strategy for quantization. Options: "quantile", "uniform",
"kmeans", "kdiquantile". Default: "kdiquantile".
softmax_temp (float): Temperature for softmax. Default: 1.
cast_float32 (bool): Whether to cast to float32. Default: True.
tabpfn (bool): Whether to use TabPFN. Default: False.
random_state (int): Random seed for reproducibility. Default: 0.
**kwargs: Additional arguments passed to TabularBaseGenerator.
Example:
>>> import pandas as pd
>>> from synthyverse.generators import UnmaskingTreesGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator
>>> generator = UnmaskingTreesGenerator(
... depth=4,
... duplicate_K=50,
... strategy="kdiquantile",
... random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)
"""
name = "unmaskingtrees"
handles_missingness = True
def __init__(
self,
depth: int = 4,
duplicate_K: int = 50,
xgboost_kwargs: dict = {},
strategy: str = "kdiquantile", # quantile, uniform, kmeans
softmax_temp: float = 1,
cast_float32: bool = True,
tabpfn: bool = False,
random_state: int = 0,
**kwargs,
):
super().__init__(random_state=random_state, **kwargs)
self.model_params = {
"depth": depth,
"duplicate_K": duplicate_K,
"xgboost_kwargs": xgboost_kwargs,
"softmax_temp": softmax_temp,
"cast_float32": cast_float32,
"tabpfn": tabpfn,
"strategy": strategy,
"random_state": random_state,
}
def _fit_model(
self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None
):
self.ori_cols = X.columns
quantize_cols = []
for col in X.columns:
if col in discrete_features:
quantize_cols.append("categorical")
else:
quantize_cols.append("continuous")
self.model = UnmaskingTrees(**self.model_params)
self.model.fit(X.to_numpy(), quantize_cols)
def _generate_data(self, n: int):
syn = self.model.generate(n)
syn = pd.DataFrame(syn, columns=self.ori_cols)
return syn