ModelSamplerEstimator#
- class pymc_marketing.pytensor_utils.ModelSamplerEstimator(*, tune=1000, draws=1000, chains=1, sequential_chains=1, seed=None)[source]#
Estimate computational characteristics of a PyMC model using JAX/NumPyro.
This utility measures the average evaluation time of the model’s logp and gradients and estimates the number of integrator steps taken by NUTS during warmup + sampling. It then compiles the information into a single-row pandas DataFrame with helpful metadata to guide planning and benchmarking.
- Parameters:
- tune
int, default 1000 Number of warmup iterations to use when estimating NUTS steps.
- draws
int, default 1000 Number of sampling iterations to use when estimating NUTS steps.
- chains
int, default 1 Intended number of chains (metadata only; not used in JAX runs here).
- sequential_chains
int, default 1 Number of chains expected to run sequentially on the target environment. Used to scale the wall-clock time estimate.
- seed
int|None, defaultNone Random seed used for the step estimation runs.
- tune
Examples
est = ModelSamplerEstimator( tune=1000, draws=1000, chains=4, sequential_chains=1, seed=1 ) df = est.run(model) print(df)
Methods
ModelSamplerEstimator.__init__(*[, tune, ...])Estimate average evaluation time (seconds) of logp+dlogp using JAX.
Estimate total number of NUTS steps during warmup + sampling using NumPyro.
ModelSamplerEstimator.run(model)Execute the estimation pipeline and return a single-row DataFrame.
Attributes
default_mcmc_kwargsDefault keyword arguments for a NumPyro MCMC runner.
default_nuts_kwargsDefault keyword arguments for a NumPyro NUTS kernel.