Source code for synthyverse.generators.tabsyn_generator.tabsyn

from ..base import TabularBaseGenerator
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer
from torch import nn
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from tqdm import tqdm
from .vae import Model_VAE, Encoder_model, Decoder_model
from .diffusion import MLPDiffusion, Model, sample


CE_LOSS_FN = nn.CrossEntropyLoss()


[docs] class TabSynGenerator(TabularBaseGenerator): """TabSyn: a latent diffusion model for tabular data. Trains a VAE to learn a latent representation of tabular data, then fits a diffusion model in that latent space. Paper: "Mixed-type tabular data synthesis with score-based diffusion in latent space" by Zhang et al. (2023). Based on the paper's original implementation: https://github.com/amazon-science/tabsyn/ Args: target_column (str): Name of the target column used for stratified validation splitting. vae_lr (float): Learning rate for VAE training. Default: 1e-3. vae_wd (float): Weight decay used by the VAE optimizer. Default: 0. vae_d_token (int): Token embedding dimension used by the VAE. Default: 4. vae_n_head (int): Number of attention heads in the VAE transformer blocks. Default: 1. vae_factor (int): Expansion factor used in VAE feed-forward layers. Default: 32. vae_num_layers (int): Number of VAE encoder/decoder layers. Default: 2. vae_batch_size (int): Batch size used to train the VAE. Default: 4096. vae_num_epochs (int): Maximum number of VAE epochs. Default: 4000. vae_min_beta (float): Minimum KL coefficient for VAE KL annealing. Default: 1e-5. vae_max_beta (float): Initial/maximum KL coefficient for VAE KL annealing. Default: 1e-2. vae_lambda (float): Multiplicative decay factor applied to beta when validation plateaus. Default: 0.7. diffusion_batch_size (int): Batch size used to train the latent diffusion model. Default: 4096. diffusion_num_epochs (int): Maximum number of diffusion training epochs. Default: 10001. diffusion_dim_t (int): Time embedding dimension used by the diffusion denoiser. Default: 1024. diffusion_lr (float): Learning rate for diffusion training. Default: 1e-3. diffusion_wd (float): Weight decay used by the diffusion optimizer. Default: 0. diffusion_sampling_steps (int): Number of reverse diffusion steps used for generation. Default: 50. diffusion_patience (int): Patience in diffusion training for early stopping. Default: 500. num_workers (int): Number of workers for PyTorch data loaders. Increase to speed up training - but may cause issues when locally training on Windows OS. Default: 0. random_state (int): Random seed for reproducibility. Default: 0. **kwargs: Additional arguments passed to `TabularBaseGenerator`. Example: >>> import pandas as pd >>> from synthyverse.generators import TabSynGenerator >>> >>> # Load data and define categorical columns >>> X = pd.read_csv("data.csv") >>> discrete_features = ["target", "category_col"] >>> >>> # Create generator (requires target column) >>> generator = TabSynGenerator( ... target_column="target", ... vae_num_epochs=100, ... diffusion_num_epochs=500, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000) """ name = "tabsyn" needs_target_column = True needs_validation_set = True def __init__( self, target_column: str, vae_lr: float = 1e-3, vae_wd: float = 0, vae_d_token: int = 4, vae_n_head: int = 1, vae_factor: int = 32, vae_num_layers: int = 2, vae_batch_size: int = 4096, vae_num_epochs: int = 4000, vae_min_beta: float = 1e-5, vae_max_beta: float = 1e-2, vae_lambda: float = 0.7, diffusion_batch_size: int = 4096, diffusion_num_epochs: int = 10_000 + 1, diffusion_dim_t: int = 1024, diffusion_lr: float = 1e-3, diffusion_wd: float = 0, diffusion_sampling_steps: int = 50, diffusion_patience: int = 500, num_workers: int = 0, # number of workers in pytorch dataloader (>0 can give issues on windows) random_state: int = 0, **kwargs, ): super().__init__(random_state=random_state, **kwargs) self.random_state = random_state self.target_column = target_column self.num_workers = num_workers self.vae_params = { "LR": vae_lr, "WD": vae_wd, "D_TOKEN": vae_d_token, "N_HEAD": vae_n_head, "FACTOR": vae_factor, "NUM_LAYERS": vae_num_layers, "BATCH_SIZE": vae_batch_size, "NUM_EPOCHS": vae_num_epochs, "MIN_BETA": vae_min_beta, "MAX_BETA": vae_max_beta, "LAMBDA": vae_lambda, } self.diffusion_params = { "BATCH_SIZE": diffusion_batch_size, "NUM_EPOCHS": diffusion_num_epochs, "DIM_T": diffusion_dim_t, "LR": diffusion_lr, "WD": diffusion_wd, "PATIENCE": diffusion_patience, "SAMPLING_STEPS": diffusion_sampling_steps, } self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _fit_model( self, X: pd.DataFrame, X_val: pd.DataFrame = None, discrete_features: list = [] ): x = X.copy() self.ori_columns = X.columns self.discrete_features = discrete_features if X_val is None: stratify = ( x[self.target_column] if self.target_column in discrete_features else None ) x, x_val = train_test_split( x, test_size=0.1, random_state=self.random_state, stratify=stratify ) else: x_val = X_val.copy() # scale numerical features (categoricals are already integer encoded by the basegenerator) self.scaler = QuantileTransformer( output_distribution="normal", n_quantiles=max(min(len(x) // 30, 1000), 10), subsample=int(1e9), random_state=self.random_state, ) self.numerical_features = [ col for col in x.columns if col not in self.discrete_features ] self.scaler.fit(x[self.numerical_features]) x[self.numerical_features] = self.scaler.transform(x[self.numerical_features]) x_val[self.numerical_features] = self.scaler.transform( x_val[self.numerical_features] ) # filter out validation obs with cats not seen in training data drop_mask = [] for col_idx in range(x[discrete_features].shape[1]): drop_mask.append( x_val[discrete_features].values[:, col_idx] > x[discrete_features].values[:, col_idx].max() ) drop_mask = np.column_stack(drop_mask) drop_mask = np.sum(drop_mask, axis=1) > 0 x_val = x_val[~drop_mask] print( f"Number of observations dropped in validation set due to unseen categories: {sum(drop_mask)}" ) # train the VAE train_z = self._train_vae(x, x_val) # no longer need original data del x, x_val # train the diffusion model self.model = self._train_diffusion(train_z) def _generate_data(self, n: int): self.model.eval() with torch.no_grad(): x = sample( self.model.denoise_fn_D, n, self.sample_dim, self.device, num_steps=self.diffusion_params["SAMPLING_STEPS"], ) x = x * 2 + self.train_z_mean.to(self.device) # x = x.float().cpu().numpy() x = x.reshape(x.shape[0], -1, self.token_dim) norm_input = self.pre_decoder(x.float()) x_hat_num, x_hat_cat = norm_input x_cat = [] for pred in x_hat_cat: x_cat.append(pred.argmax(dim=-1)) syn_num = x_hat_num.detach().cpu().numpy() syn_cat = torch.stack(x_cat).t().cpu().numpy() syn_cat = np.int64(syn_cat) syn = np.concatenate([syn_num, syn_cat], axis=1) cols = self.numerical_features + self.discrete_features syn = pd.DataFrame(syn, columns=cols) syn = syn[self.ori_columns] syn[self.numerical_features] = self.scaler.inverse_transform( syn[self.numerical_features] ) return syn def _train_vae(self, x, x_val): d_numerical = len(x.columns.tolist()) - len(self.discrete_features) categories = [] for col in self.discrete_features: categories.append(x[col].nunique()) vae = Model_VAE( self.vae_params["NUM_LAYERS"], d_numerical, categories, self.vae_params["D_TOKEN"], n_head=self.vae_params["N_HEAD"], factor=self.vae_params["FACTOR"], bias=True, ).to(self.device) pre_encoder = Encoder_model( self.vae_params["NUM_LAYERS"], d_numerical, categories, self.vae_params["D_TOKEN"], self.vae_params["N_HEAD"], self.vae_params["FACTOR"], bias=True, ).to(self.device) self.pre_decoder = Decoder_model( self.vae_params["NUM_LAYERS"], d_numerical, categories, self.vae_params["D_TOKEN"], self.vae_params["N_HEAD"], self.vae_params["FACTOR"], bias=True, ).to(self.device) pre_encoder.eval() self.pre_decoder.eval() optimizer = torch.optim.Adam( vae.parameters(), lr=self.vae_params["LR"], weight_decay=self.vae_params["WD"], ) scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.95, patience=10, verbose=True ) current_lr = optimizer.param_groups[0]["lr"] patience = 0 batch_size = min(self.vae_params["BATCH_SIZE"], len(x)) beta = self.vae_params["MAX_BETA"] x_tr_num = x[ [col for col in x.columns if col not in self.discrete_features] ].values x_tr_num = torch.from_numpy(x_tr_num).float() x_tr_cat = x[self.discrete_features].values x_tr_cat = torch.from_numpy(x_tr_cat).long() train_loader = DataLoader( TensorDataset(x_tr_num, x_tr_cat), batch_size=batch_size, shuffle=True, num_workers=self.num_workers, # 4 ) x_val_num = x_val[ [col for col in x.columns if col not in self.discrete_features] ].values x_val_num = torch.from_numpy(x_val_num).float() x_val_cat = x_val[self.discrete_features].values x_val_cat = torch.from_numpy(x_val_cat).long() x_val_num = x_val_num.to(self.device) x_val_cat = x_val_cat.to(self.device) best_val_loss = float("inf") pbar = tqdm(range(self.vae_params["NUM_EPOCHS"])) for _ in pbar: vae.train() for batch_num, batch_cat in train_loader: batch_num = batch_num.to(self.device) batch_cat = batch_cat.to(self.device) optimizer.zero_grad() Recon_X_num, Recon_X_cat, mu_z, std_z = vae(batch_num, batch_cat) loss_mse, loss_ce, loss_kld, train_acc = self._compute_loss( batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z ) loss = loss_mse + loss_ce + beta * loss_kld loss.backward() optimizer.step() vae.eval() with torch.no_grad(): Recon_X_num, Recon_X_cat, mu_z, std_z = vae(x_val_num, x_val_cat) val_mse_loss, val_ce_loss, val_kl_loss, val_acc = self._compute_loss( x_val_num, x_val_cat, Recon_X_num, Recon_X_cat, mu_z, std_z ) val_loss = val_mse_loss.item() * 0 + val_ce_loss.item() scheduler.step(val_loss) new_lr = optimizer.param_groups[0]["lr"] if new_lr != current_lr: current_lr = new_lr if val_loss < best_val_loss: best_val_loss = val_loss best_vae = vae.state_dict() patience = 0 else: patience += 1 if patience == 10: if beta > self.vae_params["MIN_BETA"]: beta = beta * self.vae_params["LAMBDA"] vae.load_state_dict(best_vae) # getting latent embeddings with torch.no_grad(): pre_encoder.load_weights(vae) self.pre_decoder.load_weights(vae) pre_encoder.to(self.device), self.pre_decoder.to(self.device) pre_encoder.eval(), self.pre_decoder.eval() train_z = [] for batch_num, batch_cat in train_loader: batch_num = batch_num.to(self.device) batch_cat = batch_cat.to(self.device) train_z.append(pre_encoder(batch_num, batch_cat).detach().cpu().numpy()) train_z = np.concatenate(train_z) return train_z def _train_diffusion(self, train_z): train_z = torch.from_numpy(train_z).float() train_z = train_z[:, 1:, :] B, num_tokens, self.token_dim = train_z.shape in_dim = num_tokens * self.token_dim train_z = train_z.view(B, in_dim) self.sample_dim = train_z.shape[1] self.train_z_mean = train_z.mean(0) in_dim = train_z.shape[1] mean, std = train_z.mean(0), train_z.std(0) train_data = (train_z - mean) / 2 train_loader = DataLoader( train_data, batch_size=self.diffusion_params["BATCH_SIZE"], shuffle=True, num_workers=self.num_workers, ) denoise_fn = MLPDiffusion(in_dim, self.diffusion_params["DIM_T"]).to( self.device ) model = Model(denoise_fn=denoise_fn, hid_dim=train_data.shape[1]).to( self.device ) optimizer = torch.optim.Adam( model.parameters(), lr=self.diffusion_params["LR"], weight_decay=self.diffusion_params["WD"], ) scheduler = ReduceLROnPlateau( optimizer, mode="min", factor=0.9, patience=20, verbose=True ) model.train() best_loss = float("inf") patience = 0 for epoch in tqdm(range(self.diffusion_params["NUM_EPOCHS"])): batch_loss = 0.0 len_input = 0 for inputs in train_loader: inputs = inputs.to(self.device) loss = model(inputs) loss = loss.mean() batch_loss += loss.item() * len(inputs) len_input += len(inputs) optimizer.zero_grad() loss.backward() optimizer.step() curr_loss = batch_loss / len_input scheduler.step(curr_loss) if curr_loss < best_loss: best_loss = curr_loss patience = 0 best_model = model.state_dict() else: patience += 1 if patience == 500: print("Early stopping") break model.load_state_dict(best_model) return model def _compute_loss(self, X_num, X_cat, Recon_X_num, Recon_X_cat, mu_z, logvar_z): mse_loss = (X_num - Recon_X_num).pow(2).mean() ce_loss = 0 acc = 0 total_num = 0 for idx, x_cat in enumerate(Recon_X_cat): if x_cat is not None: ce_loss += CE_LOSS_FN(x_cat, X_cat[:, idx]) x_hat = x_cat.argmax(dim=-1) acc += (x_hat == X_cat[:, idx]).float().sum() total_num += x_hat.shape[0] ce_loss /= idx + 1 acc /= total_num # loss = mse_loss + ce_loss temp = 1 + logvar_z - mu_z.pow(2) - logvar_z.exp() loss_kld = -0.5 * torch.mean(temp.mean(-1).mean()) return mse_loss, ce_loss, loss_kld, acc