Source code for synthyverse.generators.synthpop_generator.synth_pop

import pandas as pd
import numpy as np
from synthpop import DataProcessor, MissingDataHandler

from .synthpop_dir.cart import CARTMethod

from ..base import TabularBaseGenerator


[docs] class SynthpopGenerator(TabularBaseGenerator): """From the popular Synthpop R package. Synthpop uses CART (Classification and Regression Trees) to model conditional marginal distributions. Synthetic data is autoregressively generated by sampling from leaf nodes. Uses the python-synthpop pypi package implementation. Paper: "synthpop: Bespoke creation of synthetic data in R" by Nowok et al. (2016). Args: smoothing (bool): Whether to use smoothing for continuous variables. Default: False. proper (bool): Whether to apply a resampling (proper) step during fitting. Default: False. minibucket (int): Minimum samples in the leaf nodes. Increase to reduce overfitting. Default: 5. tree_params (dict): Dictionary of additional parameters for tree construction (scikit-learn decision trees). Default: {}. random_state (int): Random seed for reproducibility. Default: 0. **kwargs: Additional arguments passed to TabularBaseGenerator. Example: >>> import pandas as pd >>> from synthyverse.generators import SynthpopGenerator >>> >>> # Load data >>> X = pd.read_csv("data.csv") >>> discrete_features = ["category_col"] >>> >>> # Create generator >>> generator = SynthpopGenerator( ... smoothing=True, ... proper=True, ... minibucket=5, ... random_state=42 ... ) >>> >>> # Fit and generate >>> generator.fit(X, discrete_features) >>> X_syn = generator.generate(1000) """ name = "synthpop" def __init__( self, smoothing: bool = False, proper: bool = False, minibucket: int = 5, tree_params: dict = {}, random_state: int = 0, **kwargs, ): super().__init__(random_state=random_state, **kwargs) self.smoothing = smoothing self.proper = proper self.minibucket = minibucket self.random_state = random_state self.tree_params = tree_params def _fit_model( self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None ): metadata = MissingDataHandler().get_column_dtypes(X) self.processor = DataProcessor(metadata) x_pr = self.processor.preprocess(X) self.model = CARTMethod( metadata, smoothing=self.smoothing, proper=self.proper, minibucket=self.minibucket, tree_params=self.tree_params, random_state=self.random_state, ) self.model.fit(x_pr) def _generate_data(self, n: int): syn = self.model.sample(n) syn = self.processor.postprocess(syn) return syn