TabDDPM¶
- class synthyverse.generators.tabddpm_generator.TabDDPMGenerator(target_column, epochs=1000, lr=0.002, weight_decay=0.0001, batch_size=1024, num_timesteps=1000, gaussian_loss_type='mse', scheduler='cosine', log_interval=100, model_type='mlp', model_params={'dropout': 0.0, 'n_layers_hidden': 3, 'n_units_hidden': 256}, dim_embed=128, random_state=0, **kwargs)[source]¶
Bases:
TabularBaseGeneratorTabular Denoising Diffusion Probabilistic Model (TabDDPM).
TabDDPM combines continuous diffusion for numerical features with multinomial diffusion for categorical features.
We use the implementation from SynthCity, with some modifications to allow manual specification of discrete features.
Paper: “Tabddpm: Modelling tabular data with diffusion models” by Kotelnikov et al. (2023).
- Parameters:
target_column (str) – Name of the target column.
epochs (int) – Number of training epochs. Default: 1000.
lr (float) – Learning rate. Default: 0.002.
weight_decay (float) – Weight decay for optimization. Default: 1e-4.
batch_size (int) – Batch size for training. Default: 1024.
num_timesteps (int) – Number of diffusion timesteps. Default: 1000.
gaussian_loss_type (str) – Type of Gaussian loss. Options: “mse”, “kl”. Default: “mse”.
scheduler (str) – Learning rate scheduler type. Options: “cosine”, “linear”. Default: “cosine”.
log_interval (int) – Steps between logging. Default: 100.
model_type (str) – Type of model architecture. Options: “mlp”, “resnet”, “tabnet”. Default: “mlp”.
model_params (dict) – Dictionary of model parameters. Default: {“n_layers_hidden”: 3, “n_units_hidden”: 256, “dropout”: 0.0}.
dim_embed (int) – Embedding dimension. Default: 128.
random_state (int) – Random seed for reproducibility. Default: 0.
**kwargs – Additional arguments passed to TabularBaseGenerator.
Example
>>> import pandas as pd >>> from synthyverse.generators import TabDDPMGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator (requires target column) >>> generator = TabDDPMGenerator( ... target_column="target", ... epochs=1000, ... scheduler="cosine", ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000)