NRGBoost

class synthyverse.generators.nrgboost_generator.NRGBoostGenerator(num_trees=200, shrinkage=0.15, line_search=True, max_leaves=256, max_ratio_in_leaf=2, min_data_in_leaf=0, initial_uniform_mixture=0.1, categorical_split_one_vs_all=False, feature_frac=1, splitter='best', num_steps=100, num_sampling_rounds=None, temperature=1.0, num_sampling_threads=0, random_state=0, **kwargs)[source]

Bases: TabularBaseGenerator

ENeRgy-based Generative Boosting (NRGBoost).

Turns gradient-boosted decision trees into energy-based generative models.

Uses the nrgboost pypi package implementation.

Paper: “NRGBoost: Energy-Based Generative Boosted Trees” by J. Bravo (2024).

Parameters:
  • num_trees (int) – Number of trees in the boosted ensemble. Default: 200.

  • shrinkage (float) – Shrinkage parameter for boosting. Default: 0.15.

  • line_search (bool) – Whether to use line search for step size optimization. Default: True.

  • max_leaves (int) – Maximum number of leaves per tree. Default: 256.

  • max_ratio_in_leaf (float) – Maximum ratio of data / model data per leaf. Default: 2.

  • min_data_in_leaf (float) – Minimum data points per leaf. Default: 0.

  • initial_uniform_mixture (float) – Mixture coeficient for the starting point of boosting: 0 means starting from the product of training marginals, 1 means starting from a uniform distribution. Default: 0.1.

  • categorical_split_one_vs_all (bool) – Whether to use one-vs-all splitting for categorical features. Default: False.

  • feature_frac (float) – Fraction of features to randomly consider for splitting each node. Default: 1.

  • splitter (str) – Determines how trees are grown. “best” is best first and “depth” is breadth first. Default: “best”.

  • num_steps (int) – Number of Gibbs sampling steps. Default: 100.

  • num_sampling_rounds (Optional[int]) – Include only first n trees when sampling. Default: None.

  • temperature (float) – Temperature parameter for sampling. Default: 1.0.

  • num_sampling_threads (int) – Number of threads for parallel sampling (0 for openmp default). Default: 0.

  • random_state (int) – Random seed for reproducibility. Default: 0.

  • **kwargs – Additional arguments passed to TabularBaseGenerator.

Example

>>> import pandas as pd
>>> from synthyverse.generators import NRGBoostGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator
>>> generator = NRGBoostGenerator(
...     num_trees=200,
...     shrinkage=0.15,
...     num_steps=100,
...     random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)