CTGAN

class synthyverse.generators.ctgan_generator.CTGANGenerator(embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=0.0002, generator_decay=1e-06, discriminator_lr=0.0002, discriminator_decay=1e-06, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=True, epochs=300, pac=10, cuda=True, random_state=0, **kwargs)[source]

Bases: TabularBaseGenerator

Conditional Tabular GAN (CTGAN).

Conditions on discrete columns, and uses mode-specific normalization for numerical columns.

Uses the implementation from the ctgan package, which is also used in the Synthetic Data Vault.

Paper: “Modeling tabular data using conditional gan” by Xu et al. (2019).

Parameters:
  • embedding_dim (int) – Dimension of the embedding layer. Default: 128.

  • generator_dim (tuple) – Tuple of dimensions for generator layers. Default: (256, 256).

  • discriminator_dim (tuple) – Tuple of dimensions for discriminator layers. Default: (256, 256).

  • generator_lr (float) – Learning rate for generator optimizer. Default: 2e-4.

  • generator_decay (float) – Weight decay for generator optimizer. Default: 1e-6.

  • discriminator_lr (float) – Learning rate for discriminator optimizer. Default: 2e-4.

  • discriminator_decay (float) – Weight decay for discriminator optimizer. Default: 1e-6.

  • batch_size (int) – Batch size for training. Default: 500.

  • discriminator_steps (int) – Number of discriminator steps per generator step. Default: 1.

  • log_frequency (bool) – Whether to log training frequency. Default: True.

  • verbose (bool) – Whether to print training progress. Default: True.

  • epochs (int) – Number of training epochs. Default: 300.

  • pac (int) – Number of samples per class for PAC discriminator. Default: 10.

  • cuda (bool) – Whether to use CUDA if available. Default: True.

  • random_state (int) – Random seed for reproducibility. Default: 0.

  • **kwargs – Additional arguments passed to TabularBaseGenerator.

Example

>>> import pandas as pd
>>> from synthyverse.generators import CTGANGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator
>>> generator = CTGANGenerator(
...     epochs=300,
...     batch_size=500,
...     cuda=True,
...     random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)