Source code for synthyverse.generators.realtabformer_generator.real_tabformer
from ..base import TabularBaseGenerator
import pandas as pd
from realtabformer import REaLTabFormer
[docs]
class RealTabFormerGenerator(TabularBaseGenerator):
"""Realistic Relational and Tabular Data using Transformers.
Fine-tunes GPT-2 for tabular synthetic data generation.
Uses the realtabformer pypi package implementation.
Paper: "Realtabformer: Generating realistic relational and tabular data using transformers" by Solatorio et al. (2023).
Args:
workspace (str): Directory for storing checkpoints and samples.
epochs (int): Number of training epochs. Default: 1000.
batch_size (int): Batch size for training. Default: 8.
mask_rate (float): Masking rate for training. Default: 0.
early_stopping_patience (int): Patience for early stopping. Default: 5.
early_stopping_threshold (float): Threshold for early stopping. Default: 0.
random_state (int): Random seed for reproducibility. Default: 0.
**kwargs: Additional arguments passed to TabularBaseGenerator.
Example:
>>> import pandas as pd
>>> from synthyverse.generators import RealTabFormerGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator (requires workspace)
>>> generator = RealTabFormerGenerator(
... workspace="./realtabformer_workspace",
... epochs=1000,
... batch_size=8,
... random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)
"""
name = "realtabformer"
needs_workspace = True
def __init__(
self,
workspace: str,
epochs: int = 1000,
batch_size: int = 8,
mask_rate: float = 0,
early_stopping_patience: int = 5,
early_stopping_threshold: float = 0,
random_state: int = 0,
**kwargs,
):
super().__init__(random_state=random_state, **kwargs)
self.model = REaLTabFormer(
model_type="tabular",
checkpoints_dir=workspace,
samples_save_dir=workspace,
epochs=epochs,
batch_size=batch_size,
early_stopping_patience=early_stopping_patience,
early_stopping_threshold=early_stopping_threshold,
mask_rate=mask_rate,
random_state=random_state,
)
def _fit_model(
self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None
):
self.model.fit(X, discrete_features)
def _generate_data(self, n: int):
syn = self.model.sample(n, save_samples=False)
return syn