Refactor causal graph handling and remove advanced methods

Refactor causal graph node initialization and edge addition. Removed advanced methods for confounder detection and adjustment set identification in Lite version.
pull/1233/head
CI-DEV 4 weeks ago committed by GitHub
parent b987ea0342
commit 38a7cf83b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -91,15 +91,19 @@ class CRCALite:
# Initialize graph # Initialize graph
if variables: if variables:
for var in variables: for var in variables:
if var not in self.causal_graph: self._ensure_node_exists(var)
self.causal_graph[var] = {}
if var not in self.causal_graph_reverse:
self.causal_graph_reverse[var] = []
if causal_edges: if causal_edges:
for source, target in causal_edges: for source, target in causal_edges:
self.add_causal_relationship(source, target) self.add_causal_relationship(source, target)
def _ensure_node_exists(self, node: str) -> None:
"""Ensure node present in graph structures."""
if node not in self.causal_graph:
self.causal_graph[node] = {}
if node not in self.causal_graph_reverse:
self.causal_graph_reverse[node] = []
def add_causal_relationship( def add_causal_relationship(
self, self,
source: str, source: str,
@ -118,20 +122,14 @@ class CRCALite:
relation_type: Type of causal relation (default: DIRECT) relation_type: Type of causal relation (default: DIRECT)
confidence: Confidence in the relationship (default: 1.0) confidence: Confidence in the relationship (default: 1.0)
""" """
# Initialize nodes if needed # Ensure nodes exist
if source not in self.causal_graph: self._ensure_node_exists(source)
self.causal_graph[source] = {} self._ensure_node_exists(target)
if target not in self.causal_graph:
self.causal_graph[target] = {}
if source not in self.causal_graph_reverse:
self.causal_graph_reverse[source] = []
if target not in self.causal_graph_reverse:
self.causal_graph_reverse[target] = []
# Add edge: source -> target with strength # Add or update edge: source -> target with strength
self.causal_graph[source][target] = strength self.causal_graph[source][target] = float(strength)
# Update reverse mapping for parent lookup # Update reverse mapping for parent lookup (avoid duplicates)
if source not in self.causal_graph_reverse[target]: if source not in self.causal_graph_reverse[target]:
self.causal_graph_reverse[target].append(source) self.causal_graph_reverse[target].append(source)
@ -226,53 +224,7 @@ class CRCALite:
return [] # No path found return [] # No path found
def detect_confounders(self, treatment: str, outcome: str) -> List[str]: # detect_confounders removed in Lite version (advanced inference)
"""
Detect confounders: variables that are ancestors of both treatment and outcome.
Args:
treatment: Treatment variable
outcome: Outcome variable
Returns:
List of confounder variable names
"""
def get_ancestors(node: str) -> set:
"""Get all ancestors of a node using DFS."""
ancestors = set()
stack = [node]
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
parents = self._get_parents(current)
for parent in parents:
if parent not in ancestors:
ancestors.add(parent)
stack.append(parent)
return ancestors
if treatment not in self.causal_graph or outcome not in self.causal_graph:
return []
t_ancestors = get_ancestors(treatment)
o_ancestors = get_ancestors(outcome)
# Confounders are common ancestors
confounders = list(t_ancestors & o_ancestors)
# Verify they have paths to both treatment and outcome
valid_confounders = []
for conf in confounders:
if (self._has_path(conf, treatment) and self._has_path(conf, outcome)):
valid_confounders.append(conf)
return valid_confounders
def _has_path(self, start: str, end: str) -> bool: def _has_path(self, start: str, end: str) -> bool:
""" """
@ -305,52 +257,7 @@ class CRCALite:
return False return False
def identify_adjustment_set(self, treatment: str, outcome: str) -> List[str]: # identify_adjustment_set removed in Lite version (advanced inference)
"""
Identify back-door adjustment set for causal effect estimation.
Args:
treatment: Treatment variable
outcome: Outcome variable
Returns:
List of variables in the adjustment set
"""
if treatment not in self.causal_graph or outcome not in self.causal_graph:
return []
# Get parents of treatment
parents_t = set(self._get_parents(treatment))
# Get descendants of treatment
def get_descendants(node: str) -> set:
"""Get all descendants using DFS."""
descendants = set()
stack = [node]
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
for child in self._get_children(current):
if child not in descendants:
descendants.add(child)
stack.append(child)
return descendants
descendants_t = get_descendants(treatment)
# Adjustment set: parents of treatment that are not descendants and not the outcome
adjustment = [
z for z in parents_t
if z not in descendants_t and z != outcome
]
return adjustment
def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]: def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]:
""" """
@ -399,7 +306,7 @@ class CRCALite:
x_{t+1} = E(x_t) x_{t+1} = E(x_t)
Mathematical foundation: Mathematical foundation:
- Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε - Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε\n+ - NOTE: This implementation is linear. To model nonlinear dynamics override\n+ `_predict_outcomes` in a subclass with a custom evolution operator.\n*** End Patch
- Propagates effects through causal graph in topological order - Propagates effects through causal graph in topological order
- Standardizes inputs, computes in z-space, de-standardizes outputs - Standardizes inputs, computes in z-space, de-standardizes outputs
@ -449,16 +356,18 @@ class CRCALite:
interventions: Dict[str, float] interventions: Dict[str, float]
) -> float: ) -> float:
""" """
Calculate probability of a counterfactual scenario. Calculate a heuristic probability of a counterfactual scenario.
Uses Mahalanobis distance in standardized space. NOTE: This is a lightweight heuristic proximity measure (Mahalanobis-like)
and NOT a full statistical estimator it ignores covariance and should
be treated as a relative plausibility score for Lite usage.
Args: Args:
factual_state: Baseline state factual_state: Baseline state
interventions: Intervention values interventions: Intervention values
Returns: Returns:
Probability value between 0.05 and 0.98 Heuristic probability value between 0.05 and 0.98
""" """
z_sq = 0.0 z_sq = 0.0
for var, new in interventions.items(): for var, new in interventions.items():
@ -490,26 +399,7 @@ class CRCALite:
Returns: Returns:
List of CounterfactualScenario objects List of CounterfactualScenario objects
""" """
scenarios = [] # Ensure stats exist for variables in factual_state (fallback behavior)\n+ self.ensure_standardization_stats(factual_state)\n+\n+ scenarios: List[CounterfactualScenario] = []\n+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+\n+ for i, tv in enumerate(target_variables[:max_scenarios]):\n+ stats = self.standardization_stats.get(tv, {\"mean\": 0.0, \"std\": 1.0})\n+ cur = factual_state.get(tv, stats.get(\"mean\", 0.0))\n+\n+ # If std is zero or missing, use absolute perturbations instead\n+ if not stats or stats.get(\"std\", 0.0) <= 0:\n+ base = cur\n+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+ vals = [base + step for step in abs_steps]\n+ else:\n+ mean = stats[\"mean\"]\n+ std = stats[\"std\"]\n+ cz = (cur - mean) / std\n+ vals = [(cz + dz) * std + mean for dz in z_steps]\n+\n+ for j, v in enumerate(vals):\n+ scenarios.append(CounterfactualScenario(\n+ name=f\"scenario_{i}_{j}\",\n+ interventions={tv: float(v)},\n+ expected_outcomes=self._predict_outcomes(factual_state, {tv: float(v)}),\n+ probability=self._calculate_scenario_probability(factual_state, {tv: float(v)}),\n+ reasoning=f\"Intervention on {tv} with value {v}\"\n+ ))\n+\n+ return scenarios\n*** End Patch
z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
for i, tv in enumerate(target_variables[:max_scenarios]):
s = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
cur = factual_state.get(tv, 0.0)
cz = (cur - s["mean"]) / s["std"] if s["std"] > 0 else 0.0
vals = [(cz + dz) * s["std"] + s["mean"] for dz in z_steps]
for j, v in enumerate(vals):
scenarios.append(CounterfactualScenario(
name=f"scenario_{i}_{j}",
interventions={tv: v},
expected_outcomes=self._predict_outcomes(factual_state, {tv: v}),
probability=self._calculate_scenario_probability(factual_state, {tv: v}),
reasoning=f"Intervention on {tv} with value {v}"
))
return scenarios
def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]: def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
""" """
@ -552,6 +442,16 @@ class CRCALite:
""" """
self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0} self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0}
def ensure_standardization_stats(self, state: Dict[str, float]) -> None:
"""
Ensure standardization stats exist for all variables in a given state.
If stats are missing, create a sensible fallback (mean=observed, std=1.0).
This prevents degenerate std=0 issues in Lite mode.
"""
for var, val in state.items():
if var not in self.standardization_stats:
self.standardization_stats[var] = {"mean": float(val), "std": 1.0}
def get_nodes(self) -> List[str]: def get_nodes(self) -> List[str]:
""" """
Get all nodes in the causal graph. Get all nodes in the causal graph.
@ -636,9 +536,10 @@ class CRCALite:
for step in range(max_steps): for step in range(max_steps):
current_state = self._predict_outcomes(current_state, {}) current_state = self._predict_outcomes(current_state, {})
# Generate counterfactual scenarios # Ensure standardization stats exist for the evolved state and generate counterfactuals from it
self.ensure_standardization_stats(current_state)
counterfactual_scenarios = self.generate_counterfactual_scenarios( counterfactual_scenarios = self.generate_counterfactual_scenarios(
initial_state, current_state,
target_variables, target_variables,
max_scenarios=5 max_scenarios=5
) )

Loading…
Cancel
Save