Forest Diffusion

class synthyverse.generators.forestdiffusion_generator.ForestDiffusionGenerator(target_column, duplicate_K=100, noise_level=50, diffusion_type='flow', n_jobs=-1, max_depth=7, n_estimators=100, eta=0.3, tree_method='hist', reg_alpha=0.0, reg_lambda=0.0, subsample=1.0, num_leaves=31, eps=0.001, beta_min=0.1, beta_max=8, n_z=10, gpu_hist=False, random_state=0, n_batch=1, **kwargs)[source]

Bases: TabularBaseGenerator

Forest Diffusion.

Diffusion model leveraging XGBoost models to estimate the score function.

Uses the ForestDiffusion pypi package implementation. Can be a costly method for large datasets.

Paper: “Generating and imputing tabular data via diffusion and flow-based gradient-boosted trees” by Jolicoeur-Martineau et al. (2024).

Parameters:
  • target_column (str) – Name of the target column.

  • duplicate_K (int) – Number of duplicates for each sample. Default: 100.

  • noise_level (int) – Noise level for diffusion. Default: 50.

  • n_batch (int) – Number of batches to use for XGBoost’s data iterator. Default: 1.

  • diffusion_type (str) – Type of diffusion. Options: “flow”, “vp”. Default: “flow”.

  • n_jobs (int) – Number of parallel jobs (-1 for all cores). Default: -1.

  • max_depth (int) – Maximum depth of trees. Default: 7.

  • n_estimators (int) – Number of tree estimators. Default: 100.

  • eta (float) – Learning rate. Default: 0.3.

  • tree_method (str) – Tree construction method. Options: “hist”, “approx”, “exact”. Default: “hist”.

  • reg_alpha (float) – L1 regularization. Default: 0.0.

  • reg_lambda (float) – L2 regularization. Default: 0.0.

  • subsample (float) – Subsample ratio. Default: 1.0.

  • num_leaves (int) – Number of leaves in trees. Default: 31.

  • eps (float) – Epsilon parameter. Default: 1e-3.

  • beta_min (float) – Minimum beta for diffusion. Default: 0.1.

  • beta_max (float) – Maximum beta for diffusion. Default: 8.

  • n_z (int) – Dimension of latent space. Default: 10.

  • gpu_hist (bool) – Whether to use GPU histogram. Default: False.

  • random_state (int) – Random seed for reproducibility. Default: 0.

  • **kwargs – Additional arguments passed to TabularBaseGenerator.

Example

>>> import pandas as pd
>>> from synthyverse.generators import ForestDiffusionGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator (requires target column)
>>> generator = ForestDiffusionGenerator(
...     target_column="target",
...     diffusion_type="flow",
...     n_jobs=-1,
...     random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)