Source code for bambi.families.family

from statsmodels.genmod import families as sm_families

from bambi.families.link import Link

STATSMODELS_FAMILIES = {
    "bernoulli": sm_families.Binomial,
    "gamma": sm_families.Gamma,
    "gaussian": sm_families.Gaussian,
    "wald": sm_families.InverseGaussian,
    "negativebinomial": sm_families.NegativeBinomial,
    "poisson": sm_families.Poisson,
}

STATSMODELS_LINKS = {
    "identity": sm_families.links.identity(),
    "logit": sm_families.links.logit(),
    "probit": sm_families.links.probit(),
    "cloglog": sm_families.links.cloglog(),
    "inverse": sm_families.links.inverse_power(),
    "inverse_squared": sm_families.links.inverse_squared(),
    "log": sm_families.links.log(),
}


[docs]class Family: """A specification of model family. Parameters ---------- name : str The name of the family. It can be any string. likelihood: Likelihood A ``bambi.families.Likelihood`` instace specifying the model likelihood function. link : str or Link The name of the link function or a ``bambi.families.Link`` instance. The link function transforms the linear model prediction to the mean parameter of the likelihood funtion. Examples -------- >>> import bambi as bmb Replicate the Gaussian built-in family. >>> sigma_prior = bmb.Prior("HalfNormal", sigma=1) >>> likelihood = bmb.Likelihood("Gaussian", parent="mu", sigma=sigma_prior) >>> family = bmb.Family("my_gaussian", likelihood, "identity") >>> # Then you can do >>> # bmb.Model("y ~ x", data, family=family) Replicate the Bernoulli built-in family. >>> likelihood = bmb.Likelihood("Bernoulli", parent="p") >>> family = bmb.Family("bernoulli2", likelihood, "logit") """ def __init__(self, name, likelihood, link): self.smlink = None self.link = None self.name = name self.likelihood = likelihood self.smfamily = STATSMODELS_FAMILIES.get(name, None) self._set_link(link) def _set_link(self, link): """Set new link function. If ``link`` is of type ``str``, this method attempts to create a ``bambi.families.Link`` from the name passed. If it is a recognized name, a builtin ``bambi.families.Link`` will be used. Otherwise, ``bambi.families.Link`` instantiation will raise an error. Parameters ---------- link: str or bambi.families.Link If a string, it must the name of a link function recognized by Bambi. Returns ------- None """ if isinstance(link, str): self.link = Link(link) self.smlink = STATSMODELS_LINKS.get(link, None) elif isinstance(link, Link): self.link = link else: raise ValueError("'link' must be a string or a Link instance.") def __str__(self): msg_list = [f"Response distribution: {self.likelihood.name}", f"Link: {self.link.name}"] if self.likelihood.priors: priors_msg = "\n ".join([f"{k} ~ {v}" for k, v in self.likelihood.priors.items()]) msg_list += [f"Priors:\n {priors_msg}"] msg = "\n".join(msg_list) return msg def __repr__(self): return self.__str__()
# Names of parameters that can receive a prior distribution for the built-in families FAMILY_PARAMS = { "beta": ("kappa",), "gamma": ("alpha",), "gaussian": ("sigma",), "negativebinomial": ("alpha",), "t": ("sigma", "nu"), "wald": ("lam",), } def _extract_family_prior(family, priors): """Extract priors for a given family If a key in the priors dictionary matches the name of a nuisance parameter of the response distribution for the given family, this function extracts and returns the prior for that nuisance parameter. The result of this function can be safely used to update the ``Prior`` of the response term. Parameters ---------- family: str or ``bambi.families.Family`` The family for which we want to extract priors. priors: dict A dictionary where keys represent parameter/term names and values represent prior distributions. """ if isinstance(family, str) and family in FAMILY_PARAMS: names = FAMILY_PARAMS[family] priors = {name: priors.pop(name) for name in names if priors.get(name) is not None} if priors: return priors elif isinstance(family, Family): # Only work if there are auxiliary parameters in the family, and if any of these are # present in 'priors' dictionary. nuisance_params = list(family.likelihood.priors) if set(nuisance_params).intersection(set(priors)): return {k: priors.pop(k) for k in nuisance_params if k in priors} return None