Source code for synthyverse.generators.tabargn_generator.tabargn

from ..base import TabularBaseGenerator

import pandas as pd
from mostlyai import engine


[docs] class TabARGNGenerator(TabularBaseGenerator): """Tabular AutoRegressive Generative Network (TabARGN). TabARGN uses masked transformers for tabular data generation. We use the implementation from the MostlyAI engine. Paper: "TabularARGN: A Flexible and Efficient Auto-Regressive Framework for Generating High-Fidelity Synthetic Data" by Tiwald et al. (2025). Args: workspace (str): Directory for storing intermediate files. max_epochs (int): Maximum number of training epochs. Default: 100. random_state (int): Random seed for reproducibility. Default: 0. **kwargs: Additional arguments passed to TabularBaseGenerator. Example: >>> import pandas as pd >>> from synthyverse.generators import TabARGNGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator (requires workspace) >>> generator = TabARGNGenerator( ... workspace="./tabargn_workspace", ... max_epochs=100, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000) """ name = "tabargn" needs_workspace = True def __init__( self, workspace: str, max_epochs: int = 100, random_state: int = 0, **kwargs, ): super().__init__(random_state=random_state, **kwargs) self.workspace = workspace self.max_epochs = max_epochs def _fit_model( self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None ): # set up workspace and default logging engine.init_logging() # execute the engine steps engine.split( # split data as PQT files for `trn` + `val` to `{ws}/OriginalData/tgt-data` workspace_dir=self.workspace, tgt_data=X, model_type="TABULAR", ) engine.analyze(workspace_dir=self.workspace) engine.encode(workspace_dir=self.workspace) engine.train(workspace_dir=self.workspace, max_epochs=self.max_epochs) def _generate_data(self, n: int): engine.generate(workspace_dir=self.workspace, sample_size=n) syn = pd.read_parquet(f"{self.workspace}/SyntheticData") return syn