Source code for synthyverse.generators.ctgan_generator.ct_gan

from ctgan import CTGAN
from ..base import TabularBaseGenerator
import pandas as pd


[docs] class CTGANGenerator(TabularBaseGenerator): """Conditional Tabular GAN (CTGAN). Conditions on discrete columns, and uses mode-specific normalization for numerical columns. Uses the implementation from the ctgan package, which is also used in the Synthetic Data Vault. Paper: "Modeling tabular data using conditional gan" by Xu et al. (2019). Args: embedding_dim (int): Dimension of the embedding layer. Default: 128. generator_dim (tuple): Tuple of dimensions for generator layers. Default: (256, 256). discriminator_dim (tuple): Tuple of dimensions for discriminator layers. Default: (256, 256). generator_lr (float): Learning rate for generator optimizer. Default: 2e-4. generator_decay (float): Weight decay for generator optimizer. Default: 1e-6. discriminator_lr (float): Learning rate for discriminator optimizer. Default: 2e-4. discriminator_decay (float): Weight decay for discriminator optimizer. Default: 1e-6. batch_size (int): Batch size for training. Default: 500. discriminator_steps (int): Number of discriminator steps per generator step. Default: 1. log_frequency (bool): Whether to log training frequency. Default: True. verbose (bool): Whether to print training progress. Default: True. epochs (int): Number of training epochs. Default: 300. pac (int): Number of samples per class for PAC discriminator. Default: 10. cuda (bool): Whether to use CUDA if available. Default: True. random_state (int): Random seed for reproducibility. Default: 0. **kwargs: Additional arguments passed to TabularBaseGenerator. Example: >>> import pandas as pd >>> from synthyverse.generators import CTGANGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator >>> generator = CTGANGenerator( ... epochs=300, ... batch_size=500, ... cuda=True, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000) """ name = "ctgan" def __init__( self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=True, epochs=300, pac=10, cuda=True, random_state: int = 0, **kwargs, ): super().__init__(random_state=random_state, **kwargs) self.epochs = epochs self.batch_size = batch_size self.embedding_dim = embedding_dim self.generator_dim = generator_dim self.discriminator_dim = discriminator_dim self.generator_lr = generator_lr self.generator_decay = generator_decay self.discriminator_lr = discriminator_lr self.discriminator_decay = discriminator_decay self.discriminator_steps = discriminator_steps self.log_frequency = log_frequency self.verbose = verbose self.pac = pac self.cuda = cuda def _fit_model( self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None ): # round batch_size to be divisible by pac self.batch_size = self.batch_size // self.pac * self.pac self.model = CTGAN( embedding_dim=self.embedding_dim, generator_dim=self.generator_dim, discriminator_dim=self.discriminator_dim, generator_lr=self.generator_lr, generator_decay=self.generator_decay, discriminator_lr=self.discriminator_lr, discriminator_decay=self.discriminator_decay, batch_size=self.batch_size, discriminator_steps=self.discriminator_steps, log_frequency=self.log_frequency, verbose=self.verbose, epochs=self.epochs, pac=self.pac, cuda=self.cuda, ) self.model.fit(X, discrete_features) def _generate_data(self, n: int): return self.model.sample(n)