from ..base import TabularBaseGenerator
import pandas as pd
import torch
import os
import shutil
from .tabddpm_dir.plugin import TabDDPMPlugin
from synthcity.plugins.core.dataloader import GenericDataLoader
[docs]
class TabDDPMGenerator(TabularBaseGenerator):
"""Tabular Denoising Diffusion Probabilistic Model (TabDDPM).
TabDDPM combines continuous diffusion for numerical features with multinomial diffusion for categorical features.
We use the implementation from SynthCity, with some modifications to allow manual specification of discrete features.
Paper: "Tabddpm: Modelling tabular data with diffusion models" by Kotelnikov et al. (2023).
Args:
target_column (str): Name of the target column.
epochs (int): Number of training epochs. Default: 1000.
lr (float): Learning rate. Default: 0.002.
weight_decay (float): Weight decay for optimization. Default: 1e-4.
batch_size (int): Batch size for training. Default: 1024.
num_timesteps (int): Number of diffusion timesteps. Default: 1000.
gaussian_loss_type (str): Type of Gaussian loss. Options: "mse", "kl". Default: "mse".
scheduler (str): Learning rate scheduler type. Options: "cosine", "linear". Default: "cosine".
log_interval (int): Steps between logging. Default: 100.
model_type (str): Type of model architecture. Options: "mlp", "resnet", "tabnet". Default: "mlp".
model_params (dict): Dictionary of model parameters. Default: {"n_layers_hidden": 3, "n_units_hidden": 256, "dropout": 0.0}.
dim_embed (int): Embedding dimension. Default: 128.
random_state (int): Random seed for reproducibility. Default: 0.
**kwargs: Additional arguments passed to TabularBaseGenerator.
Example:
>>> import pandas as pd
>>> from synthyverse.generators import TabDDPMGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator (requires target column)
>>> generator = TabDDPMGenerator(
... target_column="target",
... epochs=1000,
... scheduler="cosine",
... random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)
"""
name = "tabddpm"
needs_target_column = True
def __init__(
self,
target_column: str,
epochs: int = 1000,
lr: float = 0.002,
weight_decay: float = 1e-4,
batch_size: int = 1024,
num_timesteps: int = 1000,
gaussian_loss_type: str = "mse",
scheduler: str = "cosine",
log_interval: int = 100,
model_type: str = "mlp",
model_params: dict = {
"n_layers_hidden": 3,
"n_units_hidden": 256,
"dropout": 0.0,
},
dim_embed: int = 128,
random_state: int = 0,
**kwargs,
):
super().__init__(random_state=random_state, **kwargs)
self.epochs = epochs
self.lr = lr
self.weight_decay = weight_decay
self.batch_size = batch_size
self.num_timesteps = num_timesteps
self.gaussian_loss_type = gaussian_loss_type
self.scheduler = scheduler
self.log_interval = log_interval
self.model_type = model_type
self.model_params = model_params
self.dim_embed = dim_embed
self.target_column = target_column
self.random_state = random_state
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _fit_model(
self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None
):
workspace = "tabddpm_workspace"
os.makedirs(workspace, exist_ok=True)
if self.target_column in discrete_features:
is_classification = True
else:
is_classification = False
loader = GenericDataLoader(data=X, target_column=self.target_column)
self.model = TabDDPMPlugin(
is_classification=is_classification,
n_iter=self.epochs,
lr=self.lr,
weight_decay=self.weight_decay,
batch_size=self.batch_size,
num_timesteps=self.num_timesteps,
gaussian_loss_type=self.gaussian_loss_type,
scheduler=self.scheduler,
device=self.device,
log_interval=self.log_interval,
model_type=self.model_type,
model_params=self.model_params,
dim_embed=self.dim_embed,
random_state=self.random_state,
workspace=workspace,
)
kwargs = {
"discrete_columns": [
x for x in discrete_features if x != self.target_column
],
}
self.model.fit(loader, **kwargs)
# delete workspace
shutil.rmtree(workspace)
def _generate_data(self, n: int):
return self.model.generate(n).dataframe()
def _cleanup_additional_state_for_save(self) -> None:
if not hasattr(self, "model"):
return
# Training curves can be large and are not needed for sampling.
for attr in ("loss_history", "validation_history"):
if hasattr(self.model, attr):
setattr(self.model, attr, None)
if hasattr(self.model, "model"):
for attr in ("loss_history", "val_history"):
if hasattr(self.model.model, attr):
setattr(self.model.model, attr, None)