NRGBoost¶
- class synthyverse.generators.nrgboost_generator.NRGBoostGenerator(num_trees=200, shrinkage=0.15, line_search=True, max_leaves=256, max_ratio_in_leaf=2, min_data_in_leaf=0, initial_uniform_mixture=0.1, categorical_split_one_vs_all=False, feature_frac=1, splitter='best', num_steps=100, num_sampling_rounds=None, temperature=1.0, num_sampling_threads=0, random_state=0, **kwargs)[source]¶
Bases:
TabularBaseGeneratorENeRgy-based Generative Boosting (NRGBoost).
Turns gradient-boosted decision trees into energy-based generative models.
Uses the nrgboost pypi package implementation.
Paper: “NRGBoost: Energy-Based Generative Boosted Trees” by J. Bravo (2024).
- Parameters:
num_trees (int) – Number of trees in the boosted ensemble. Default: 200.
shrinkage (float) – Shrinkage parameter for boosting. Default: 0.15.
line_search (bool) – Whether to use line search for step size optimization. Default: True.
max_leaves (int) – Maximum number of leaves per tree. Default: 256.
max_ratio_in_leaf (float) – Maximum ratio of data / model data per leaf. Default: 2.
min_data_in_leaf (float) – Minimum data points per leaf. Default: 0.
initial_uniform_mixture (float) – Mixture coeficient for the starting point of boosting: 0 means starting from the product of training marginals, 1 means starting from a uniform distribution. Default: 0.1.
categorical_split_one_vs_all (bool) – Whether to use one-vs-all splitting for categorical features. Default: False.
feature_frac (float) – Fraction of features to randomly consider for splitting each node. Default: 1.
splitter (str) – Determines how trees are grown. “best” is best first and “depth” is breadth first. Default: “best”.
num_steps (int) – Number of Gibbs sampling steps. Default: 100.
num_sampling_rounds (Optional[int]) – Include only first n trees when sampling. Default: None.
temperature (float) – Temperature parameter for sampling. Default: 1.0.
num_sampling_threads (int) – Number of threads for parallel sampling (0 for openmp default). Default: 0.
random_state (int) – Random seed for reproducibility. Default: 0.
**kwargs – Additional arguments passed to TabularBaseGenerator.
Example
>>> import pandas as pd >>> from synthyverse.generators import NRGBoostGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator >>> generator = NRGBoostGenerator( ... num_trees=200, ... shrinkage=0.15, ... num_steps=100, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000)