Source code for synthyverse.generators.arf_generator.arf

import pandas as pd
from arfpy.arf import arf
import numpy as np
from ..base import TabularBaseGenerator


[docs] class ARFGenerator(TabularBaseGenerator): """Adversarial Random Forest (ARF). ARF leverages random forests in alternating rounds of generation/discrimination to estimate densities and generate synthetic data. Uses the arfpy package implementation. Paper: "Adversarial random forests for density estimation and generative modeling" by Watson et al. (2023). Args: num_trees (int): Number of trees in the random forests. Default: 20. delta (float): Tolerance parameter for convergence. Default: 0.0. max_iters (int): Maximum number of adversarial iterations. Default: 10. early_stop (bool): Whether to use early stopping. Default: True. verbose (bool): Whether to print training progress. Default: True. min_node_size (int): Minimum leaf node samples in trees. Default: 5. retain_value_ranges (bool): Whether to clip numerical features to training ranges after generation. Default: False. random_state (int): Random seed for reproducibility. Default: 0. **kwargs: Additional arguments passed to TabularBaseGenerator. Example: >>> import pandas as pd >>> from synthyverse.generators import ARFGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator >>> generator = ARFGenerator( ... num_trees=50, ... max_iters=10, ... early_stop=True, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000) """ name = "arf" def __init__( self, num_trees: int = 20, delta: float = 0.0, max_iters: int = 10, early_stop: bool = True, verbose: bool = True, min_node_size: int = 5, retain_value_ranges: bool = False, # whether to retain numerical feature ranges random_state: int = 0, **kwargs, ): super().__init__(random_state=random_state, **kwargs) self.retain_value_ranges = retain_value_ranges self.model_params = { "num_trees": num_trees, "delta": delta, "max_iters": max_iters, "early_stop": early_stop, "verbose": verbose, "min_node_size": min_node_size, "random_state": random_state, } def _fit_model( self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None ): xx = X.copy() xx[discrete_features] = xx[discrete_features].astype(str) self.numerical_features = [ col for col in xx.columns if col not in discrete_features ] if self.retain_value_ranges: self.value_ranges = {} for col in self.numerical_features: self.value_ranges[col] = { "min": xx[col].min(), "max": xx[col].max(), } self.model = arf(xx, **self.model_params) self.model.forde() def _generate_data(self, n: int): syn = self.model.forge(n) if self.retain_value_ranges: for col in self.numerical_features: syn[col] = np.clip( syn[col], self.value_ranges[col]["min"], self.value_ranges[col]["max"], ) return syn