@ -91,15 +91,19 @@ class CRCALite:
# Initialize graph
if variables :
for var in variables :
if var not in self . causal_graph :
self . causal_graph [ var ] = { }
if var not in self . causal_graph_reverse :
self . causal_graph_reverse [ var ] = [ ]
self . _ensure_node_exists ( var )
if causal_edges :
for source , target in causal_edges :
self . add_causal_relationship ( source , target )
def _ensure_node_exists ( self , node : str ) - > None :
""" Ensure node present in graph structures. """
if node not in self . causal_graph :
self . causal_graph [ node ] = { }
if node not in self . causal_graph_reverse :
self . causal_graph_reverse [ node ] = [ ]
def add_causal_relationship (
self ,
source : str ,
@ -118,20 +122,14 @@ class CRCALite:
relation_type : Type of causal relation ( default : DIRECT )
confidence : Confidence in the relationship ( default : 1.0 )
"""
# Initialize nodes if needed
if source not in self . causal_graph :
self . causal_graph [ source ] = { }
if target not in self . causal_graph :
self . causal_graph [ target ] = { }
if source not in self . causal_graph_reverse :
self . causal_graph_reverse [ source ] = [ ]
if target not in self . causal_graph_reverse :
self . causal_graph_reverse [ target ] = [ ]
# Ensure nodes exist
self . _ensure_node_exists ( source )
self . _ensure_node_exists ( target )
# Add edge: source -> target with strength
self . causal_graph [ source ] [ target ] = strength
# Add or update edge: source -> target with strength
self . causal_graph [ source ] [ target ] = float ( strength )
# Update reverse mapping for parent lookup
# Update reverse mapping for parent lookup (avoid duplicates)
if source not in self . causal_graph_reverse [ target ] :
self . causal_graph_reverse [ target ] . append ( source )
@ -226,53 +224,7 @@ class CRCALite:
return [ ] # No path found
def detect_confounders ( self , treatment : str , outcome : str ) - > List [ str ] :
"""
Detect confounders : variables that are ancestors of both treatment and outcome .
Args :
treatment : Treatment variable
outcome : Outcome variable
Returns :
List of confounder variable names
"""
def get_ancestors ( node : str ) - > set :
""" Get all ancestors of a node using DFS. """
ancestors = set ( )
stack = [ node ]
visited = set ( )
while stack :
current = stack . pop ( )
if current in visited :
continue
visited . add ( current )
parents = self . _get_parents ( current )
for parent in parents :
if parent not in ancestors :
ancestors . add ( parent )
stack . append ( parent )
return ancestors
if treatment not in self . causal_graph or outcome not in self . causal_graph :
return [ ]
t_ancestors = get_ancestors ( treatment )
o_ancestors = get_ancestors ( outcome )
# Confounders are common ancestors
confounders = list ( t_ancestors & o_ancestors )
# Verify they have paths to both treatment and outcome
valid_confounders = [ ]
for conf in confounders :
if ( self . _has_path ( conf , treatment ) and self . _has_path ( conf , outcome ) ) :
valid_confounders . append ( conf )
return valid_confounders
# detect_confounders removed in Lite version (advanced inference)
def _has_path ( self , start : str , end : str ) - > bool :
"""
@ -305,52 +257,7 @@ class CRCALite:
return False
def identify_adjustment_set ( self , treatment : str , outcome : str ) - > List [ str ] :
"""
Identify back - door adjustment set for causal effect estimation .
Args :
treatment : Treatment variable
outcome : Outcome variable
Returns :
List of variables in the adjustment set
"""
if treatment not in self . causal_graph or outcome not in self . causal_graph :
return [ ]
# Get parents of treatment
parents_t = set ( self . _get_parents ( treatment ) )
# Get descendants of treatment
def get_descendants ( node : str ) - > set :
""" Get all descendants using DFS. """
descendants = set ( )
stack = [ node ]
visited = set ( )
while stack :
current = stack . pop ( )
if current in visited :
continue
visited . add ( current )
for child in self . _get_children ( current ) :
if child not in descendants :
descendants . add ( child )
stack . append ( child )
return descendants
descendants_t = get_descendants ( treatment )
# Adjustment set: parents of treatment that are not descendants and not the outcome
adjustment = [
z for z in parents_t
if z not in descendants_t and z != outcome
]
return adjustment
# identify_adjustment_set removed in Lite version (advanced inference)
def _standardize_state ( self , state : Dict [ str , float ] ) - > Dict [ str , float ] :
"""
@ -399,7 +306,7 @@ class CRCALite:
x_ { t + 1 } = E ( x_t )
Mathematical foundation :
- Linear structural causal model : y = Σᵢ βᵢ · xᵢ + ε
- Linear structural causal model : y = Σᵢ βᵢ · xᵢ + ε \n + - NOTE : This implementation is linear . To model nonlinear dynamics override \n + ` _predict_outcomes ` in a subclass with a custom evolution operator . \n * * * End Patch
- Propagates effects through causal graph in topological order
- Standardizes inputs , computes in z - space , de - standardizes outputs
@ -449,16 +356,18 @@ class CRCALite:
interventions : Dict [ str , float ]
) - > float :
"""
Calculate probability of a counterfactual scenario .
Calculate a heuristic probability of a counterfactual scenario .
Uses Mahalanobis distance in standardized space .
NOTE : This is a lightweight heuristic proximity measure ( Mahalanobis - like )
and NOT a full statistical estimator — it ignores covariance and should
be treated as a relative plausibility score for Lite usage .
Args :
factual_state : Baseline state
interventions : Intervention values
Returns :
P robability value between 0.05 and 0.98
Heuristic p robability value between 0.05 and 0.98
"""
z_sq = 0.0
for var , new in interventions . items ( ) :
@ -490,26 +399,7 @@ class CRCALite:
Returns :
List of CounterfactualScenario objects
"""
scenarios = [ ]
z_steps = [ - 2.0 , - 1.0 , - 0.5 , 0.5 , 1.0 , 2.0 ]
for i , tv in enumerate ( target_variables [ : max_scenarios ] ) :
s = self . standardization_stats . get ( tv , { " mean " : 0.0 , " std " : 1.0 } )
cur = factual_state . get ( tv , 0.0 )
cz = ( cur - s [ " mean " ] ) / s [ " std " ] if s [ " std " ] > 0 else 0.0
vals = [ ( cz + dz ) * s [ " std " ] + s [ " mean " ] for dz in z_steps ]
for j , v in enumerate ( vals ) :
scenarios . append ( CounterfactualScenario (
name = f " scenario_ { i } _ { j } " ,
interventions = { tv : v } ,
expected_outcomes = self . _predict_outcomes ( factual_state , { tv : v } ) ,
probability = self . _calculate_scenario_probability ( factual_state , { tv : v } ) ,
reasoning = f " Intervention on { tv } with value { v } "
) )
return scenarios
# Ensure stats exist for variables in factual_state (fallback behavior)\n+ self.ensure_standardization_stats(factual_state)\n+\n+ scenarios: List[CounterfactualScenario] = []\n+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+\n+ for i, tv in enumerate(target_variables[:max_scenarios]):\n+ stats = self.standardization_stats.get(tv, {\"mean\": 0.0, \"std\": 1.0})\n+ cur = factual_state.get(tv, stats.get(\"mean\", 0.0))\n+\n+ # If std is zero or missing, use absolute perturbations instead\n+ if not stats or stats.get(\"std\", 0.0) <= 0:\n+ base = cur\n+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+ vals = [base + step for step in abs_steps]\n+ else:\n+ mean = stats[\"mean\"]\n+ std = stats[\"std\"]\n+ cz = (cur - mean) / std\n+ vals = [(cz + dz) * std + mean for dz in z_steps]\n+\n+ for j, v in enumerate(vals):\n+ scenarios.append(CounterfactualScenario(\n+ name=f\"scenario_{i}_{j}\",\n+ interventions={tv: float(v)},\n+ expected_outcomes=self._predict_outcomes(factual_state, {tv: float(v)}),\n+ probability=self._calculate_scenario_probability(factual_state, {tv: float(v)}),\n+ reasoning=f\"Intervention on {tv} with value {v}\"\n+ ))\n+\n+ return scenarios\n*** End Patch
def analyze_causal_strength ( self , source : str , target : str ) - > Dict [ str , float ] :
"""
@ -552,6 +442,16 @@ class CRCALite:
"""
self . standardization_stats [ variable ] = { " mean " : mean , " std " : std if std > 0 else 1.0 }
def ensure_standardization_stats ( self , state : Dict [ str , float ] ) - > None :
"""
Ensure standardization stats exist for all variables in a given state .
If stats are missing , create a sensible fallback ( mean = observed , std = 1.0 ) .
This prevents degenerate std = 0 issues in Lite mode .
"""
for var , val in state . items ( ) :
if var not in self . standardization_stats :
self . standardization_stats [ var ] = { " mean " : float ( val ) , " std " : 1.0 }
def get_nodes ( self ) - > List [ str ] :
"""
Get all nodes in the causal graph .
@ -636,9 +536,10 @@ class CRCALite:
for step in range ( max_steps ) :
current_state = self . _predict_outcomes ( current_state , { } )
# Generate counterfactual scenarios
# Ensure standardization stats exist for the evolved state and generate counterfactuals from it
self . ensure_standardization_stats ( current_state )
counterfactual_scenarios = self . generate_counterfactual_scenarios (
initial _state,
current _state,
target_variables ,
max_scenarios = 5
)