import numpy as np
from .prior import Prior
[docs]class PriorScaler:
"""Scale prior distributions parameters."""
# Standard deviation multiplier.
STD = 2.5
def __init__(self, model):
self.model = model
self.has_intercept = model.intercept_term is not None
self.priors = {}
# Compute mean and std of the response
if self.model.family.name in ["gaussian", "t"]:
self.response_mean = np.mean(model.response.data)
self.response_std = np.std(self.model.response.data)
else:
self.response_mean = 0
self.response_std = 1
def get_intercept_stats(self):
mu = self.response_mean
sigma = self.STD * self.response_std
if self.model.common_terms:
sigmas = np.hstack([prior["sigma"] for prior in self.priors.values()])
x_mean = np.hstack([self.model.terms[term].data.mean(axis=0) for term in self.priors])
sigma = (sigma ** 2 + np.dot(sigmas ** 2, x_mean ** 2)) ** 0.5
return mu, sigma
def get_slope_sigma(self, x):
return self.STD * (self.response_std / np.std(x))
def scale_response(self):
# Add cases for other families
priors = self.model.response.family.likelihood.priors
if self.model.family.name == "gaussian":
if priors["sigma"].auto_scale:
priors["sigma"] = Prior("HalfStudentT", nu=4, sigma=self.response_std)
def scale_intercept(self, term):
if term.prior.name != "Normal":
return
mu, sigma = self.get_intercept_stats()
term.prior.update(mu=mu, sigma=sigma)
def scale_common(self, term):
if term.prior.name != "Normal":
return
# As many zeros as columns in the data. It can be greater than 1 for categorical variables
mu = np.zeros(term.data.shape[1])
sigma = np.zeros(term.data.shape[1])
# Iterate over columns in the data
for i, x in enumerate(term.data.T):
sigma[i] = self.get_slope_sigma(x)
# Save and set prior
self.priors.update({term.name: {"mu": mu, "sigma": sigma}})
term.prior.update(mu=mu, sigma=sigma)
def scale_group_specific(self, term):
if term.prior.args["sigma"].name != "HalfNormal":
return
# Recreate the corresponding common effect data
data_as_common = term.predictor
# Handle intercepts
if term.type == "intercept":
_, sigma = self.get_intercept_stats()
# Handle slopes
else:
sigma = np.zeros(data_as_common.shape[1])
for i, x in enumerate(data_as_common.T):
sigma[i] = self.get_slope_sigma(x)
term.prior.args["sigma"].update(sigma=np.squeeze(np.atleast_1d(sigma)))
def scale(self):
# Scale response
self.scale_response()
# Scale common terms
for term in self.model.common_terms.values():
if term.prior.auto_scale:
self.scale_common(term)
# Scale intercept
if self.has_intercept:
term = self.model.intercept_term
if term.prior.auto_scale:
self.scale_intercept(term)
# Scale group-specific terms
for term in self.model.group_specific_terms.values():
if term.prior.auto_scale:
self.scale_group_specific(term)