ModelSamplerEstimator#

class pymc_marketing.pytensor_utils.ModelSamplerEstimator(*, tune=1000, draws=1000, chains=1, sequential_chains=1, seed=None)[source]#

Estimate computational characteristics of a PyMC model using JAX/NumPyro.

This utility measures the average evaluation time of the model’s logp and gradients and estimates the number of integrator steps taken by NUTS during warmup + sampling. It then compiles the information into a single-row pandas DataFrame with helpful metadata to guide planning and benchmarking.

Parameters:
tuneint, default 1000

Number of warmup iterations to use when estimating NUTS steps.

drawsint, default 1000

Number of sampling iterations to use when estimating NUTS steps.

chainsint, default 1

Intended number of chains (metadata only; not used in JAX runs here).

sequential_chainsint, default 1

Number of chains expected to run sequentially on the target environment. Used to scale the wall-clock time estimate.

seedint | None, default None

Random seed used for the step estimation runs.

Examples

est = ModelSamplerEstimator(
    tune=1000, draws=1000, chains=4, sequential_chains=1, seed=1
)
df = est.run(model)
print(df)

Methods

ModelSamplerEstimator.__init__(*[, tune, ...])

ModelSamplerEstimator.estimate_model_eval_time(model)

Estimate average evaluation time (seconds) of logp+dlogp using JAX.

ModelSamplerEstimator.estimate_num_steps_sampling(...)

Estimate total number of NUTS steps during warmup + sampling using NumPyro.

ModelSamplerEstimator.run(model)

Execute the estimation pipeline and return a single-row DataFrame.

Attributes

default_mcmc_kwargs

Default keyword arguments for a NumPyro MCMC runner.

default_nuts_kwargs

Default keyword arguments for a NumPyro NUTS kernel.