TVAE

class synthyverse.generators.tvae_generator.TVAEGenerator(embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), l2scale=1e-05, batch_size=500, epochs=300, loss_factor=2, cuda=True, verbose=True, random_state=0)

Bases: BaseGenerator

Registry name: tvae

Tabular Variational Autoencoder (TVAE).

Similar to CTGAN; uses mode-specific normalization for numerical columns.

Uses the implementation from the ctgan package, which is also used in the Synthetic Data Vault.

Paper: “Modeling tabular data using conditional gan” by Xu et al. (2019).

Parameters:
  • embedding_dim (int) – Dimension of the embedding layer. Default: 128.

  • compress_dims (tuple) – Tuple of dimensions for encoder layers. Default: (128, 128).

  • decompress_dims (tuple) – Tuple of dimensions for decoder layers. Default: (128, 128).

  • l2scale (float) – L2 regularization scale. Default: 1e-5.

  • batch_size (int) – Batch size for training. Default: 500.

  • epochs (int) – Number of training epochs. Default: 300.

  • loss_factor (int) – Loss factor for Beta-VAE. Default: 2.

  • cuda (bool) – Whether to use CUDA if available. Default: True.

  • verbose (bool) – Whether to print training progress. Default: True.

  • random_state (int) – Random seed for reproducibility. Default: 0.

Example

>>> import pandas as pd
>>> from synthyverse.generators import TVAEGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator
>>> generator = TVAEGenerator(
...     embedding_dim=128,
...     epochs=300,
...     cuda=True,
...     random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)
fit(X, discrete_features, X_val=None)

Fit the generator to tabular data.

Parameters:
  • X (DataFrame) – Training data in the generator’s input space.

  • discrete_features (list) – Names of categorical/discrete columns in X.

  • X_val (Optional[DataFrame]) – Optional validation data in the same schema as X.

Returns:

The fitted generator.

generate(n)

Generate synthetic tabular data.

Parameters:

n (int) – Number of synthetic rows to generate.

Returns:

Synthetic data in the generator’s model space.

classmethod load(path)

Load a generator persisted with the default pickle layout.

save(path)

Persist the generator state with the default pickle layout.