Source code for bambi.priors.scaler_default

import numpy as np

from .prior import Prior


[docs]class PriorScaler: """Scale prior distributions parameters.""" # Standard deviation multiplier. STD = 2.5 def __init__(self, model): self.model = model self.has_intercept = model.intercept_term is not None self.priors = {} # Compute mean and std of the response if self.model.family.name in ["gaussian", "t"]: self.response_mean = np.mean(model.response.data) self.response_std = np.std(self.model.response.data) else: self.response_mean = 0 self.response_std = 1 def get_intercept_stats(self): mu = self.response_mean sigma = self.STD * self.response_std if self.model.common_terms: sigmas = np.hstack([prior["sigma"] for prior in self.priors.values()]) x_mean = np.hstack([self.model.terms[term].data.mean(axis=0) for term in self.priors]) sigma = (sigma ** 2 + np.dot(sigmas ** 2, x_mean ** 2)) ** 0.5 return mu, sigma def get_slope_sigma(self, x): return self.STD * (self.response_std / np.std(x)) def scale_response(self): # Add cases for other families priors = self.model.response.family.likelihood.priors if self.model.family.name in ["gaussian", "t"]: if priors["sigma"].auto_scale: priors["sigma"] = Prior("HalfStudentT", nu=4, sigma=self.response_std) def scale_intercept(self, term): if term.prior.name != "Normal": return mu, sigma = self.get_intercept_stats() term.prior.update(mu=mu, sigma=sigma) def scale_common(self, term): if term.prior.name != "Normal": return # As many zeros as columns in the data. It can be greater than 1 for categorical variables mu = np.zeros(term.data.shape[1]) sigma = np.zeros(term.data.shape[1]) # Iterate over columns in the data for i, x in enumerate(term.data.T): sigma[i] = self.get_slope_sigma(x) # Save and set prior self.priors.update({term.name: {"mu": mu, "sigma": sigma}}) term.prior.update(mu=mu, sigma=sigma) def scale_group_specific(self, term): if term.prior.args["sigma"].name != "HalfNormal": return # Recreate the corresponding common effect data data_as_common = term.predictor # Handle intercepts if term.type == "intercept": _, sigma = self.get_intercept_stats() # Handle slopes else: sigma = np.zeros(data_as_common.shape[1]) for i, x in enumerate(data_as_common.T): sigma[i] = self.get_slope_sigma(x) term.prior.args["sigma"].update(sigma=np.squeeze(np.atleast_1d(sigma))) def scale(self): # Scale response self.scale_response() # Scale common terms for term in self.model.common_terms.values(): if term.prior.auto_scale: self.scale_common(term) # Scale intercept if self.has_intercept: term = self.model.intercept_term if term.prior.auto_scale: self.scale_intercept(term) # Scale group-specific terms for term in self.model.group_specific_terms.values(): if term.prior.auto_scale: self.scale_group_specific(term)