Source code for tablemage._src.causal.report

import numpy as np
from ..display.print_options import print_options
from ..display.print_utils import (
    color_text,
    bold_text,
    list_to_string,
    fill_ignore_format,
    format_two_column,
)
from scipy.stats import norm


[docs] class CausalReport: """Class for storing and displaying causal inference results.""" def __init__( self, estimate: float, se: float, n_units: int, n_units_treated: int, outcome_var: str, treatment_var: str, confounders: list[str], estimand: str, method: str, method_description: str, p_value: float | None = None, ): """Initializes a CausalReport object. Parameters ---------- estimate : float The estimate of the causal effect. se : float The standard error of the estimator. n_units : int The number of units in the data. n_units_treated : int The number of treated units in the data. outcome_var : str The name of the outcome variable. treatment_var : str The name of the treatment variable. confounders : list[str] The names of the confounding variables. estimand : str The estimand of the causal effect. Either "ate" or "att". method : str The method used to estimate the causal effect. method_description : str A description of the method used to estimate the causal effect. """ self._estimate = estimate self._estimate_se = se self._n_units = n_units self._n_units_treated = n_units_treated self._outcome_var = outcome_var self._treatment_var = treatment_var self._confounders = confounders self._estimand = estimand self._method = method self._method_description = method_description if p_value is not None: self._p_value = p_value else: # Calculate p-value from standard error self._p_value = 2 * (1 - norm.cdf(abs(self._estimate) / self._estimate_se))
[docs] def effect(self): """Returns the estimate of the causal effect.""" return self._estimate
[docs] def se(self): """Returns the standard error of the estimator.""" return self._estimate_se
[docs] def n_units(self): """Returns the number of units in the data.""" return self._n_units
[docs] def pval(self): """Returns the p-value of the estimator.""" return self._p_value
def _to_dict(self): """Converts the CausalReport object to a dictionary.""" return { "estimand": self._estimand, "estimate": self._estimate, "se": self._estimate_se, "p_value": self._p_value, "n_units": self._n_units, "n_units_treated": self._n_units_treated, "outcome_var": self._outcome_var, "treatment_var": self._treatment_var, "confounders": self._confounders, "method": self._method, "method_description": self._method_description, } def __str__(self): max_width = print_options._max_line_width n_dec = print_options._n_decimals top_divider = color_text("=" * max_width, "none") + "\n" bottom_divider = "\n" + color_text("=" * max_width, "none") divider = "\n" + color_text("-" * max_width, "none") + "\n" divider_invisible = "\n" + " " * max_width + "\n" title_message = bold_text("Causal Effect Estimation Report") estimand = ( "Avg Trmt Effect (ATE)" if self._estimand == "ate" else "Avg Trmt Effect on Trtd (ATT)" ) estimate_message = "" estimate_message += ( format_two_column( f"{bold_text('Estimate:')} " f"{color_text(f'{self._estimate:.{n_dec}f}', 'yellow')}", f"{bold_text('Std Err:')} " f"{color_text(f'{self._estimate_se:.{n_dec}f}', 'yellow')}", max_width, ) + "\n" ) pval_str = f"{self._p_value:.{n_dec}e}" if self._p_value <= 0.05: pval_color = "green" else: pval_color = "red" estimate_message += format_two_column( f"{bold_text('Estimand:')} " f"{color_text(estimand, 'blue')}", f"{bold_text('p-value:')} " f"{color_text(pval_str, pval_color)}", max_width, ) treatment_message = bold_text("Treatment variable:\n") + color_text( " '" + self._treatment_var + "'", "purple" ) outcome_message = bold_text("Outcome variable:\n") + color_text( " '" + self._outcome_var + "'", "purple" ) confounders_message = bold_text("Confounders:\n") if len(self._confounders) == 0: confounders_message += fill_ignore_format( color_text("None", "yellow"), initial_indent=2, subsequent_indent=2 ) else: confounders_message += fill_ignore_format( list_to_string(self._confounders), initial_indent=2, subsequent_indent=2 ) method_message = bold_text("Method:\n") + fill_ignore_format( color_text(self._method, "blue"), initial_indent=2, subsequent_indent=2 ) return ( top_divider + title_message + divider + estimate_message + divider + treatment_message + divider_invisible + outcome_message + divider_invisible + confounders_message + divider + method_message + bottom_divider ) def _repr_pretty_(self, p, cycle): if cycle: p.text(self.__class__.__name__ + "(...)") else: p.text(str(self))