Forest Diffusion¶
- class synthyverse.generators.forestdiffusion_generator.ForestDiffusionGenerator(target_column, duplicate_K=100, noise_level=50, diffusion_type='flow', n_jobs=-1, max_depth=7, n_estimators=100, eta=0.3, tree_method='hist', reg_alpha=0.0, reg_lambda=0.0, subsample=1.0, num_leaves=31, eps=0.001, beta_min=0.1, beta_max=8, n_z=10, gpu_hist=False, random_state=0, n_batch=1, **kwargs)[source]¶
Bases:
TabularBaseGeneratorForest Diffusion.
Diffusion model leveraging XGBoost models to estimate the score function.
Uses the ForestDiffusion pypi package implementation. Can be a costly method for large datasets.
Paper: “Generating and imputing tabular data via diffusion and flow-based gradient-boosted trees” by Jolicoeur-Martineau et al. (2024).
- Parameters:
target_column (str) – Name of the target column.
duplicate_K (int) – Number of duplicates for each sample. Default: 100.
noise_level (int) – Noise level for diffusion. Default: 50.
n_batch (int) – Number of batches to use for XGBoost’s data iterator. Default: 1.
diffusion_type (str) – Type of diffusion. Options: “flow”, “vp”. Default: “flow”.
n_jobs (int) – Number of parallel jobs (-1 for all cores). Default: -1.
max_depth (int) – Maximum depth of trees. Default: 7.
n_estimators (int) – Number of tree estimators. Default: 100.
eta (float) – Learning rate. Default: 0.3.
tree_method (str) – Tree construction method. Options: “hist”, “approx”, “exact”. Default: “hist”.
reg_alpha (float) – L1 regularization. Default: 0.0.
reg_lambda (float) – L2 regularization. Default: 0.0.
subsample (float) – Subsample ratio. Default: 1.0.
num_leaves (int) – Number of leaves in trees. Default: 31.
eps (float) – Epsilon parameter. Default: 1e-3.
beta_min (float) – Minimum beta for diffusion. Default: 0.1.
beta_max (float) – Maximum beta for diffusion. Default: 8.
n_z (int) – Dimension of latent space. Default: 10.
gpu_hist (bool) – Whether to use GPU histogram. Default: False.
random_state (int) – Random seed for reproducibility. Default: 0.
**kwargs – Additional arguments passed to TabularBaseGenerator.
Example
>>> import pandas as pd >>> from synthyverse.generators import ForestDiffusionGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator (requires target column) >>> generator = ForestDiffusionGenerator( ... target_column="target", ... diffusion_type="flow", ... n_jobs=-1, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000)