CTGAN¶
- class synthyverse.generators.ctgan_generator.CTGANGenerator(embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=0.0002, generator_decay=1e-06, discriminator_lr=0.0002, discriminator_decay=1e-06, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=True, epochs=300, pac=10, cuda=True, random_state=0)¶
Bases:
BaseGeneratorRegistry name:
ctganConditional Tabular GAN (CTGAN).
Conditions on discrete columns, and uses mode-specific normalization for numerical columns.
Uses the implementation from the ctgan package, which is also used in the Synthetic Data Vault.
Paper: “Modeling tabular data using conditional gan” by Xu et al. (2019).
- Parameters:
embedding_dim (int) – Dimension of the embedding layer. Default: 128.
generator_dim (tuple) – Tuple of dimensions for generator layers. Default: (256, 256).
discriminator_dim (tuple) – Tuple of dimensions for discriminator layers. Default: (256, 256).
generator_lr (float) – Learning rate for generator optimizer. Default: 2e-4.
generator_decay (float) – Weight decay for generator optimizer. Default: 1e-6.
discriminator_lr (float) – Learning rate for discriminator optimizer. Default: 2e-4.
discriminator_decay (float) – Weight decay for discriminator optimizer. Default: 1e-6.
batch_size (int) – Batch size for training. Default: 500.
discriminator_steps (int) – Number of discriminator steps per generator step. Default: 1.
log_frequency (bool) – Whether to log training frequency. Default: True.
verbose (bool) – Whether to print training progress. Default: True.
epochs (int) – Number of training epochs. Default: 300.
pac (int) – Number of samples per class for PAC discriminator. Default: 10.
cuda (bool) – Whether to use CUDA if available. Default: True.
random_state (int) – Random seed for reproducibility. Default: 0.
Example
>>> import pandas as pd >>> from synthyverse.generators import CTGANGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator >>> generator = CTGANGenerator( ... epochs=300, ... batch_size=500, ... cuda=True, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000)
- fit(X, discrete_features, X_val=None)¶
Fit the generator to tabular data.
- Parameters:
X (
DataFrame) – Training data in the generator’s input space.discrete_features (
list) – Names of categorical/discrete columns inX.X_val (
Optional[DataFrame]) – Optional validation data in the same schema asX.
- Returns:
The fitted generator.
- generate(n)¶
Generate synthetic tabular data.
- Parameters:
n (
int) – Number of synthetic rows to generate.- Returns:
Synthetic data in the generator’s model space.
- classmethod load(path)¶
Load a generator persisted with the default pickle layout.
- save(path)¶
Persist the generator state with the default pickle layout.