Refactor causal graph handling and remove advanced methods

Refactor causal graph node initialization and edge addition. Removed advanced methods for confounder detection and adjustment set identification in Lite version.
pull/1233/head
CI-DEV 4 weeks ago committed by GitHub
parent b987ea0342
commit 38a7cf83b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -91,15 +91,19 @@ class CRCALite:
# Initialize graph
if variables:
for var in variables:
if var not in self.causal_graph:
self.causal_graph[var] = {}
if var not in self.causal_graph_reverse:
self.causal_graph_reverse[var] = []
self._ensure_node_exists(var)
if causal_edges:
for source, target in causal_edges:
self.add_causal_relationship(source, target)
def _ensure_node_exists(self, node: str) -> None:
"""Ensure node present in graph structures."""
if node not in self.causal_graph:
self.causal_graph[node] = {}
if node not in self.causal_graph_reverse:
self.causal_graph_reverse[node] = []
def add_causal_relationship(
self,
source: str,
@ -118,20 +122,14 @@ class CRCALite:
relation_type: Type of causal relation (default: DIRECT)
confidence: Confidence in the relationship (default: 1.0)
"""
# Initialize nodes if needed
if source not in self.causal_graph:
self.causal_graph[source] = {}
if target not in self.causal_graph:
self.causal_graph[target] = {}
if source not in self.causal_graph_reverse:
self.causal_graph_reverse[source] = []
if target not in self.causal_graph_reverse:
self.causal_graph_reverse[target] = []
# Add edge: source -> target with strength
self.causal_graph[source][target] = strength
# Update reverse mapping for parent lookup
# Ensure nodes exist
self._ensure_node_exists(source)
self._ensure_node_exists(target)
# Add or update edge: source -> target with strength
self.causal_graph[source][target] = float(strength)
# Update reverse mapping for parent lookup (avoid duplicates)
if source not in self.causal_graph_reverse[target]:
self.causal_graph_reverse[target].append(source)
@ -226,53 +224,7 @@ class CRCALite:
return [] # No path found
def detect_confounders(self, treatment: str, outcome: str) -> List[str]:
"""
Detect confounders: variables that are ancestors of both treatment and outcome.
Args:
treatment: Treatment variable
outcome: Outcome variable
Returns:
List of confounder variable names
"""
def get_ancestors(node: str) -> set:
"""Get all ancestors of a node using DFS."""
ancestors = set()
stack = [node]
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
parents = self._get_parents(current)
for parent in parents:
if parent not in ancestors:
ancestors.add(parent)
stack.append(parent)
return ancestors
if treatment not in self.causal_graph or outcome not in self.causal_graph:
return []
t_ancestors = get_ancestors(treatment)
o_ancestors = get_ancestors(outcome)
# Confounders are common ancestors
confounders = list(t_ancestors & o_ancestors)
# Verify they have paths to both treatment and outcome
valid_confounders = []
for conf in confounders:
if (self._has_path(conf, treatment) and self._has_path(conf, outcome)):
valid_confounders.append(conf)
return valid_confounders
# detect_confounders removed in Lite version (advanced inference)
def _has_path(self, start: str, end: str) -> bool:
"""
@ -305,52 +257,7 @@ class CRCALite:
return False
def identify_adjustment_set(self, treatment: str, outcome: str) -> List[str]:
"""
Identify back-door adjustment set for causal effect estimation.
Args:
treatment: Treatment variable
outcome: Outcome variable
Returns:
List of variables in the adjustment set
"""
if treatment not in self.causal_graph or outcome not in self.causal_graph:
return []
# Get parents of treatment
parents_t = set(self._get_parents(treatment))
# Get descendants of treatment
def get_descendants(node: str) -> set:
"""Get all descendants using DFS."""
descendants = set()
stack = [node]
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
for child in self._get_children(current):
if child not in descendants:
descendants.add(child)
stack.append(child)
return descendants
descendants_t = get_descendants(treatment)
# Adjustment set: parents of treatment that are not descendants and not the outcome
adjustment = [
z for z in parents_t
if z not in descendants_t and z != outcome
]
return adjustment
# identify_adjustment_set removed in Lite version (advanced inference)
def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]:
"""
@ -399,7 +306,7 @@ class CRCALite:
x_{t+1} = E(x_t)
Mathematical foundation:
- Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε
- Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε\n+ - NOTE: This implementation is linear. To model nonlinear dynamics override\n+ `_predict_outcomes` in a subclass with a custom evolution operator.\n*** End Patch
- Propagates effects through causal graph in topological order
- Standardizes inputs, computes in z-space, de-standardizes outputs
@ -449,16 +356,18 @@ class CRCALite:
interventions: Dict[str, float]
) -> float:
"""
Calculate probability of a counterfactual scenario.
Calculate a heuristic probability of a counterfactual scenario.
Uses Mahalanobis distance in standardized space.
NOTE: This is a lightweight heuristic proximity measure (Mahalanobis-like)
and NOT a full statistical estimator it ignores covariance and should
be treated as a relative plausibility score for Lite usage.
Args:
factual_state: Baseline state
interventions: Intervention values
Returns:
Probability value between 0.05 and 0.98
Heuristic probability value between 0.05 and 0.98
"""
z_sq = 0.0
for var, new in interventions.items():
@ -490,26 +399,7 @@ class CRCALite:
Returns:
List of CounterfactualScenario objects
"""
scenarios = []
z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
for i, tv in enumerate(target_variables[:max_scenarios]):
s = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
cur = factual_state.get(tv, 0.0)
cz = (cur - s["mean"]) / s["std"] if s["std"] > 0 else 0.0
vals = [(cz + dz) * s["std"] + s["mean"] for dz in z_steps]
for j, v in enumerate(vals):
scenarios.append(CounterfactualScenario(
name=f"scenario_{i}_{j}",
interventions={tv: v},
expected_outcomes=self._predict_outcomes(factual_state, {tv: v}),
probability=self._calculate_scenario_probability(factual_state, {tv: v}),
reasoning=f"Intervention on {tv} with value {v}"
))
return scenarios
# Ensure stats exist for variables in factual_state (fallback behavior)\n+ self.ensure_standardization_stats(factual_state)\n+\n+ scenarios: List[CounterfactualScenario] = []\n+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+\n+ for i, tv in enumerate(target_variables[:max_scenarios]):\n+ stats = self.standardization_stats.get(tv, {\"mean\": 0.0, \"std\": 1.0})\n+ cur = factual_state.get(tv, stats.get(\"mean\", 0.0))\n+\n+ # If std is zero or missing, use absolute perturbations instead\n+ if not stats or stats.get(\"std\", 0.0) <= 0:\n+ base = cur\n+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+ vals = [base + step for step in abs_steps]\n+ else:\n+ mean = stats[\"mean\"]\n+ std = stats[\"std\"]\n+ cz = (cur - mean) / std\n+ vals = [(cz + dz) * std + mean for dz in z_steps]\n+\n+ for j, v in enumerate(vals):\n+ scenarios.append(CounterfactualScenario(\n+ name=f\"scenario_{i}_{j}\",\n+ interventions={tv: float(v)},\n+ expected_outcomes=self._predict_outcomes(factual_state, {tv: float(v)}),\n+ probability=self._calculate_scenario_probability(factual_state, {tv: float(v)}),\n+ reasoning=f\"Intervention on {tv} with value {v}\"\n+ ))\n+\n+ return scenarios\n*** End Patch
def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
"""
@ -552,6 +442,16 @@ class CRCALite:
"""
self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0}
def ensure_standardization_stats(self, state: Dict[str, float]) -> None:
"""
Ensure standardization stats exist for all variables in a given state.
If stats are missing, create a sensible fallback (mean=observed, std=1.0).
This prevents degenerate std=0 issues in Lite mode.
"""
for var, val in state.items():
if var not in self.standardization_stats:
self.standardization_stats[var] = {"mean": float(val), "std": 1.0}
def get_nodes(self) -> List[str]:
"""
Get all nodes in the causal graph.
@ -636,9 +536,10 @@ class CRCALite:
for step in range(max_steps):
current_state = self._predict_outcomes(current_state, {})
# Generate counterfactual scenarios
# Ensure standardization stats exist for the evolved state and generate counterfactuals from it
self.ensure_standardization_stats(current_state)
counterfactual_scenarios = self.generate_counterfactual_scenarios(
initial_state,
current_state,
target_variables,
max_scenarios=5
)

Loading…
Cancel
Save