Merge b15f70ce86 into d1cb8a6da2
commit
9d5d623827
@ -0,0 +1,630 @@
|
||||
"""
|
||||
CR-CA Lite: A lightweight Causal Reasoning with Counterfactual Analysis engine.
|
||||
|
||||
This is a minimal implementation of the ASTT/CR-CA framework focusing on:
|
||||
- Core evolution operator E(x)
|
||||
- Counterfactual scenario generation
|
||||
- Causal chain identification
|
||||
- Basic causal graph operations
|
||||
|
||||
Dependencies: numpy only (typing, dataclasses, enum are stdlib)
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CausalRelationType(Enum):
|
||||
"""Types of causal relationships"""
|
||||
DIRECT = "direct"
|
||||
INDIRECT = "indirect"
|
||||
CONFOUNDING = "confounding"
|
||||
MEDIATING = "mediating"
|
||||
MODERATING = "moderating"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalNode:
|
||||
"""Represents a node in the causal graph"""
|
||||
name: str
|
||||
value: Optional[float] = None
|
||||
confidence: float = 1.0
|
||||
node_type: str = "variable"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalEdge:
|
||||
"""Represents an edge in the causal graph"""
|
||||
source: str
|
||||
target: str
|
||||
strength: float = 1.0
|
||||
relation_type: CausalRelationType = CausalRelationType.DIRECT
|
||||
confidence: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CounterfactualScenario:
|
||||
"""Represents a counterfactual scenario"""
|
||||
name: str
|
||||
interventions: Dict[str, float]
|
||||
expected_outcomes: Dict[str, float]
|
||||
probability: float = 1.0
|
||||
reasoning: str = ""
|
||||
|
||||
|
||||
class CRCAgent:
|
||||
"""
|
||||
CR-CA Lite: Lightweight Causal Reasoning with Counterfactual Analysis engine.
|
||||
|
||||
Core components:
|
||||
- Evolution operator: E(x) = _predict_outcomes()
|
||||
- Counterfactual generation: generate_counterfactual_scenarios()
|
||||
- Causal chain identification: identify_causal_chain()
|
||||
- State mapping: _standardize_state() / _destandardize_value()
|
||||
|
||||
Args:
|
||||
variables: Optional list of variable names
|
||||
causal_edges: Optional list of (source, target) tuples for initial edges
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variables: Optional[List[str]] = None,
|
||||
causal_edges: Optional[List[Tuple[str, str]]] = None,
|
||||
max_loops: Optional[Union[int, str]] = 1,
|
||||
):
|
||||
"""
|
||||
Initialize CR-CA Lite engine.
|
||||
|
||||
Args:
|
||||
variables: Optional list of variable names to add to graph
|
||||
causal_edges: Optional list of (source, target) tuples for initial edges
|
||||
"""
|
||||
# Pure Python graph representation: {node: {child: strength}}
|
||||
self.causal_graph: Dict[str, Dict[str, float]] = {}
|
||||
self.causal_graph_reverse: Dict[str, List[str]] = {} # For fast parent lookup
|
||||
|
||||
# Standardization statistics: {'var': {'mean': m, 'std': s}}
|
||||
self.standardization_stats: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
# Initialize graph
|
||||
if variables:
|
||||
for var in variables:
|
||||
self._ensure_node_exists(var)
|
||||
|
||||
if causal_edges:
|
||||
for source, target in causal_edges:
|
||||
self.add_causal_relationship(source, target)
|
||||
|
||||
# Agent-like loop control: accept numeric or "auto"
|
||||
# Keep the original (possibly "auto") value; resolution happens at run time.
|
||||
self.max_loops = max_loops
|
||||
|
||||
def _ensure_node_exists(self, node: str) -> None:
|
||||
"""Ensure node present in graph structures."""
|
||||
if node not in self.causal_graph:
|
||||
self.causal_graph[node] = {}
|
||||
if node not in self.causal_graph_reverse:
|
||||
self.causal_graph_reverse[node] = []
|
||||
|
||||
def add_causal_relationship(
|
||||
self,
|
||||
source: str,
|
||||
target: str,
|
||||
strength: float = 1.0,
|
||||
relation_type: CausalRelationType = CausalRelationType.DIRECT,
|
||||
confidence: float = 1.0
|
||||
) -> None:
|
||||
"""
|
||||
Add a causal edge to the graph.
|
||||
|
||||
Args:
|
||||
source: Source variable name
|
||||
target: Target variable name
|
||||
strength: Causal effect strength (default: 1.0)
|
||||
relation_type: Type of causal relation (default: DIRECT)
|
||||
confidence: Confidence in the relationship (default: 1.0)
|
||||
"""
|
||||
# Ensure nodes exist
|
||||
self._ensure_node_exists(source)
|
||||
self._ensure_node_exists(target)
|
||||
|
||||
# Add or update edge: source -> target with strength
|
||||
self.causal_graph[source][target] = float(strength)
|
||||
|
||||
# Update reverse mapping for parent lookup (avoid duplicates)
|
||||
if source not in self.causal_graph_reverse[target]:
|
||||
self.causal_graph_reverse[target].append(source)
|
||||
|
||||
def _get_parents(self, node: str) -> List[str]:
|
||||
"""
|
||||
Get parent nodes (predecessors) of a node.
|
||||
|
||||
Args:
|
||||
node: Node name
|
||||
|
||||
Returns:
|
||||
List of parent node names
|
||||
"""
|
||||
return self.causal_graph_reverse.get(node, [])
|
||||
|
||||
def _get_children(self, node: str) -> List[str]:
|
||||
"""
|
||||
Get child nodes (successors) of a node.
|
||||
|
||||
Args:
|
||||
node: Node name
|
||||
|
||||
Returns:
|
||||
List of child node names
|
||||
"""
|
||||
return list(self.causal_graph.get(node, {}).keys())
|
||||
|
||||
def _topological_sort(self) -> List[str]:
|
||||
"""
|
||||
Perform topological sort using Kahn's algorithm (pure Python).
|
||||
|
||||
Returns:
|
||||
List of nodes in topological order
|
||||
"""
|
||||
# Compute in-degrees
|
||||
in_degree: Dict[str, int] = {node: 0 for node in self.causal_graph.keys()}
|
||||
for node in self.causal_graph:
|
||||
for child in self._get_children(node):
|
||||
in_degree[child] = in_degree.get(child, 0) + 1
|
||||
|
||||
# Initialize queue with nodes having no incoming edges
|
||||
queue: List[str] = [node for node, degree in in_degree.items() if degree == 0]
|
||||
result: List[str] = []
|
||||
|
||||
# Process nodes
|
||||
while queue:
|
||||
node = queue.pop(0)
|
||||
result.append(node)
|
||||
|
||||
# Reduce in-degree of children
|
||||
for child in self._get_children(node):
|
||||
in_degree[child] -= 1
|
||||
if in_degree[child] == 0:
|
||||
queue.append(child)
|
||||
|
||||
return result
|
||||
|
||||
def identify_causal_chain(self, start: str, end: str) -> List[str]:
|
||||
"""
|
||||
Find shortest causal path from start to end using BFS (pure Python).
|
||||
|
||||
Implements core causal chain identification (Ax2, Ax6).
|
||||
|
||||
Args:
|
||||
start: Starting variable
|
||||
end: Target variable
|
||||
|
||||
Returns:
|
||||
List of variables forming the causal chain, or empty list if no path exists
|
||||
"""
|
||||
if start not in self.causal_graph or end not in self.causal_graph:
|
||||
return []
|
||||
|
||||
if start == end:
|
||||
return [start]
|
||||
|
||||
# BFS to find shortest path
|
||||
queue: List[Tuple[str, List[str]]] = [(start, [start])]
|
||||
visited: set = {start}
|
||||
|
||||
while queue:
|
||||
current, path = queue.pop(0)
|
||||
|
||||
# Check all children
|
||||
for child in self._get_children(current):
|
||||
if child == end:
|
||||
return path + [child]
|
||||
|
||||
if child not in visited:
|
||||
visited.add(child)
|
||||
queue.append((child, path + [child]))
|
||||
|
||||
return [] # No path found
|
||||
|
||||
# detect_confounders removed in Lite version (advanced inference)
|
||||
|
||||
def _has_path(self, start: str, end: str) -> bool:
|
||||
"""
|
||||
Check if a path exists from start to end using DFS.
|
||||
|
||||
Args:
|
||||
start: Starting node
|
||||
end: Target node
|
||||
|
||||
Returns:
|
||||
True if path exists, False otherwise
|
||||
"""
|
||||
if start == end:
|
||||
return True
|
||||
|
||||
stack = [start]
|
||||
visited = set()
|
||||
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
|
||||
for child in self._get_children(current):
|
||||
if child == end:
|
||||
return True
|
||||
if child not in visited:
|
||||
stack.append(child)
|
||||
|
||||
return False
|
||||
|
||||
# identify_adjustment_set removed in Lite version (advanced inference)
|
||||
|
||||
def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]:
|
||||
"""
|
||||
Standardize state values to z-scores.
|
||||
|
||||
Args:
|
||||
state: Dictionary of variable values
|
||||
|
||||
Returns:
|
||||
Dictionary of standardized (z-score) values
|
||||
"""
|
||||
z: Dict[str, float] = {}
|
||||
for k, v in state.items():
|
||||
s = self.standardization_stats.get(k)
|
||||
if s and s.get("std", 0.0) > 0:
|
||||
z[k] = (v - s["mean"]) / s["std"]
|
||||
else:
|
||||
z[k] = v
|
||||
return z
|
||||
|
||||
def _destandardize_value(self, var: str, z_value: float) -> float:
|
||||
"""
|
||||
Convert z-score back to original scale.
|
||||
|
||||
Args:
|
||||
var: Variable name
|
||||
z_value: Standardized (z-score) value
|
||||
|
||||
Returns:
|
||||
Original scale value
|
||||
"""
|
||||
s = self.standardization_stats.get(var)
|
||||
if s and s.get("std", 0.0) > 0:
|
||||
return z_value * s["std"] + s["mean"]
|
||||
return z_value
|
||||
|
||||
def _predict_outcomes(
|
||||
self,
|
||||
factual_state: Dict[str, float],
|
||||
interventions: Dict[str, float]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evolution operator E(x): Predict outcomes given state and interventions.
|
||||
|
||||
This is the core CR-CA evolution operator implementing:
|
||||
x_{t+1} = E(x_t)
|
||||
|
||||
Mathematical foundation:
|
||||
- Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε\n+ - NOTE: This implementation is linear. To model nonlinear dynamics override\n+ `_predict_outcomes` in a subclass with a custom evolution operator.\n*** End Patch
|
||||
- Propagates effects through causal graph in topological order
|
||||
- Standardizes inputs, computes in z-space, de-standardizes outputs
|
||||
|
||||
Args:
|
||||
factual_state: Current world state (baseline)
|
||||
interventions: Interventions to apply (do-operator)
|
||||
|
||||
Returns:
|
||||
Dictionary of predicted variable values
|
||||
"""
|
||||
# Merge factual state with interventions
|
||||
raw = factual_state.copy()
|
||||
raw.update(interventions)
|
||||
|
||||
# Standardize to z-scores
|
||||
z_state = self._standardize_state(raw)
|
||||
z_pred = dict(z_state)
|
||||
|
||||
# Process nodes in topological order
|
||||
for node in self._topological_sort():
|
||||
# If node is intervened on, keep its value
|
||||
if node in interventions:
|
||||
if node not in z_pred:
|
||||
z_pred[node] = z_state.get(node, 0.0)
|
||||
continue
|
||||
|
||||
# Get parents
|
||||
parents = self._get_parents(node)
|
||||
if not parents:
|
||||
continue
|
||||
|
||||
# Compute linear combination: Σᵢ βᵢ·z_xi
|
||||
s = 0.0
|
||||
for p in parents:
|
||||
pz = z_pred.get(p, z_state.get(p, 0.0))
|
||||
strength = self.causal_graph.get(p, {}).get(node, 0.0)
|
||||
s += pz * strength
|
||||
|
||||
z_pred[node] = s
|
||||
|
||||
# De-standardize results
|
||||
return {v: self._destandardize_value(v, z) for v, z in z_pred.items()}
|
||||
|
||||
def _calculate_scenario_probability(
|
||||
self,
|
||||
factual_state: Dict[str, float],
|
||||
interventions: Dict[str, float]
|
||||
) -> float:
|
||||
"""
|
||||
Calculate a heuristic probability of a counterfactual scenario.
|
||||
|
||||
NOTE: This is a lightweight heuristic proximity measure (Mahalanobis-like)
|
||||
and NOT a full statistical estimator — it ignores covariance and should
|
||||
be treated as a relative plausibility score for Lite usage.
|
||||
|
||||
Args:
|
||||
factual_state: Baseline state
|
||||
interventions: Intervention values
|
||||
|
||||
Returns:
|
||||
Heuristic probability value between 0.05 and 0.98
|
||||
"""
|
||||
z_sq = 0.0
|
||||
for var, new in interventions.items():
|
||||
s = self.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
|
||||
mu, sd = s.get("mean", 0.0), s.get("std", 1.0) or 1.0
|
||||
old = factual_state.get(var, mu)
|
||||
dz = (new - mu) / sd - (old - mu) / sd
|
||||
z_sq += float(dz) * float(dz)
|
||||
|
||||
p = 0.95 * float(np.exp(-0.5 * z_sq)) + 0.05
|
||||
return float(max(0.05, min(0.98, p)))
|
||||
|
||||
def generate_counterfactual_scenarios(
|
||||
self,
|
||||
factual_state: Dict[str, float],
|
||||
target_variables: List[str],
|
||||
max_scenarios: int = 5
|
||||
) -> List[CounterfactualScenario]:
|
||||
"""
|
||||
Generate counterfactual scenarios for target variables.
|
||||
|
||||
Implements Ax8 (Counterfactuals) - core CR-CA functionality.
|
||||
|
||||
Args:
|
||||
factual_state: Current factual state
|
||||
target_variables: Variables to generate counterfactuals for
|
||||
max_scenarios: Maximum number of scenarios per variable
|
||||
|
||||
Returns:
|
||||
List of CounterfactualScenario objects
|
||||
"""
|
||||
# Ensure stats exist for variables in factual_state (fallback behavior)
|
||||
self.ensure_standardization_stats(factual_state)
|
||||
|
||||
scenarios: List[CounterfactualScenario] = []
|
||||
z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
|
||||
|
||||
for i, tv in enumerate(target_variables[:max_scenarios]):
|
||||
stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
|
||||
cur = factual_state.get(tv, stats.get("mean", 0.0))
|
||||
|
||||
# If std is zero or missing, use absolute perturbations instead
|
||||
if not stats or stats.get("std", 0.0) <= 0:
|
||||
base = cur
|
||||
abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
|
||||
vals = [base + step for step in abs_steps]
|
||||
else:
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
cz = (cur - mean) / std
|
||||
vals = [(cz + dz) * std + mean for dz in z_steps]
|
||||
|
||||
for j, v in enumerate(vals):
|
||||
interventions = {tv: float(v)}
|
||||
scenarios.append(
|
||||
CounterfactualScenario(
|
||||
name=f"scenario_{i}_{j}",
|
||||
interventions=interventions,
|
||||
expected_outcomes=self._predict_outcomes(
|
||||
factual_state, interventions
|
||||
),
|
||||
probability=self._calculate_scenario_probability(
|
||||
factual_state, interventions
|
||||
),
|
||||
reasoning=f"Intervention on {tv} with value {v}",
|
||||
)
|
||||
)
|
||||
|
||||
return scenarios
|
||||
|
||||
def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
|
||||
"""
|
||||
Analyze the strength of causal relationship between two variables.
|
||||
|
||||
Args:
|
||||
source: Source variable
|
||||
target: Target variable
|
||||
|
||||
Returns:
|
||||
Dictionary with strength, confidence, path_length, relation_type
|
||||
"""
|
||||
if source not in self.causal_graph or target not in self.causal_graph[source]:
|
||||
return {"strength": 0.0, "confidence": 0.0, "path_length": float('inf')}
|
||||
|
||||
strength = self.causal_graph[source].get(target, 0.0)
|
||||
path = self.identify_causal_chain(source, target)
|
||||
path_length = len(path) - 1 if path else float('inf')
|
||||
|
||||
return {
|
||||
"strength": float(strength),
|
||||
"confidence": 1.0, # Simplified: assume full confidence
|
||||
"path_length": path_length,
|
||||
"relation_type": CausalRelationType.DIRECT.value
|
||||
}
|
||||
|
||||
def set_standardization_stats(
|
||||
self,
|
||||
variable: str,
|
||||
mean: float,
|
||||
std: float
|
||||
) -> None:
|
||||
"""
|
||||
Set standardization statistics for a variable.
|
||||
|
||||
Args:
|
||||
variable: Variable name
|
||||
mean: Mean value
|
||||
std: Standard deviation
|
||||
"""
|
||||
self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0}
|
||||
|
||||
def ensure_standardization_stats(self, state: Dict[str, float]) -> None:
|
||||
"""
|
||||
Ensure standardization stats exist for all variables in a given state.
|
||||
If stats are missing, create a sensible fallback (mean=observed, std=1.0).
|
||||
This prevents degenerate std=0 issues in Lite mode.
|
||||
"""
|
||||
for var, val in state.items():
|
||||
if var not in self.standardization_stats:
|
||||
self.standardization_stats[var] = {"mean": float(val), "std": 1.0}
|
||||
|
||||
def get_nodes(self) -> List[str]:
|
||||
"""
|
||||
Get all nodes in the causal graph.
|
||||
|
||||
Returns:
|
||||
List of node names
|
||||
"""
|
||||
return list(self.causal_graph.keys())
|
||||
|
||||
def get_edges(self) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Get all edges in the causal graph.
|
||||
|
||||
Returns:
|
||||
List of (source, target) tuples
|
||||
"""
|
||||
edges = []
|
||||
for source, targets in self.causal_graph.items():
|
||||
for target in targets.keys():
|
||||
edges.append((source, target))
|
||||
return edges
|
||||
|
||||
def is_dag(self) -> bool:
|
||||
"""
|
||||
Check if the causal graph is a DAG (no cycles).
|
||||
|
||||
Uses DFS to detect cycles.
|
||||
|
||||
Returns:
|
||||
True if DAG, False if cycles exist
|
||||
"""
|
||||
def has_cycle(node: str, visited: set, rec_stack: set) -> bool:
|
||||
"""DFS to detect cycles."""
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
|
||||
for child in self._get_children(node):
|
||||
if child not in visited:
|
||||
if has_cycle(child, visited, rec_stack):
|
||||
return True
|
||||
elif child in rec_stack:
|
||||
return True
|
||||
|
||||
rec_stack.remove(node)
|
||||
return False
|
||||
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
|
||||
for node in self.causal_graph:
|
||||
if node not in visited:
|
||||
if has_cycle(node, visited, rec_stack):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def run(
|
||||
self,
|
||||
initial_state: Any,
|
||||
target_variables: Optional[List[str]] = None,
|
||||
max_steps: Union[int, str] = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run causal simulation: evolve state and generate counterfactuals.
|
||||
|
||||
Simple entry point for CR-CA engine.
|
||||
|
||||
Args:
|
||||
initial_state: Initial world state
|
||||
target_variables: Variables to generate counterfactuals for (default: all nodes)
|
||||
max_steps: Number of evolution steps (default: 1)
|
||||
|
||||
Returns:
|
||||
Dictionary with evolved state, counterfactuals, and graph info
|
||||
"""
|
||||
# Accept either a dict initial_state or a JSON string (agent-like behavior)
|
||||
if not isinstance(initial_state, dict):
|
||||
try:
|
||||
import json
|
||||
parsed = json.loads(initial_state)
|
||||
if isinstance(parsed, dict):
|
||||
initial_state = parsed
|
||||
else:
|
||||
return {"error": "initial_state JSON must decode to a dict"}
|
||||
except Exception:
|
||||
return {"error": "initial_state must be a dict or JSON-encoded dict"}
|
||||
|
||||
# Use all nodes as targets if not specified
|
||||
if target_variables is None:
|
||||
target_variables = list(self.causal_graph.keys())
|
||||
|
||||
# Resolve "auto" sentinel for max_steps (accepts method arg or instance-level default)
|
||||
def _resolve_max_steps(value: Union[int, str]) -> int:
|
||||
if isinstance(value, str) and value == "auto":
|
||||
# Heuristic: one step per variable (at least 1)
|
||||
return max(1, len(self.causal_graph))
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return max(1, len(self.causal_graph))
|
||||
|
||||
effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.max_loops == 1 else self.max_loops)
|
||||
# If caller passed default 1 and instance set a different max_loops, prefer instance value
|
||||
if max_steps == 1 and self.max_loops != 1:
|
||||
effective_steps = _resolve_max_steps(self.max_loops)
|
||||
|
||||
# Evolve state
|
||||
current_state = initial_state.copy()
|
||||
for step in range(effective_steps):
|
||||
current_state = self._predict_outcomes(current_state, {})
|
||||
|
||||
# Ensure standardization stats exist for the evolved state and generate counterfactuals from it
|
||||
self.ensure_standardization_stats(current_state)
|
||||
counterfactual_scenarios = self.generate_counterfactual_scenarios(
|
||||
current_state,
|
||||
target_variables,
|
||||
max_scenarios=5
|
||||
)
|
||||
|
||||
return {
|
||||
"initial_state": initial_state,
|
||||
"evolved_state": current_state,
|
||||
"counterfactual_scenarios": counterfactual_scenarios,
|
||||
"causal_graph_info": {
|
||||
"nodes": self.get_nodes(),
|
||||
"edges": self.get_edges(),
|
||||
"is_dag": self.is_dag()
|
||||
},
|
||||
"steps": effective_steps
|
||||
}
|
||||
|
||||
|
||||
# Agent-like behavior: `run` accepts either a dict or a JSON string as the initial_state
|
||||
# so the engine behaves like a normal agent by default.
|
||||
Loading…
Reference in new issue