Source code for synthyverse.generators.tabddpm_generator.tabddpm

from ..base import TabularBaseGenerator
import pandas as pd
import torch
import os
import shutil

from .tabddpm_dir.plugin import TabDDPMPlugin
from synthcity.plugins.core.dataloader import GenericDataLoader


[docs] class TabDDPMGenerator(TabularBaseGenerator): """Tabular Denoising Diffusion Probabilistic Model (TabDDPM). TabDDPM combines continuous diffusion for numerical features with multinomial diffusion for categorical features. We use the implementation from SynthCity, with some modifications to allow manual specification of discrete features. Paper: "Tabddpm: Modelling tabular data with diffusion models" by Kotelnikov et al. (2023). Args: target_column (str): Name of the target column. epochs (int): Number of training epochs. Default: 1000. lr (float): Learning rate. Default: 0.002. weight_decay (float): Weight decay for optimization. Default: 1e-4. batch_size (int): Batch size for training. Default: 1024. num_timesteps (int): Number of diffusion timesteps. Default: 1000. gaussian_loss_type (str): Type of Gaussian loss. Options: "mse", "kl". Default: "mse". scheduler (str): Learning rate scheduler type. Options: "cosine", "linear". Default: "cosine". log_interval (int): Steps between logging. Default: 100. model_type (str): Type of model architecture. Options: "mlp", "resnet", "tabnet". Default: "mlp". model_params (dict): Dictionary of model parameters. Default: {"n_layers_hidden": 3, "n_units_hidden": 256, "dropout": 0.0}. dim_embed (int): Embedding dimension. Default: 128. random_state (int): Random seed for reproducibility. Default: 0. **kwargs: Additional arguments passed to TabularBaseGenerator. Example: >>> import pandas as pd >>> from synthyverse.generators import TabDDPMGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator (requires target column) >>> generator = TabDDPMGenerator( ... target_column="target", ... epochs=1000, ... scheduler="cosine", ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000) """ name = "tabddpm" needs_target_column = True def __init__( self, target_column: str, epochs: int = 1000, lr: float = 0.002, weight_decay: float = 1e-4, batch_size: int = 1024, num_timesteps: int = 1000, gaussian_loss_type: str = "mse", scheduler: str = "cosine", log_interval: int = 100, model_type: str = "mlp", model_params: dict = { "n_layers_hidden": 3, "n_units_hidden": 256, "dropout": 0.0, }, dim_embed: int = 128, random_state: int = 0, **kwargs, ): super().__init__(random_state=random_state, **kwargs) self.epochs = epochs self.lr = lr self.weight_decay = weight_decay self.batch_size = batch_size self.num_timesteps = num_timesteps self.gaussian_loss_type = gaussian_loss_type self.scheduler = scheduler self.log_interval = log_interval self.model_type = model_type self.model_params = model_params self.dim_embed = dim_embed self.target_column = target_column self.random_state = random_state self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _fit_model( self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None ): workspace = "tabddpm_workspace" os.makedirs(workspace, exist_ok=True) if self.target_column in discrete_features: is_classification = True else: is_classification = False loader = GenericDataLoader(data=X, target_column=self.target_column) self.model = TabDDPMPlugin( is_classification=is_classification, n_iter=self.epochs, lr=self.lr, weight_decay=self.weight_decay, batch_size=self.batch_size, num_timesteps=self.num_timesteps, gaussian_loss_type=self.gaussian_loss_type, scheduler=self.scheduler, device=self.device, log_interval=self.log_interval, model_type=self.model_type, model_params=self.model_params, dim_embed=self.dim_embed, random_state=self.random_state, workspace=workspace, ) kwargs = { "discrete_columns": [ x for x in discrete_features if x != self.target_column ], } self.model.fit(loader, **kwargs) # delete workspace shutil.rmtree(workspace) def _generate_data(self, n: int): return self.model.generate(n).dataframe() def _cleanup_additional_state_for_save(self) -> None: if not hasattr(self, "model"): return # Training curves can be large and are not needed for sampling. for attr in ("loss_history", "validation_history"): if hasattr(self.model, attr): setattr(self.model, attr, None) if hasattr(self.model, "model"): for attr in ("loss_history", "val_history"): if hasattr(self.model.model, attr): setattr(self.model.model, attr, None)