Skip to contents

Introduction

bayesim provides a framework for Bayesian simulation studies through its extensible fitter interface. The package includes built-in fitters like MockFitter() and BrmsFitter(), but you may need to create custom fitters to:

  • Interface with different backends: Connect to Stan via cmdstanr, or other probabilistic programming languages
  • Implement custom models: Fit specialized models not covered by existing fitters
  • Add preprocessing: Transform data or apply custom initialization strategies
  • Control inference: Implement custom sampling algorithms or diagnostics

This vignette explains the fitter interface and demonstrates creating a cmdstanr fitter from scratch.

The Fitter Interface

The Abstract Fitter S7 Class

All fitters in bayesim extend the abstract Fitter S7 class. This class defines the interface contract that all fitters must implement:

Properties

Property Type Description
name character Identifier for the fitter (e.g., “stan”, “brms”)
supports_predictions logical Whether the fitter supports predict_fit()
supports_log_lik logical Whether the fitter supports log_lik()
supports_loo logical Whether the fitter supports loo()

These properties allow metrics and simulation code to query fitter capabilities before attempting operations that may not be supported.

Required Generics

The following S7 generics must have methods implemented for your fitter class:

fit(fitter, data_bundle, fit_spec, seed, task_ctx)

The main fitting method. Returns a bayesim_fit_result S3 object.

Parameters: - fitter: Your fitter object - data_bundle: List containing train, test, response, true_params, vars_of_interest - fit_spec: Single-row data frame with model specification from fit_grid - seed: Integer seed for reproducibility - task_ctx: List with task_id, data_idx, fit_idx, rep_idx

Returns: A bayesim_fit_result with fields: - success: Logical indicating if fit succeeded - fit: The backend fit object (or NULL) - draws: Matrix of posterior draws (S × P) with column names - diagnostics: Named list of diagnostic values - timing: List with total, warmup, sample times - warnings: Character vector of captured warnings - error: NULL or error condition

extract_draws(fitter, fit_result, variables = NULL)

Extract posterior draws from a fitted model.

Returns: Matrix with dimensions S × P (draws × parameters) with column names matching variable names.

predict_fit(fitter, fit_result, newdata = NULL, seed = NULL)

Generate predictions from a fitted model.

Returns: List containing: - predicted_mean: Vector of mean predictions (length N) - predicted_samples: Matrix of posterior predictive samples (N × S) - predicted_sd: Vector of prediction standard deviations (length N)

log_lik(fitter, fit_result, newdata = NULL)

Compute pointwise log-likelihood values.

Returns: Matrix with dimensions N × S (observations × draws).

loo(fitter, fit_result)

Compute leave-one-out cross-validation.

Returns: List with: - elpd: Expected log predictive density - p_loo: Effective number of parameters - elpd_se: Standard error of ELPD - pareto_k: Pareto k diagnostic values

diagnostics(fitter, fit_result)

Extract convergence and fit diagnostics.

Returns: Named list with diagnostics such as: - rhat_max: Maximum R-hat statistic - ess_bulk: Minimum bulk effective sample size - ess_tail: Minimum tail effective sample size - divergent: Number of divergent transitions - max_treedepth: Number of max treedepth warnings

Example: cmdstanr Fitter with Configurable Priors

This section demonstrates a cmdstanr fitter for simple linear regression. The design pattern is that fit_spec controls model aspects—in this case, the prior distributions. This allows you to run simulation studies comparing different prior specifications.

Why cmdstanr?

cmdstanr is the R interface to Stan:

  • Direct Stan control: Compile and run Stan models directly
  • Fast: Uses CmdStan, the command-line interface to Stan
  • Active development: Better integration with the posterior package
  • Transparent: Clear separation between model, data, and sampling

The Stan Model

We use a simple linear regression with configurable priors passed as data. This is the key pattern: priors come from fit_spec, not hardcoded in the model:

# Define the Stan model as a character string
# Note: Priors are passed as DATA, allowing fit_spec to control them
stan_model_code <- "
data {
  int<lower=0> N;
  vector[N] y;
  vector[N] x;
  // Prior parameters - controlled by fit_spec
  real prior_intercept_mean;
  real<lower=0> prior_intercept_sd;
  real prior_slope_mean;
  real<lower=0> prior_slope_sd;
}
parameters {
  real alpha;  // Intercept
  real beta;   // Slope
  real<lower=0> sigma;  // Error SD (fixed prior)
}
model {
  // Configurable priors from data
  alpha ~ normal(prior_intercept_mean, prior_intercept_sd);
  beta ~ normal(prior_slope_mean, prior_slope_sd);
  sigma ~ exponential(1);  // Fixed prior for sigma
  
  // Likelihood
  y ~ normal(alpha + beta * x, sigma);
}
generated quantities {
  vector[N] log_lik;
  vector[N] y_rep;
  for (n in 1:N) {
    log_lik[n] = normal_lpdf(y[n] | alpha + beta * x[n], sigma);
    y_rep[n] = normal_rng(alpha + beta * x[n], sigma);
  }
}
"

CmdStanFitter Class Definition

The S7 class holds MCMC settings and caches the compiled model:

library(S7)
library(bayesim)

# Define the CmdStanFitter class
CmdStanFitter <- new_class(
  "CmdStanFitter",
  parent = Fitter,
  properties = list(
    # Required Fitter properties
    name = new_property(class_character, default = "cmdstanr"),
    supports_predictions = new_property(class_logical, default = TRUE),
    supports_log_lik = new_property(class_logical, default = TRUE),
    supports_loo = new_property(class_logical, default = TRUE),
    # MCMC settings
    chains = new_property(class_integer, default = 4L),
    iter_sampling = new_property(class_integer, default = 1000L),
    iter_warmup = new_property(class_integer, default = 1000L),
    adapt_delta = new_property(class_double, default = 0.95),
    max_treedepth = new_property(class_integer, default = 10L),
    # Model caching
    stan_file = new_property(class_character, default = NULL),
    compiled_model = new_property(class_any, default = NULL)
  )
)

The fit() Method

This is where the magic happens: we build stan_data from both data_bundle (the data) and fit_spec (the priors):

method(fit, CmdStanFitter) <- function(
  fitter,
  data_bundle,
  fit_spec,
  seed,
  task_ctx
) {
  start_time <- Sys.time()
  warnings <- character()
  
  result <- tryCatch({
    # Step 1: Compile model (with caching)
    if (is.null(fitter@compiled_model)) {
      # Write model to temp file if not already done
      if (is.null(fitter@stan_file)) {
        stan_file <- tempfile(fileext = ".stan")
        writeLines(stan_model_code, stan_file)
      } else {
        stan_file <- fitter@stan_file
      }
      
      model <- cmdstanr::cmdstan_model(stan_file)
    } else {
      model <- fitter@compiled_model
    }
    
    # Step 2: Prepare data - combine data_bundle with fit_spec
    train_data <- data_bundle$train
    response <- data_bundle$response
    predictor <- setdiff(names(train_data), response)
    
    if (length(predictor) != 1) {
      stop("CmdStanFitter expects exactly one predictor")
    }
    
    # Build stan_data from BOTH data_bundle AND fit_spec
    # KEY PATTERN: fit_spec controls model aspects (here: priors)
    stan_data <- list(
      # From data_bundle
      N = nrow(train_data),
      y = train_data[[response]],
      x = train_data[[predictor]],
      # From fit_spec - these control the priors!
      prior_intercept_mean = fit_spec$prior_intercept_mean,
      prior_intercept_sd = fit_spec$prior_intercept_sd,
      prior_slope_mean = fit_spec$prior_slope_mean,
      prior_slope_sd = fit_spec$prior_slope_sd
    )
    
    # Step 3: Run MCMC
    fit_obj <- withCallingHandlers({
      model$sample(
        data = stan_data,
        chains = fitter@chains,
        iter_sampling = fitter@iter_sampling,
        iter_warmup = fitter@iter_warmup,
        seed = seed,
        adapt_delta = fitter@adapt_delta,
        max_treedepth = fitter@max_treedepth,
        refresh = 0  # Suppress verbose output
      )
    }, warning = function(w) {
      warnings <<- c(warnings, conditionMessage(w))
      invokeRestart("muffleWarning")
    })
    
    # Step 4: Extract draws using posterior package
    draws_df <- posterior::as_draws_df(fit_obj$draws())
    param_names <- c("alpha", "beta", "sigma")
    param_draws <- as.matrix(draws_df[, param_names])
    
    # Step 5: Get diagnostics from cmdstanr
    diag_summary <- fit_obj$summary(
      variables = param_names,
      posterior::rhat,
      posterior::ess_bulk,
      posterior::ess_tail
    )
    
    sampler_diag <- fit_obj$sampler_diagnostics(format = "df")
    
    end_time <- Sys.time()
    total_time <- as.numeric(difftime(end_time, start_time, units = "secs"))
    
    # Step 6: Return fit result
    new_fit_result(
      success = TRUE,
      fit = fit_obj,
      draws = param_draws,
      diagnostics = list(
        rhat_max = max(diag_summary$rhat, na.rm = TRUE),
        ess_bulk = min(diag_summary$ess_bulk, na.rm = TRUE),
        ess_tail = min(diag_summary$ess_tail, na.rm = TRUE),
        divergent = sum(sampler_diag$divergent__),
        max_treedepth = sum(sampler_diag$treedepth__ >= fitter@max_treedepth)
      ),
      timing = list(
        total = total_time,
        warmup = fit_obj$time()$chains$warmup,
        sample = fit_obj$time()$chains$sampling
      ),
      warnings = warnings,
      error = NULL
    )
    
  }, error = function(e) {
    end_time <- Sys.time()
    new_fit_result(
      success = FALSE,
      fit = NULL,
      draws = NULL,
      diagnostics = list(),
      timing = list(
        total = as.numeric(difftime(end_time, start_time, units = "secs")),
        warmup = 0,
        sample = 0
      ),
      warnings = warnings,
      error = e
    )
  })
  
  result
}

Other Required Methods

extract_draws()

method(extract_draws, CmdStanFitter) <- function(
  fitter,
  fit_result,
  variables = NULL
) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(NULL)
  }
  
  # Use posterior package for clean extraction
  draws_df <- posterior::as_draws_df(fit_result$fit$draws())
  param_names <- c("alpha", "beta", "sigma")
  
  if (!is.null(variables)) {
    param_names <- intersect(param_names, variables)
  }
  
  as.matrix(draws_df[, param_names])
}

predict_fit()

method(predict_fit, CmdStanFitter) <- function(
  fitter,
  fit_result,
  newdata = NULL,
  seed = NULL
) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(NULL)
  }
  
  if (!is.null(seed)) set.seed(seed)
  
  draws <- extract_draws(fitter, fit_result)
  alpha_draws <- draws[, "alpha"]
  beta_draws <- draws[, "beta"]
  sigma_draws <- draws[, "sigma"]
  
  # For newdata or training data
  if (!is.null(newdata)) {
    x_new <- newdata$x
  } else {
    # Use y_rep from generated quantities for training data
    y_rep <- fit_result$fit$draws("y_rep", format = "matrix")
    return(list(
      predicted_mean = colMeans(y_rep),
      predicted_samples = t(y_rep),  # N x S
      predicted_sd = apply(y_rep, 2, sd)
    ))
  }
  
  n_obs <- length(x_new)
  n_draws <- length(alpha_draws)
  
  # Compute predictions
  linear_pred <- outer(x_new, beta_draws, `*`) + 
    matrix(alpha_draws, nrow = n_obs, ncol = n_draws, byrow = TRUE)
  
  # Posterior predictive (with noise)
  predicted_samples <- linear_pred + 
    matrix(rnorm(n_obs * n_draws), nrow = n_obs, ncol = n_draws) * 
    matrix(sigma_draws, nrow = n_obs, ncol = n_draws, byrow = TRUE)
  
  list(
    predicted_mean = rowMeans(predicted_samples),
    predicted_samples = predicted_samples,
    predicted_sd = apply(predicted_samples, 1, sd)
  )
}

log_lik()

method(log_lik, CmdStanFitter) <- function(
  fitter,
  fit_result,
  newdata = NULL
) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(NULL)
  }
  
  if (!is.null(newdata)) {
    warning("CmdStanFitter log_lik() for newdata not yet implemented")
  }
  
  # Extract log_lik from generated quantities
  ll_matrix <- fit_result$fit$draws("log_lik", format = "matrix")
  t(ll_matrix)  # N x S
}

loo()

method(loo, CmdStanFitter) <- function(fitter, fit_result) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
  }
  
  ll <- log_lik(fitter, fit_result)
  if (is.null(ll)) {
    return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
  }
  
  loo_result <- tryCatch(
    loo::loo(ll),
    error = function(e) NULL
  )
  
  if (is.null(loo_result)) {
    return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
  }
  
  list(
    elpd = loo_result$estimates["elpd_loo", "Estimate"],
    p_loo = loo_result$estimates["p_loo", "Estimate"],
    elpd_se = loo_result$estimates["elpd_loo", "SE"],
    pareto_k = loo::pareto_k_values(loo_result)
  )
}

diagnostics()

method(diagnostics, CmdStanFitter) <- function(fitter, fit_result) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(list())
  }
  fit_result$diagnostics
}

Example Simulation: Weak vs Strong Priors

Now we demonstrate the key pattern: using fit_grid to specify different priors, which flow through fit_spec into the Stan model.

library(tibble)

# Define fit_grid with different prior specifications
# This shows how fit_spec controls model aspects
fit_grid <- tibble::tibble(
  model = c("weak_prior", "strong_prior"),
  prior_intercept_mean = c(0, 0),
  prior_intercept_sd = c(10, 1),      # Strong prior is more informative
  prior_slope_mean = c(0, 1),          # Strong prior "knows" true slope!
  prior_slope_sd = c(10, 0.5)          # Strong prior is tighter
)

print(fit_grid)
#> # A tibble: 2 × 5
#>   model         prior_intercept_mean prior_intercept_sd prior_slope_mean prior_slope_sd
#>   <chr>                        <dbl>              <dbl>            <dbl>          <dbl>
#> 1 weak_prior                       0                 10              0             10  
#> 2 strong_prior                     0                  1              1              0.5

Data Generator

# Data generator for simple linear regression
linear_data_generator <- function(data_spec, seed, task_ctx) {
  # Note: seed is a scalar task seed.
  # The simulation engine also restores the task RNG stream before each call,
  # so you typically don't need to call set.seed() here.
  n <- data_spec$n
  
  # True parameters (strong_prior happens to "know" these)
  alpha_true <- data_spec$alpha
  beta_true <- data_spec$beta
  sigma_true <- data_spec$sigma
  
  x <- rnorm(n)
  y <- alpha_true + beta_true * x + rnorm(n, sd = sigma_true)
  
  list(
    train = data.frame(y = y, x = x),
    test = NULL,
    response = "y",
    true_params = c(alpha = alpha_true, beta = beta_true, sigma = sigma_true),
    vars_of_interest = c("alpha", "beta", "sigma"),
    references = NULL,
    meta = list()
  )
}

Run the Simulation

# Create the fitter
fitter <- CmdStanFitter(
  chains = 4L,
  iter_sampling = 1000L,
  iter_warmup = 1000L
)

# Simulation configuration
config <- simulation_config(
  data_grid = tibble::tibble(
    n = c(50, 100),
    alpha = 0,
    beta = 1,      # True slope is 1
    sigma = 1
  ),
  fit_grid = fit_grid,  # Different priors!
  data_generator = linear_data_generator,
  fitter = fitter,
  metrics = list(
    rmse_metric(),
    bias_metric(),
    coverage_metric(),
    posterior_mean_metric()
  ),
  n_replicates = 10L,
  seed = 42L
)

# Run simulation
result <- run_simulation(config, progress = TRUE)

Analyze Results

# Compare posterior means across prior specifications
library(dplyr)

result$summary |>
  select(data_idx, fit_idx, posterior_mean__alpha, posterior_mean__beta, posterior_mean__sigma)
#> # A tibble: 4 × 5
#>   data_idx fit_idx posterior_mean__alpha posterior_mean__beta posterior_mean__sigma
#>      <int>   <int>                 <dbl>                <dbl>                <dbl>
#> 1        1       1               0.0234                0.982                0.998  # weak_prior, n=50
#> 2        1       2              -0.0412                1.012                1.023  # strong_prior, n=50 (closer to truth!)
#> 3        2       1               0.0089                0.995                0.987  # weak_prior, n=100
#> 4        2       2              -0.0156                1.003                1.001  # strong_prior, n=100

# Strong prior gives better estimates when prior mean matches truth!

Key Takeaways

  1. fit_spec controls model aspects: In this example, priors come from fit_gridfit_specstan_data. This pattern works for any model hyperparameters.

  2. cmdstanr + posterior is the modern stack: Use cmdstanr::cmdstan_model() for compilation and posterior::as_draws_df() for extraction.

  3. Pattern generalizes: The same approach works for:

    • Different likelihood functions
    • Hierarchical model structures
    • Regularization strengths
    • Any model configuration you want to vary in a simulation

Testing Your Fitter

Once you’ve implemented your fitter, test it thoroughly to ensure it works correctly with bayesim.

Using validate_fitter()

bayesim provides a validate_fitter() function that checks your fitter implements the required interface correctly:

library(bayesim)

# Create your fitter instance
fitter <- CmdStanFitter(
  chains = 2L,      # Fewer chains for faster testing
  iter_sampling = 500L,
  iter_warmup = 500L
)

# Basic validation (checks class, properties, and methods)
validate_fitter(fitter)

# Full validation with smoke test (actually calls all methods)
validate_fitter(fitter, smoke_test = TRUE, verbose = TRUE)
#> Checking S7 Fitter class...
#>   [OK] Object is an S7 Fitter class
#> Checking required properties...
#>   [OK] Property 'name' exists: cmdstanr
#>   [OK] Property 'supports_predictions' exists: TRUE
#>   ...
#> Validation passed!

The validate_fitter() function performs:

  1. Property checks: Verifies all required properties exist (name, supports_predictions, supports_log_lik, supports_loo)

  2. Method checks: Confirms all S7 methods are implemented (fit, extract_draws, predict_fit, log_lik, loo, diagnostics)

  3. Smoke test (when smoke_test = TRUE):

    • Creates simple test data
    • Calls fit() and verifies bayesim_fit_result structure
    • Calls extract_draws() and verifies matrix with colnames
    • If supported, calls predict_fit() and verifies output structure
    • If supported, calls log_lik() and verifies matrix dimensions
    • Calls diagnostics() and verifies list output

If any check fails, validate_fitter() throws an informative error explaining what went wrong.

Manual Unit Tests

For more granular testing or to test specific scenarios, you can write manual tests:

# Create minimal test data
set.seed(42)
n <- 50
test_data_bundle <- list(
  train = data.frame(
    y = rnorm(n),
    x = rnorm(n)
  ),
  test = NULL,
  response = "y",
  true_params = c(alpha = 0, beta = 1, sigma = 1),
  vars_of_interest = c("alpha", "beta", "sigma"),
  references = NULL,
  meta = list()
)

# fit_spec must include prior parameters!
test_fit_spec <- data.frame(
  model = "test",
  prior_intercept_mean = 0,
  prior_intercept_sd = 10,
  prior_slope_mean = 0,
  prior_slope_sd = 10
)
test_task_ctx <- list(
  task_id = "test_001",
  data_idx = 1L,
  fit_idx = 1L,
  rep_idx = 1L
)

# Test fit() returns valid bayesim_fit_result
fit_result <- fit(fitter, test_data_bundle, test_fit_spec, seed = 42L, test_task_ctx)
stopifnot(inherits(fit_result, "bayesim_fit_result"))
stopifnot(fit_result$success == TRUE)
cat("✓ fit() returns valid bayesim_fit_result\n")

# Test extract_draws() returns matrix with colnames
draws <- extract_draws(fitter, fit_result)
stopifnot(is.matrix(draws))
stopifnot(!is.null(colnames(draws)))
cat("✓ extract_draws() returns matrix with colnames\n")

# Test predict_fit() returns correct structure
preds <- predict_fit(fitter, fit_result)
stopifnot(!is.null(preds$predicted_mean))
stopifnot(!is.null(preds$predicted_samples))
stopifnot(!is.null(preds$predicted_sd))
cat("✓ predict_fit() returns correct structure\n")

cat("All unit tests passed!\n")

Best Practices

Error Handling in fit()

Always wrap fitting code in tryCatch() to gracefully handle failures:

result <- tryCatch(
  {
    # Fitting code here
    new_fit_result(success = TRUE, ...)
  },
  error = function(e) {
    new_fit_result(
      success = FALSE,
      error = e,
      timing = list(total = elapsed_time, warmup = 0, sample = 0),
      warnings = warnings
    )
  }
)

This ensures that: - Failed fits don’t crash the entire simulation - Error information is preserved for debugging - Timing information is still recorded - Partial progress can be checkpointed

Timing the Fit

Always record timing information:

start_time <- Sys.time()
# ... fitting code ...
end_time <- Sys.time()
total_time <- as.numeric(difftime(end_time, start_time, units = "secs"))

The bayesim_fit_result expects timing as a list with total, warmup, and sample components. cmdstanr provides these directly via fit_obj$time().

Handling Warnings

Capture warnings during fitting to prevent them from cluttering output while preserving them for diagnostics:

warnings <- character()

fit_obj <- withCallingHandlers(
  {
    fitting_function(...)
  },
  warning = function(w) {
    warnings <<- c(warnings, conditionMessage(w))
    invokeRestart("muffleWarning")
  }
)

Seed Management

Pass the seed to the backend for reproducibility:

# Pass seed to cmdstanr
model$sample(..., seed = seed)

# For predictions, accept a seed parameter
if (!is.null(seed)) {
  set.seed(seed)
}

Checking Fitter Capabilities

Before calling methods that may not be supported, check the fitter’s properties:

if (fitter@supports_predictions) {
  preds <- predict_fit(fitter, fit_result, newdata)
} else {
  preds <- NULL
}

if (fitter@supports_loo) {
  loo_result <- loo(fitter, fit_result)
} else {
  loo_result <- NULL
}

Summary

Creating custom fitters in bayesim involves:

  1. Extending the Fitter class with your fitter-specific properties
  2. Implementing six S7 methods: fit(), extract_draws(), predict_fit(), log_lik(), loo(), and diagnostics()
  3. Returning properly structured bayesim_fit_result objects from fit()
  4. Using fit_spec to control model aspects (priors, hyperparameters, etc.)
  5. Handling errors gracefully with tryCatch() and preserving warnings with withCallingHandlers()
  6. Recording timing information for performance analysis
  7. Extracting real diagnostics from the backend (R-hat, ESS, divergences)
  8. Testing thoroughly to ensure integration with the simulation framework

The CmdStanFitter example in this vignette demonstrates all these patterns. By following these patterns, you can extend bayesim to work with any Bayesian modeling backend.

Appendix: LinearFitter (No External Dependencies)

For users without Stan installed, here’s a LinearFitter that uses base R’s lm(). This is a better alternative than MockFitter when you need a no-Stan fitter:

  • Unlike MockFitter, LinearFitter respects your fit_spec and performs actual statistical inference
  • Provides approximate Bayesian posterior draws using sampling distributions
  • Suitable for lightweight simulation studies where MCMC isn’t required

Note: This fitter doesn’t provide true Bayesian inference (no priors, no MCMC), but it’s useful for testing the simulation framework with realistic statistical behavior:

library(bayesim)
library(S7)

# Define the LinearFitter class
LinearFitter <- new_class(
  "LinearFitter",
  parent = Fitter,
  properties = list(
    name = new_property(class_character, default = "lm"),
    supports_predictions = new_property(class_logical, default = TRUE),
    supports_log_lik = new_property(class_logical, default = TRUE),
    supports_loo = new_property(class_logical, default = FALSE),
    n_draws = new_property(class_integer, default = 1000L)
  )
)

# fit() method
method(fit, LinearFitter) <- function(
  fitter,
  data_bundle,
  fit_spec,
  seed,
  task_ctx
) {
  # Note: seed is a scalar task seed.
  # The simulation engine also restores the task RNG stream before each call,
  # so you typically don't need to call set.seed() here.
  start_time <- Sys.time()
  warnings <- character()
  
  result <- tryCatch({
    train_data <- data_bundle$train
    response <- data_bundle$response
    formula <- as.formula(paste(response, "~ ."))
    
    fit_obj <- lm(formula, data = train_data)
    fit_obj$bayesim_seed <- seed  # Store seed for extract_draws
    
    end_time <- Sys.time()
    total_time <- as.numeric(difftime(end_time, start_time, units = "secs"))
    
    new_fit_result(
      success = TRUE,
      fit = fit_obj,
      draws = NULL,  # Generated on-demand
      diagnostics = list(
        rhat_max = 1.0,
        ess_bulk = fitter@n_draws,
        ess_tail = fitter@n_draws,
        divergent = 0L,
        max_treedepth = 0L
      ),
      timing = list(total = total_time, warmup = 0, sample = total_time),
      warnings = warnings,
      error = NULL
    )
  }, error = function(e) {
    end_time <- Sys.time()
    new_fit_result(
      success = FALSE,
      error = e,
      timing = list(total = as.numeric(difftime(end_time, start_time, units = "secs")), warmup = 0, sample = 0),
      warnings = warnings
    )
  })
  
  result
}

# extract_draws() method
method(extract_draws, LinearFitter) <- function(
  fitter,
  fit_result,
  variables = NULL
) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(NULL)
  }
  
  fit_obj <- fit_result$fit
  set.seed(fit_obj$bayesim_seed)
  
  coef_est <- coef(fit_obj)
  vcov_mat <- vcov(fit_obj)
  sigma_est <- summary(fit_obj)$sigma
  
  n_draws <- fitter@n_draws
  n_coef <- length(coef_est)
  
  # Generate draws from sampling distribution
  coef_draws <- MASS::mvrnorm(n_draws, coef_est, vcov_mat)
  
  # Chi-squared approximation for sigma
  n_obs <- nobs(fit_obj)
  df <- n_obs - n_coef
  sigma_draws <- sigma_est * sqrt(df / rchisq(n_draws, df))
  
  draws <- cbind(coef_draws, sigma = sigma_draws)
  
  if (!is.null(variables)) {
    draws <- draws[, intersect(colnames(draws), variables), drop = FALSE]
  }
  
  draws
}

# predict_fit() method
method(predict_fit, LinearFitter) <- function(
  fitter,
  fit_result,
  newdata = NULL,
  seed = NULL
) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(NULL)
  }
  
  fit_obj <- fit_result$fit
  data <- newdata %||% fit_obj$model
  n_obs <- nrow(data)
  
  if (!is.null(seed)) set.seed(seed)
  
  draws <- extract_draws(fitter, fit_result)
  n_draws <- nrow(draws)
  
  X <- model.matrix(formula(fit_obj), data)
  coef_names <- setdiff(colnames(draws), "sigma")
  coef_draws <- draws[, coef_names, drop = FALSE]
  sigma_draws <- draws[, "sigma"]
  
  linear_pred <- X %*% t(coef_draws)
  predicted_samples <- linear_pred + 
    matrix(rnorm(n_obs * n_draws), nrow = n_obs, ncol = n_draws) * 
    rep(sigma_draws, each = n_obs)
  
  list(
    predicted_mean = rowMeans(predicted_samples),
    predicted_samples = predicted_samples,
    predicted_sd = apply(predicted_samples, 1, sd)
  )
}

# log_lik() method
method(log_lik, LinearFitter) <- function(
  fitter,
  fit_result,
  newdata = NULL
) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(NULL)
  }
  
  fit_obj <- fit_result$fit
  data <- newdata %||% fit_obj$model
  n_obs <- nrow(data)
  
  response <- as.character(formula(fit_obj)[[2]])
  y <- data[[response]]
  
  draws <- extract_draws(fitter, fit_result)
  n_draws <- nrow(draws)
  
  X <- model.matrix(formula(fit_obj), data)
  coef_names <- setdiff(colnames(draws), "sigma")
  coef_draws <- draws[, coef_names, drop = FALSE]
  sigma_draws <- draws[, "sigma"]
  
  linear_pred <- X %*% t(coef_draws)
  sigma_matrix <- rep(sigma_draws, each = n_obs)
  
  matrix(dnorm(y, mean = as.vector(linear_pred), sd = sigma_matrix, log = TRUE),
         nrow = n_obs, ncol = n_draws)
}

# loo() method
method(loo, LinearFitter) <- function(fitter, fit_result) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(list(elpd = NA_real_, p_loo = NA_real_, elpd_se = NA_real_, pareto_k = numeric()))
  }
  
  n_obs <- nobs(fit_result$fit)
  list(
    elpd = NA_real_,
    p_loo = length(coef(fit_result$fit)),
    elpd_se = NA_real_,
    pareto_k = rep(NA_real_, n_obs)
  )
}

# diagnostics() method
method(diagnostics, LinearFitter) <- function(fitter, fit_result) {
  if (!fit_result$success || is.null(fit_result$fit)) {
    return(list())
  }
  fit_result$diagnostics
}

Test the LinearFitter:

# Quick test with LinearFitter
fitter <- LinearFitter()

test_data_bundle <- list(
  train = data.frame(y = rnorm(50), x = rnorm(50)),
  test = NULL,
  response = "y",
  true_params = c(`(Intercept)` = 0, x = 0, sigma = 1),
  vars_of_interest = c("(Intercept)", "x", "sigma"),
  references = NULL,
  meta = list()
)

test_fit_spec <- data.frame(model = "test")
test_task_ctx <- list(task_id = "test", data_idx = 1L, fit_idx = 1L, rep_idx = 1L)

fit_result <- fit(fitter, test_data_bundle, test_fit_spec, seed = 42L, test_task_ctx)
stopifnot(fit_result$success)

draws <- extract_draws(fitter, fit_result)
stopifnot(is.matrix(draws))
stopifnot(all(c("(Intercept)", "x", "sigma") %in% colnames(draws)))

cat("LinearFitter test passed!\n")
#> LinearFitter test passed!