Introduction
bayesim provides a framework for Bayesian simulation studies through
its extensible fitter interface. The package includes built-in fitters
like MockFitter() and BrmsFitter(), but you
may need to create custom fitters to:
-
Interface with different backends: Connect to Stan
via
cmdstanr, or other probabilistic programming languages - Implement custom models: Fit specialized models not covered by existing fitters
- Add preprocessing: Transform data or apply custom initialization strategies
- Control inference: Implement custom sampling algorithms or diagnostics
This vignette explains the fitter interface and demonstrates creating a cmdstanr fitter from scratch.
The Fitter Interface
The Abstract Fitter S7 Class
All fitters in bayesim extend the abstract Fitter S7
class. This class defines the interface contract that all fitters must
implement:
Properties
| Property | Type | Description |
|---|---|---|
name |
character |
Identifier for the fitter (e.g., “stan”, “brms”) |
supports_predictions |
logical |
Whether the fitter supports predict_fit()
|
supports_log_lik |
logical |
Whether the fitter supports log_lik()
|
supports_loo |
logical |
Whether the fitter supports loo()
|
These properties allow metrics and simulation code to query fitter capabilities before attempting operations that may not be supported.
Required Generics
The following S7 generics must have methods implemented for your fitter class:
fit(fitter, data_bundle, fit_spec, seed, task_ctx)
The main fitting method. Returns a bayesim_fit_result S3
object.
Parameters: - fitter: Your fitter
object - data_bundle: List containing train,
test, response, true_params,
vars_of_interest - fit_spec: Single-row data
frame with model specification from fit_grid - seed:
Integer seed for reproducibility - task_ctx: List with
task_id, data_idx, fit_idx,
rep_idx
Returns: A bayesim_fit_result with
fields: - success: Logical indicating if fit succeeded -
fit: The backend fit object (or NULL) - draws:
Matrix of posterior draws (S × P) with column names -
diagnostics: Named list of diagnostic values -
timing: List with total, warmup,
sample times - warnings: Character vector of
captured warnings - error: NULL or error condition
extract_draws(fitter, fit_result, variables = NULL)
Extract posterior draws from a fitted model.
Returns: Matrix with dimensions S × P (draws × parameters) with column names matching variable names.
predict_fit(fitter, fit_result, newdata = NULL, seed = NULL)
Generate predictions from a fitted model.
Returns: List containing: -
predicted_mean: Vector of mean predictions (length N) -
predicted_samples: Matrix of posterior predictive samples
(N × S) - predicted_sd: Vector of prediction standard
deviations (length N)
log_lik(fitter, fit_result, newdata = NULL)
Compute pointwise log-likelihood values.
Returns: Matrix with dimensions N × S (observations × draws).
loo(fitter, fit_result)
Compute leave-one-out cross-validation.
Returns: List with: - elpd: Expected
log predictive density - p_loo: Effective number of
parameters - elpd_se: Standard error of ELPD -
pareto_k: Pareto k diagnostic values
diagnostics(fitter, fit_result)
Extract convergence and fit diagnostics.
Returns: Named list with diagnostics such as: -
rhat_max: Maximum R-hat statistic - ess_bulk:
Minimum bulk effective sample size - ess_tail: Minimum tail
effective sample size - divergent: Number of divergent
transitions - max_treedepth: Number of max treedepth
warnings
Example: cmdstanr Fitter with Configurable Priors
This section demonstrates a cmdstanr fitter for simple linear
regression. The design pattern is that fit_spec
controls model aspects—in this case, the prior distributions.
This allows you to run simulation studies comparing different prior
specifications.
Why cmdstanr?
cmdstanr is the R interface to Stan:
- Direct Stan control: Compile and run Stan models directly
- Fast: Uses CmdStan, the command-line interface to Stan
-
Active development: Better integration with the
posteriorpackage - Transparent: Clear separation between model, data, and sampling
The Stan Model
We use a simple linear regression with configurable
priors passed as data. This is the key pattern: priors come
from fit_spec, not hardcoded in the model:
# Define the Stan model as a character string
# Note: Priors are passed as DATA, allowing fit_spec to control them
stan_model_code <- "
data {
int<lower=0> N;
vector[N] y;
vector[N] x;
// Prior parameters - controlled by fit_spec
real prior_intercept_mean;
real<lower=0> prior_intercept_sd;
real prior_slope_mean;
real<lower=0> prior_slope_sd;
}
parameters {
real alpha; // Intercept
real beta; // Slope
real<lower=0> sigma; // Error SD (fixed prior)
}
model {
// Configurable priors from data
alpha ~ normal(prior_intercept_mean, prior_intercept_sd);
beta ~ normal(prior_slope_mean, prior_slope_sd);
sigma ~ exponential(1); // Fixed prior for sigma
// Likelihood
y ~ normal(alpha + beta * x, sigma);
}
generated quantities {
vector[N] log_lik;
vector[N] y_rep;
for (n in 1:N) {
log_lik[n] = normal_lpdf(y[n] | alpha + beta * x[n], sigma);
y_rep[n] = normal_rng(alpha + beta * x[n], sigma);
}
}
"CmdStanFitter Class Definition
The S7 class holds MCMC settings and caches the compiled model:
library(S7)
library(bayesim)
# Define the CmdStanFitter class
CmdStanFitter <- new_class(
"CmdStanFitter",
parent = Fitter,
properties = list(
# Required Fitter properties
name = new_property(class_character, default = "cmdstanr"),
supports_predictions = new_property(class_logical, default = TRUE),
supports_log_lik = new_property(class_logical, default = TRUE),
supports_loo = new_property(class_logical, default = TRUE),
# MCMC settings
chains = new_property(class_integer, default = 4L),
iter_sampling = new_property(class_integer, default = 1000L),
iter_warmup = new_property(class_integer, default = 1000L),
adapt_delta = new_property(class_double, default = 0.95),
max_treedepth = new_property(class_integer, default = 10L),
# Model caching
stan_file = new_property(class_character, default = NULL),
compiled_model = new_property(class_any, default = NULL)
)
)The fit() Method
This is where the magic happens: we build stan_data from
both data_bundle (the data)
and fit_spec (the priors):
method(fit, CmdStanFitter) <- function(
fitter,
data_bundle,
fit_spec,
seed,
task_ctx
) {
start_time <- Sys.time()
warnings <- character()
result <- tryCatch({
# Step 1: Compile model (with caching)
if (is.null(fitter@compiled_model)) {
# Write model to temp file if not already done
if (is.null(fitter@stan_file)) {
stan_file <- tempfile(fileext = ".stan")
writeLines(stan_model_code, stan_file)
} else {
stan_file <- fitter@stan_file
}
model <- cmdstanr::cmdstan_model(stan_file)
} else {
model <- fitter@compiled_model
}
# Step 2: Prepare data - combine data_bundle with fit_spec
train_data <- data_bundle$train
response <- data_bundle$response
predictor <- setdiff(names(train_data), response)
if (length(predictor) != 1) {
stop("CmdStanFitter expects exactly one predictor")
}
# Build stan_data from BOTH data_bundle AND fit_spec
# KEY PATTERN: fit_spec controls model aspects (here: priors)
stan_data <- list(
# From data_bundle
N = nrow(train_data),
y = train_data[[response]],
x = train_data[[predictor]],
# From fit_spec - these control the priors!
prior_intercept_mean = fit_spec$prior_intercept_mean,
prior_intercept_sd = fit_spec$prior_intercept_sd,
prior_slope_mean = fit_spec$prior_slope_mean,
prior_slope_sd = fit_spec$prior_slope_sd
)
# Step 3: Run MCMC
fit_obj <- withCallingHandlers({
model$sample(
data = stan_data,
chains = fitter@chains,
iter_sampling = fitter@iter_sampling,
iter_warmup = fitter@iter_warmup,
seed = seed,
adapt_delta = fitter@adapt_delta,
max_treedepth = fitter@max_treedepth,
refresh = 0 # Suppress verbose output
)
}, warning = function(w) {
warnings <<- c(warnings, conditionMessage(w))
invokeRestart("muffleWarning")
})
# Step 4: Extract draws using posterior package
draws_df <- posterior::as_draws_df(fit_obj$draws())
param_names <- c("alpha", "beta", "sigma")
param_draws <- as.matrix(draws_df[, param_names])
# Step 5: Get diagnostics from cmdstanr
diag_summary <- fit_obj$summary(
variables = param_names,
posterior::rhat,
posterior::ess_bulk,
posterior::ess_tail
)
sampler_diag <- fit_obj$sampler_diagnostics(format = "df")
end_time <- Sys.time()
total_time <- as.numeric(difftime(end_time, start_time, units = "secs"))
# Step 6: Return fit result
new_fit_result(
success = TRUE,
fit = fit_obj,
draws = param_draws,
diagnostics = list(
rhat_max = max(diag_summary$rhat, na.rm = TRUE),
ess_bulk = min(diag_summary$ess_bulk, na.rm = TRUE),
ess_tail = min(diag_summary$ess_tail, na.rm = TRUE),
divergent = sum(sampler_diag$divergent__),
max_treedepth = sum(sampler_diag$treedepth__ >= fitter@max_treedepth)
),
timing = list(
total = total_time,
warmup = fit_obj$time()$chains$warmup,
sample = fit_obj$time()$chains$sampling
),
warnings = warnings,
error = NULL
)
}, error = function(e) {
end_time <- Sys.time()
new_fit_result(
success = FALSE,
fit = NULL,
draws = NULL,
diagnostics = list(),
timing = list(
total = as.numeric(difftime(end_time, start_time, units = "secs")),
warmup = 0,
sample = 0
),
warnings = warnings,
error = e
)
})
result
}Other Required Methods
extract_draws()
method(extract_draws, CmdStanFitter) <- function(
fitter,
fit_result,
variables = NULL
) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(NULL)
}
# Use posterior package for clean extraction
draws_df <- posterior::as_draws_df(fit_result$fit$draws())
param_names <- c("alpha", "beta", "sigma")
if (!is.null(variables)) {
param_names <- intersect(param_names, variables)
}
as.matrix(draws_df[, param_names])
}
predict_fit()
method(predict_fit, CmdStanFitter) <- function(
fitter,
fit_result,
newdata = NULL,
seed = NULL
) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(NULL)
}
if (!is.null(seed)) set.seed(seed)
draws <- extract_draws(fitter, fit_result)
alpha_draws <- draws[, "alpha"]
beta_draws <- draws[, "beta"]
sigma_draws <- draws[, "sigma"]
# For newdata or training data
if (!is.null(newdata)) {
x_new <- newdata$x
} else {
# Use y_rep from generated quantities for training data
y_rep <- fit_result$fit$draws("y_rep", format = "matrix")
return(list(
predicted_mean = colMeans(y_rep),
predicted_samples = t(y_rep), # N x S
predicted_sd = apply(y_rep, 2, sd)
))
}
n_obs <- length(x_new)
n_draws <- length(alpha_draws)
# Compute predictions
linear_pred <- outer(x_new, beta_draws, `*`) +
matrix(alpha_draws, nrow = n_obs, ncol = n_draws, byrow = TRUE)
# Posterior predictive (with noise)
predicted_samples <- linear_pred +
matrix(rnorm(n_obs * n_draws), nrow = n_obs, ncol = n_draws) *
matrix(sigma_draws, nrow = n_obs, ncol = n_draws, byrow = TRUE)
list(
predicted_mean = rowMeans(predicted_samples),
predicted_samples = predicted_samples,
predicted_sd = apply(predicted_samples, 1, sd)
)
}
log_lik()
method(log_lik, CmdStanFitter) <- function(
fitter,
fit_result,
newdata = NULL
) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(NULL)
}
if (!is.null(newdata)) {
warning("CmdStanFitter log_lik() for newdata not yet implemented")
}
# Extract log_lik from generated quantities
ll_matrix <- fit_result$fit$draws("log_lik", format = "matrix")
t(ll_matrix) # N x S
}
loo()
method(loo, CmdStanFitter) <- function(fitter, fit_result) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
}
ll <- log_lik(fitter, fit_result)
if (is.null(ll)) {
return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
}
loo_result <- tryCatch(
loo::loo(ll),
error = function(e) NULL
)
if (is.null(loo_result)) {
return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
}
list(
elpd = loo_result$estimates["elpd_loo", "Estimate"],
p_loo = loo_result$estimates["p_loo", "Estimate"],
elpd_se = loo_result$estimates["elpd_loo", "SE"],
pareto_k = loo::pareto_k_values(loo_result)
)
}Example Simulation: Weak vs Strong Priors
Now we demonstrate the key pattern: using fit_grid to
specify different priors, which flow through fit_spec into
the Stan model.
library(tibble)
# Define fit_grid with different prior specifications
# This shows how fit_spec controls model aspects
fit_grid <- tibble::tibble(
model = c("weak_prior", "strong_prior"),
prior_intercept_mean = c(0, 0),
prior_intercept_sd = c(10, 1), # Strong prior is more informative
prior_slope_mean = c(0, 1), # Strong prior "knows" true slope!
prior_slope_sd = c(10, 0.5) # Strong prior is tighter
)
print(fit_grid)
#> # A tibble: 2 × 5
#> model prior_intercept_mean prior_intercept_sd prior_slope_mean prior_slope_sd
#> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 weak_prior 0 10 0 10
#> 2 strong_prior 0 1 1 0.5Data Generator
# Data generator for simple linear regression
linear_data_generator <- function(data_spec, seed, task_ctx) {
# Note: seed is a scalar task seed.
# The simulation engine also restores the task RNG stream before each call,
# so you typically don't need to call set.seed() here.
n <- data_spec$n
# True parameters (strong_prior happens to "know" these)
alpha_true <- data_spec$alpha
beta_true <- data_spec$beta
sigma_true <- data_spec$sigma
x <- rnorm(n)
y <- alpha_true + beta_true * x + rnorm(n, sd = sigma_true)
list(
train = data.frame(y = y, x = x),
test = NULL,
response = "y",
true_params = c(alpha = alpha_true, beta = beta_true, sigma = sigma_true),
vars_of_interest = c("alpha", "beta", "sigma"),
references = NULL,
meta = list()
)
}Run the Simulation
# Create the fitter
fitter <- CmdStanFitter(
chains = 4L,
iter_sampling = 1000L,
iter_warmup = 1000L
)
# Simulation configuration
config <- simulation_config(
data_grid = tibble::tibble(
n = c(50, 100),
alpha = 0,
beta = 1, # True slope is 1
sigma = 1
),
fit_grid = fit_grid, # Different priors!
data_generator = linear_data_generator,
fitter = fitter,
metrics = list(
rmse_metric(),
bias_metric(),
coverage_metric(),
posterior_mean_metric()
),
n_replicates = 10L,
seed = 42L
)
# Run simulation
result <- run_simulation(config, progress = TRUE)Analyze Results
# Compare posterior means across prior specifications
library(dplyr)
result$summary |>
select(data_idx, fit_idx, posterior_mean__alpha, posterior_mean__beta, posterior_mean__sigma)
#> # A tibble: 4 × 5
#> data_idx fit_idx posterior_mean__alpha posterior_mean__beta posterior_mean__sigma
#> <int> <int> <dbl> <dbl> <dbl>
#> 1 1 1 0.0234 0.982 0.998 # weak_prior, n=50
#> 2 1 2 -0.0412 1.012 1.023 # strong_prior, n=50 (closer to truth!)
#> 3 2 1 0.0089 0.995 0.987 # weak_prior, n=100
#> 4 2 2 -0.0156 1.003 1.001 # strong_prior, n=100
# Strong prior gives better estimates when prior mean matches truth!Key Takeaways
fit_speccontrols model aspects: In this example, priors come fromfit_grid→fit_spec→stan_data. This pattern works for any model hyperparameters.cmdstanr + posterior is the modern stack: Use
cmdstanr::cmdstan_model()for compilation andposterior::as_draws_df()for extraction.-
Pattern generalizes: The same approach works for:
- Different likelihood functions
- Hierarchical model structures
- Regularization strengths
- Any model configuration you want to vary in a simulation
Testing Your Fitter
Once you’ve implemented your fitter, test it thoroughly to ensure it works correctly with bayesim.
Using validate_fitter()
bayesim provides a validate_fitter() function that
checks your fitter implements the required interface correctly:
library(bayesim)
# Create your fitter instance
fitter <- CmdStanFitter(
chains = 2L, # Fewer chains for faster testing
iter_sampling = 500L,
iter_warmup = 500L
)
# Basic validation (checks class, properties, and methods)
validate_fitter(fitter)
# Full validation with smoke test (actually calls all methods)
validate_fitter(fitter, smoke_test = TRUE, verbose = TRUE)
#> Checking S7 Fitter class...
#> [OK] Object is an S7 Fitter class
#> Checking required properties...
#> [OK] Property 'name' exists: cmdstanr
#> [OK] Property 'supports_predictions' exists: TRUE
#> ...
#> Validation passed!The validate_fitter() function performs:
Property checks: Verifies all required properties exist (
name,supports_predictions,supports_log_lik,supports_loo)Method checks: Confirms all S7 methods are implemented (
fit,extract_draws,predict_fit,log_lik,loo,diagnostics)-
Smoke test (when
smoke_test = TRUE):- Creates simple test data
- Calls
fit()and verifiesbayesim_fit_resultstructure - Calls
extract_draws()and verifies matrix with colnames - If supported, calls
predict_fit()and verifies output structure - If supported, calls
log_lik()and verifies matrix dimensions - Calls
diagnostics()and verifies list output
If any check fails, validate_fitter() throws an
informative error explaining what went wrong.
Manual Unit Tests
For more granular testing or to test specific scenarios, you can write manual tests:
# Create minimal test data
set.seed(42)
n <- 50
test_data_bundle <- list(
train = data.frame(
y = rnorm(n),
x = rnorm(n)
),
test = NULL,
response = "y",
true_params = c(alpha = 0, beta = 1, sigma = 1),
vars_of_interest = c("alpha", "beta", "sigma"),
references = NULL,
meta = list()
)
# fit_spec must include prior parameters!
test_fit_spec <- data.frame(
model = "test",
prior_intercept_mean = 0,
prior_intercept_sd = 10,
prior_slope_mean = 0,
prior_slope_sd = 10
)
test_task_ctx <- list(
task_id = "test_001",
data_idx = 1L,
fit_idx = 1L,
rep_idx = 1L
)
# Test fit() returns valid bayesim_fit_result
fit_result <- fit(fitter, test_data_bundle, test_fit_spec, seed = 42L, test_task_ctx)
stopifnot(inherits(fit_result, "bayesim_fit_result"))
stopifnot(fit_result$success == TRUE)
cat("✓ fit() returns valid bayesim_fit_result\n")
# Test extract_draws() returns matrix with colnames
draws <- extract_draws(fitter, fit_result)
stopifnot(is.matrix(draws))
stopifnot(!is.null(colnames(draws)))
cat("✓ extract_draws() returns matrix with colnames\n")
# Test predict_fit() returns correct structure
preds <- predict_fit(fitter, fit_result)
stopifnot(!is.null(preds$predicted_mean))
stopifnot(!is.null(preds$predicted_samples))
stopifnot(!is.null(preds$predicted_sd))
cat("✓ predict_fit() returns correct structure\n")
cat("All unit tests passed!\n")Best Practices
Error Handling in fit()
Always wrap fitting code in tryCatch() to gracefully
handle failures:
result <- tryCatch(
{
# Fitting code here
new_fit_result(success = TRUE, ...)
},
error = function(e) {
new_fit_result(
success = FALSE,
error = e,
timing = list(total = elapsed_time, warmup = 0, sample = 0),
warnings = warnings
)
}
)This ensures that: - Failed fits don’t crash the entire simulation - Error information is preserved for debugging - Timing information is still recorded - Partial progress can be checkpointed
Timing the Fit
Always record timing information:
start_time <- Sys.time()
# ... fitting code ...
end_time <- Sys.time()
total_time <- as.numeric(difftime(end_time, start_time, units = "secs"))The bayesim_fit_result expects timing as a list with
total, warmup, and sample
components. cmdstanr provides these directly via
fit_obj$time().
Handling Warnings
Capture warnings during fitting to prevent them from cluttering output while preserving them for diagnostics:
warnings <- character()
fit_obj <- withCallingHandlers(
{
fitting_function(...)
},
warning = function(w) {
warnings <<- c(warnings, conditionMessage(w))
invokeRestart("muffleWarning")
}
)Checking Fitter Capabilities
Before calling methods that may not be supported, check the fitter’s properties:
if (fitter@supports_predictions) {
preds <- predict_fit(fitter, fit_result, newdata)
} else {
preds <- NULL
}
if (fitter@supports_loo) {
loo_result <- loo(fitter, fit_result)
} else {
loo_result <- NULL
}Summary
Creating custom fitters in bayesim involves:
-
Extending the
Fitterclass with your fitter-specific properties -
Implementing six S7 methods:
fit(),extract_draws(),predict_fit(),log_lik(),loo(), anddiagnostics() -
Returning properly structured
bayesim_fit_resultobjects fromfit() -
Using
fit_specto control model aspects (priors, hyperparameters, etc.) -
Handling errors gracefully with
tryCatch()and preserving warnings withwithCallingHandlers() - Recording timing information for performance analysis
- Extracting real diagnostics from the backend (R-hat, ESS, divergences)
- Testing thoroughly to ensure integration with the simulation framework
The CmdStanFitter example in this vignette demonstrates
all these patterns. By following these patterns, you can extend bayesim
to work with any Bayesian modeling backend.
Appendix: LinearFitter (No External Dependencies)
For users without Stan installed, here’s a LinearFitter
that uses base R’s lm(). This is a better
alternative than MockFitter when you need a no-Stan fitter:
- Unlike
MockFitter,LinearFitterrespects your fit_spec and performs actual statistical inference - Provides approximate Bayesian posterior draws using sampling distributions
- Suitable for lightweight simulation studies where MCMC isn’t required
Note: This fitter doesn’t provide true Bayesian inference (no priors, no MCMC), but it’s useful for testing the simulation framework with realistic statistical behavior:
library(bayesim)
library(S7)
# Define the LinearFitter class
LinearFitter <- new_class(
"LinearFitter",
parent = Fitter,
properties = list(
name = new_property(class_character, default = "lm"),
supports_predictions = new_property(class_logical, default = TRUE),
supports_log_lik = new_property(class_logical, default = TRUE),
supports_loo = new_property(class_logical, default = FALSE),
n_draws = new_property(class_integer, default = 1000L)
)
)
# fit() method
method(fit, LinearFitter) <- function(
fitter,
data_bundle,
fit_spec,
seed,
task_ctx
) {
# Note: seed is a scalar task seed.
# The simulation engine also restores the task RNG stream before each call,
# so you typically don't need to call set.seed() here.
start_time <- Sys.time()
warnings <- character()
result <- tryCatch({
train_data <- data_bundle$train
response <- data_bundle$response
formula <- as.formula(paste(response, "~ ."))
fit_obj <- lm(formula, data = train_data)
fit_obj$bayesim_seed <- seed # Store seed for extract_draws
end_time <- Sys.time()
total_time <- as.numeric(difftime(end_time, start_time, units = "secs"))
new_fit_result(
success = TRUE,
fit = fit_obj,
draws = NULL, # Generated on-demand
diagnostics = list(
rhat_max = 1.0,
ess_bulk = fitter@n_draws,
ess_tail = fitter@n_draws,
divergent = 0L,
max_treedepth = 0L
),
timing = list(total = total_time, warmup = 0, sample = total_time),
warnings = warnings,
error = NULL
)
}, error = function(e) {
end_time <- Sys.time()
new_fit_result(
success = FALSE,
error = e,
timing = list(total = as.numeric(difftime(end_time, start_time, units = "secs")), warmup = 0, sample = 0),
warnings = warnings
)
})
result
}
# extract_draws() method
method(extract_draws, LinearFitter) <- function(
fitter,
fit_result,
variables = NULL
) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(NULL)
}
fit_obj <- fit_result$fit
set.seed(fit_obj$bayesim_seed)
coef_est <- coef(fit_obj)
vcov_mat <- vcov(fit_obj)
sigma_est <- summary(fit_obj)$sigma
n_draws <- fitter@n_draws
n_coef <- length(coef_est)
# Generate draws from sampling distribution
coef_draws <- MASS::mvrnorm(n_draws, coef_est, vcov_mat)
# Chi-squared approximation for sigma
n_obs <- nobs(fit_obj)
df <- n_obs - n_coef
sigma_draws <- sigma_est * sqrt(df / rchisq(n_draws, df))
draws <- cbind(coef_draws, sigma = sigma_draws)
if (!is.null(variables)) {
draws <- draws[, intersect(colnames(draws), variables), drop = FALSE]
}
draws
}
# predict_fit() method
method(predict_fit, LinearFitter) <- function(
fitter,
fit_result,
newdata = NULL,
seed = NULL
) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(NULL)
}
fit_obj <- fit_result$fit
data <- newdata %||% fit_obj$model
n_obs <- nrow(data)
if (!is.null(seed)) set.seed(seed)
draws <- extract_draws(fitter, fit_result)
n_draws <- nrow(draws)
X <- model.matrix(formula(fit_obj), data)
coef_names <- setdiff(colnames(draws), "sigma")
coef_draws <- draws[, coef_names, drop = FALSE]
sigma_draws <- draws[, "sigma"]
linear_pred <- X %*% t(coef_draws)
predicted_samples <- linear_pred +
matrix(rnorm(n_obs * n_draws), nrow = n_obs, ncol = n_draws) *
rep(sigma_draws, each = n_obs)
list(
predicted_mean = rowMeans(predicted_samples),
predicted_samples = predicted_samples,
predicted_sd = apply(predicted_samples, 1, sd)
)
}
# log_lik() method
method(log_lik, LinearFitter) <- function(
fitter,
fit_result,
newdata = NULL
) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(NULL)
}
fit_obj <- fit_result$fit
data <- newdata %||% fit_obj$model
n_obs <- nrow(data)
response <- as.character(formula(fit_obj)[[2]])
y <- data[[response]]
draws <- extract_draws(fitter, fit_result)
n_draws <- nrow(draws)
X <- model.matrix(formula(fit_obj), data)
coef_names <- setdiff(colnames(draws), "sigma")
coef_draws <- draws[, coef_names, drop = FALSE]
sigma_draws <- draws[, "sigma"]
linear_pred <- X %*% t(coef_draws)
sigma_matrix <- rep(sigma_draws, each = n_obs)
matrix(dnorm(y, mean = as.vector(linear_pred), sd = sigma_matrix, log = TRUE),
nrow = n_obs, ncol = n_draws)
}
# loo() method
method(loo, LinearFitter) <- function(fitter, fit_result) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
}
n_obs <- nobs(fit_result$fit)
list(
elpd = NA_real_,
p_loo = length(coef(fit_result$fit)),
elpd_se = NA_real_,
pareto_k = rep(NA_real_, n_obs)
)
}
# diagnostics() method
method(diagnostics, LinearFitter) <- function(fitter, fit_result) {
if (!fit_result$success || is.null(fit_result$fit)) {
return(list())
}
fit_result$diagnostics
}Test the LinearFitter:
# Quick test with LinearFitter
fitter <- LinearFitter()
test_data_bundle <- list(
train = data.frame(y = rnorm(50), x = rnorm(50)),
test = NULL,
response = "y",
true_params = c(`(Intercept)` = 0, x = 0, sigma = 1),
vars_of_interest = c("(Intercept)", "x", "sigma"),
references = NULL,
meta = list()
)
test_fit_spec <- data.frame(model = "test")
test_task_ctx <- list(task_id = "test", data_idx = 1L, fit_idx = 1L, rep_idx = 1L)
fit_result <- fit(fitter, test_data_bundle, test_fit_spec, seed = 42L, test_task_ctx)
stopifnot(fit_result$success)
draws <- extract_draws(fitter, fit_result)
stopifnot(is.matrix(draws))
stopifnot(all(c("(Intercept)", "x", "sigma") %in% colnames(draws)))
cat("LinearFitter test passed!\n")
#> LinearFitter test passed!