from statsmodels.genmod import families as sm_families
from bambi.families.link import Link
STATSMODELS_FAMILIES = {
"bernoulli": sm_families.Binomial,
"gamma": sm_families.Gamma,
"gaussian": sm_families.Gaussian,
"wald": sm_families.InverseGaussian,
"negativebinomial": sm_families.NegativeBinomial,
"poisson": sm_families.Poisson,
}
STATSMODELS_LINKS = {
"identity": sm_families.links.identity(),
"logit": sm_families.links.logit(),
"probit": sm_families.links.probit(),
"cloglog": sm_families.links.cloglog(),
"inverse": sm_families.links.inverse_power(),
"inverse_squared": sm_families.links.inverse_squared(),
"log": sm_families.links.log(),
}
[docs]class Family:
"""A specification of model family.
Parameters
----------
name : str
The name of the family. It can be any string.
likelihood: Likelihood
A ``bambi.families.Likelihood`` instace specifying the model likelihood function.
link : str or Link
The name of the link function or a ``bambi.families.Link`` instance. The link function
transforms the linear model prediction to the mean parameter of the likelihood funtion.
Examples
--------
>>> import bambi as bmb
Replicate the Gaussian built-in family.
>>> sigma_prior = bmb.Prior("HalfNormal", sigma=1)
>>> likelihood = bmb.Likelihood("Gaussian", parent="mu", sigma=sigma_prior)
>>> family = bmb.Family("my_gaussian", likelihood, "identity")
>>> # Then you can do
>>> # bmb.Model("y ~ x", data, family=family)
Replicate the Bernoulli built-in family.
>>> likelihood = bmb.Likelihood("Bernoulli", parent="p")
>>> family = bmb.Family("bernoulli2", likelihood, "logit")
"""
def __init__(self, name, likelihood, link):
self.smlink = None
self.link = None
self.name = name
self.likelihood = likelihood
self.smfamily = STATSMODELS_FAMILIES.get(name, None)
self._set_link(link)
def _set_link(self, link):
"""Set new link function.
If ``link`` is of type ``str``, this method attempts to create a ``bambi.families.Link``
from the name passed. If it is a recognized name, a builtin ``bambi.families.Link`` will be
used. Otherwise, ``bambi.families.Link`` instantiation will raise an error.
Parameters
----------
link: str or bambi.families.Link
If a string, it must the name of a link function recognized by Bambi.
Returns
-------
None
"""
if isinstance(link, str):
self.link = Link(link)
self.smlink = STATSMODELS_LINKS.get(link, None)
elif isinstance(link, Link):
self.link = link
else:
raise ValueError("'link' must be a string or a Link instance.")
def __str__(self):
msg_list = [f"Response distribution: {self.likelihood.name}", f"Link: {self.link.name}"]
if self.likelihood.priors:
priors_msg = "\n ".join([f"{k} ~ {v}" for k, v in self.likelihood.priors.items()])
msg_list += [f"Priors:\n {priors_msg}"]
msg = "\n".join(msg_list)
return msg
def __repr__(self):
return self.__str__()
# Names of parameters that can receive a prior distribution for the built-in families
FAMILY_PARAMS = {
"beta": ("kappa",),
"gamma": ("alpha",),
"gaussian": ("sigma",),
"negativebinomial": ("alpha",),
"t": ("sigma", "nu"),
"wald": ("lam",),
}
def _extract_family_prior(family, priors):
"""Extract priors for a given family
If a key in the priors dictionary matches the name of a nuisance parameter of the response
distribution for the given family, this function extracts and returns the prior for that
nuisance parameter. The result of this function can be safely used to update the ``Prior`` of
the response term.
Parameters
----------
family: str or ``bambi.families.Family``
The family for which we want to extract priors.
priors: dict
A dictionary where keys represent parameter/term names and values represent
prior distributions.
"""
if isinstance(family, str) and family in FAMILY_PARAMS:
names = FAMILY_PARAMS[family]
priors = {name: priors.pop(name) for name in names if priors.get(name) is not None}
if priors:
return priors
elif isinstance(family, Family):
# Only work if there are auxiliary parameters in the family, and if any of these are
# present in 'priors' dictionary.
nuisance_params = list(family.likelihood.priors)
if set(nuisance_params).intersection(set(priors)):
return {k: priors.pop(k) for k in nuisance_params if k in priors}
return None