Source code for synthyverse.generators.synthpop_generator.synth_pop
import pandas as pd
import numpy as np
from synthpop import DataProcessor, MissingDataHandler
from .synthpop_dir.cart import CARTMethod
from ..base import TabularBaseGenerator
[docs]
class SynthpopGenerator(TabularBaseGenerator):
"""From the popular Synthpop R package.
Synthpop uses CART (Classification and Regression Trees) to model conditional marginal distributions.
Synthetic data is autoregressively generated by sampling from leaf nodes.
Uses the python-synthpop pypi package implementation.
Paper: "synthpop: Bespoke creation of synthetic data in R" by Nowok et al. (2016).
Args:
smoothing (bool): Whether to use smoothing for continuous variables. Default: False.
proper (bool): Whether to apply a resampling (proper) step during fitting. Default: False.
minibucket (int): Minimum samples in the leaf nodes. Increase to reduce overfitting. Default: 5.
tree_params (dict): Dictionary of additional parameters for tree construction (scikit-learn decision trees). Default: {}.
random_state (int): Random seed for reproducibility. Default: 0.
**kwargs: Additional arguments passed to TabularBaseGenerator.
Example:
>>> import pandas as pd
>>> from synthyverse.generators import SynthpopGenerator
>>>
>>> # Load data
>>> X = pd.read_csv("data.csv")
>>> discrete_features = ["category_col"]
>>>
>>> # Create generator
>>> generator = SynthpopGenerator(
... smoothing=True,
... proper=True,
... minibucket=5,
... random_state=42
... )
>>>
>>> # Fit and generate
>>> generator.fit(X, discrete_features)
>>> X_syn = generator.generate(1000)
"""
name = "synthpop"
def __init__(
self,
smoothing: bool = False,
proper: bool = False,
minibucket: int = 5,
tree_params: dict = {},
random_state: int = 0,
**kwargs,
):
super().__init__(random_state=random_state, **kwargs)
self.smoothing = smoothing
self.proper = proper
self.minibucket = minibucket
self.random_state = random_state
self.tree_params = tree_params
def _fit_model(
self, X: pd.DataFrame, discrete_features: list, X_val: pd.DataFrame = None
):
metadata = MissingDataHandler().get_column_dtypes(X)
self.processor = DataProcessor(metadata)
x_pr = self.processor.preprocess(X)
self.model = CARTMethod(
metadata,
smoothing=self.smoothing,
proper=self.proper,
minibucket=self.minibucket,
tree_params=self.tree_params,
random_state=self.random_state,
)
self.model.fit(x_pr)
def _generate_data(self, n: int):
syn = self.model.sample(n)
syn = self.processor.postprocess(syn)
return syn