pull/1233/merge
CI-DEV 1 week ago committed by GitHub
commit 9d5d623827
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -262,6 +262,9 @@ nav:
- Agent Judge: "swarms/agents/agent_judge.md"
- Reasoning Agent Router: "swarms/agents/reasoning_agent_router.md"
- CR-CA Agent:
- CR-CA Agent: "swarms/agents/cr_ca_agent.md"
- Multi-Agent Architectures:
- Overview: "swarms/concept/swarm_architectures.md"
- Benefits: "swarms/concept/why.md"
@ -298,8 +301,8 @@ nav:
- Routers:
- SwarmRouter: "swarms/structs/swarm_router.md"
- MultiAgentRouter: "swarms/structs/multi_agent_router.md"
- AgentRouter: "swarms/structs/agent_router.md"
- ModelRouter: "swarms/structs/model_router.md"
- AgentRouter: "swarms/structs/agent_router.md"
- Rearrangers:
@ -457,12 +460,12 @@ nav:
- Overview: "examples/apps_examples_overview.md"
- Web Scraper Agents: "developer_guides/web_scraper.md"
- Smart Database: "examples/smart_database.md"
- AOP:
- Overview: "examples/aop_examples_overview.md"
- Medical AOP Example: "examples/aop_medical.md"
- X402:
- X402:
- x402 Quickstart Example: "examples/x402_payment_integration.md"
- X402 Discovery Query Agent: "examples/x402_discovery_query.md"

@ -0,0 +1,498 @@
<!-- swarms/agents/cr_ca_agent.py — CRCAAgent (CRCA Lite) -->
# CRCAAgent
Short summary
-------------
CRCAAgent is a lightweight causal reasoning Agent with LLM integration,
implemented in pure Python and intended as a flexible CRCA engine for Swarms.
It provides both LLM-based causal analysis and deterministic causal simulation,
focusing on the core ASTT primitives: a causal DAG, a linear structural
evolution operator (in z-space), and compact counterfactual generation.
Key properties
- LLM integration for sophisticated causal reasoning (like full CRCAAgent)
- Dual-mode operation: LLM-based analysis and deterministic simulation
- Minimal dependencies (numpy + swarms Agent base)
- Pure-Python causal graph (adjacency dicts)
- Linear SCM evolution by default (overrideable)
- Agent-first `run()` entrypoint (accepts task string or dict/JSON payloads)
Canonical import
----------------
Use the canonical agent import in application code:
```python
from path.to.crca_agent import CRCAAgent
```
Quickstart
----------
Minimal example — deterministic mode: initialize, add edges, evolve state and get counterfactuals:
```python
from path.to.crca_agent import CRCAAgent
agent = CRCAAgent(variables=["price", "demand", "inventory"])
agent.add_causal_relationship("price", "demand", strength=-0.5)
agent.add_causal_relationship("demand", "inventory", strength=-0.2)
state = {"price": 100.0, "demand": 1000.0, "inventory": 5000.0}
out = agent.run(initial_state=state, target_variables=["price", "demand"], max_steps=1)
print("Evolved:", out["evolved_state"]) # evolved world state
for sc in out["counterfactual_scenarios"][:5]: # candidate CFs
print(sc.name, sc.interventions, sc.probability)
```
LLM-based causal analysis example
----------------------------------
```python
from path.to.crca_agent import CRCAAgent
agent = CRCAAgent(
variables=["price", "demand", "inventory"],
model_name="gpt-4o",
max_loops=3
)
agent.add_causal_relationship("price", "demand", strength=-0.5)
# LLM mode: pass task as string
task = "Analyze how increasing price affects demand and inventory levels"
result = agent.run(task=task)
print("Causal Analysis:", result["causal_analysis"])
print("Counterfactual Scenarios:", result["counterfactual_scenarios"])
print("Analysis Steps:", result["analysis_steps"])
```
Agent-style JSON payload example (orchestrators)
------------------------------------------------
```python
import json
from path.to.crca_agent import CRCAAgent
agent = CRCAAgent(variables=["price","demand","inventory"])
payload = json.dumps({"price": 100.0, "demand": 1000.0})
out = agent.run(initial_state=payload, target_variables=["price"], max_steps=1)
print(out["evolved_state"])
```
Complete example: Full workflow with system prompt
--------------------------------------------------
This example demonstrates a complete workflow from imports to execution, including
system prompt configuration, causal graph construction, and both LLM and deterministic modes.
```python
"""
Complete CRCAAgent example: Full workflow from initialization to execution
"""
# 1. Imports
from typing import Dict, Any
from path.to.crca_agent import CRCAAgent
# 2. System prompt configuration
# Define a custom system prompt for domain-specific causal reasoning
SYSTEM_PROMPT = """You are an expert causal reasoning analyst specializing in economic systems.
Your role is to:
- Identify causal relationships between economic variables
- Analyze how interventions affect system outcomes
- Generate plausible counterfactual scenarios
- Provide clear, evidence-based causal explanations
When analyzing causal relationships:
1. Consider both direct and indirect causal paths
2. Account for confounding factors
3. Evaluate intervention plausibility
4. Quantify expected causal effects when possible
Always ground your analysis in the provided causal graph structure and observed data."""
# 3. Agent initialization with system prompt
agent = CRCAAgent(
variables=["price", "demand", "inventory", "supply", "competition"],
agent_name="economic-causal-analyst",
agent_description="Expert economic causal reasoning agent",
model_name="gpt-4o", # or "gpt-4o-mini" for faster/cheaper analysis
max_loops=3, # Number of reasoning loops for LLM-based analysis
system_prompt=SYSTEM_PROMPT,
verbose=True, # Enable detailed logging
)
# 4. Build causal graph: Add causal relationships
# Price negatively affects demand (higher price → lower demand)
agent.add_causal_relationship("price", "demand", strength=-0.5)
# Demand negatively affects inventory (higher demand → lower inventory)
agent.add_causal_relationship("demand", "inventory", strength=-0.2)
# Supply positively affects inventory (higher supply → higher inventory)
agent.add_causal_relationship("supply", "inventory", strength=0.3)
# Competition negatively affects price (more competition → lower price)
agent.add_causal_relationship("competition", "price", strength=-0.4)
# Price positively affects supply (higher price → more supply)
agent.add_causal_relationship("price", "supply", strength=0.2)
# 5. Verify graph structure
print("Causal Graph Nodes:", agent.get_nodes())
print("Causal Graph Edges:", agent.get_edges())
print("Is DAG:", agent.is_dag())
# 6. Example 1: LLM-based causal analysis (sophisticated reasoning)
print("\n" + "="*80)
print("EXAMPLE 1: LLM-Based Causal Analysis")
print("="*80)
task = """
Analyze the causal relationship between price increases and inventory levels.
Consider both direct and indirect causal paths. What interventions could
stabilize inventory while maintaining profitability?
"""
result = agent.run(task=task)
print("\n--- Causal Analysis Report ---")
print(result["causal_analysis"])
print("\n--- Counterfactual Scenarios ---")
for i, scenario in enumerate(result["counterfactual_scenarios"][:3], 1):
print(f"\nScenario {i}: {scenario.name}")
print(f" Interventions: {scenario.interventions}")
print(f" Expected Outcomes: {scenario.expected_outcomes}")
print(f" Probability: {scenario.probability:.3f}")
print(f" Reasoning: {scenario.reasoning}")
print("\n--- Causal Graph Info ---")
print(f"Nodes: {result['causal_graph_info']['nodes']}")
print(f"Edges: {result['causal_graph_info']['edges']}")
print(f"Is DAG: {result['causal_graph_info']['is_dag']}")
# 7. Example 2: Deterministic simulation (script-style)
print("\n" + "="*80)
print("EXAMPLE 2: Deterministic Causal Simulation")
print("="*80)
# Initial state
initial_state = {
"price": 100.0,
"demand": 1000.0,
"inventory": 5000.0,
"supply": 2000.0,
"competition": 5.0,
}
# Run deterministic evolution
simulation_result = agent.run(
initial_state=initial_state,
target_variables=["price", "demand", "inventory"],
max_steps=3, # Evolve for 3 time steps
)
print("\n--- Initial State ---")
for var, value in initial_state.items():
print(f" {var}: {value}")
print("\n--- Evolved State (after 3 steps) ---")
for var, value in simulation_result["evolved_state"].items():
print(f" {var}: {value:.2f}")
print("\n--- Counterfactual Scenarios ---")
for i, scenario in enumerate(simulation_result["counterfactual_scenarios"][:3], 1):
print(f"\nScenario {i}: {scenario.name}")
print(f" Interventions: {scenario.interventions}")
print(f" Expected Outcomes: {scenario.expected_outcomes}")
print(f" Probability: {scenario.probability:.3f}")
# 8. Example 3: Causal chain identification
print("\n" + "="*80)
print("EXAMPLE 3: Causal Chain Analysis")
print("="*80)
chain = agent.identify_causal_chain("competition", "inventory")
if chain:
print(f"Causal chain from 'competition' to 'inventory': {' → '.join(chain)}")
else:
print("No direct causal chain found")
# 9. Example 4: Analyze causal strength
print("\n" + "="*80)
print("EXAMPLE 4: Causal Strength Analysis")
print("="*80)
strength_analysis = agent.analyze_causal_strength("price", "inventory")
print(f"Direct edge strength: {strength_analysis.get('direct_strength', 'N/A')}")
print(f"Path strength: {strength_analysis.get('path_strength', 'N/A')}")
# 10. Example 5: Custom intervention prediction
print("\n" + "="*80)
print("EXAMPLE 5: Custom Intervention Prediction")
print("="*80)
# What if we reduce price by 20%?
interventions = {"price": 80.0} # 20% reduction from 100
predicted = agent._predict_outcomes(initial_state, interventions)
print("\nIntervention: Reduce price from 100 to 80 (20% reduction)")
print("Predicted outcomes:")
for var, value in predicted.items():
if var in initial_state:
change = value - initial_state[var]
change_pct = (change / initial_state[var]) * 100 if initial_state[var] != 0 else 0
print(f" {var}: {initial_state[var]:.2f} → {value:.2f} ({change_pct:+.1f}%)")
print("\n" + "="*80)
print("Example execution complete!")
print("="*80)
```
Expected output structure
-------------------------
The `run()` method returns different structures depending on the mode:
**LLM Mode** (task string):
```python
{
"task": str, # The provided task string
"causal_analysis": str, # Synthesized analysis report
"counterfactual_scenarios": List[CounterfactualScenario], # Generated scenarios
"causal_graph_info": {
"nodes": List[str],
"edges": List[Tuple[str, str]],
"is_dag": bool
},
"analysis_steps": List[Dict[str, Any]] # Step-by-step reasoning history
}
```
**Deterministic Mode** (initial_state dict):
```python
{
"initial_state": Dict[str, float], # Input state
"evolved_state": Dict[str, float], # State after max_steps evolution
"counterfactual_scenarios": List[CounterfactualScenario], # Generated scenarios
"causal_graph_info": {
"nodes": List[str],
"edges": List[Tuple[str, str]],
"is_dag": bool
},
"steps": int # Number of evolution steps applied
}
```
System prompt best practices
-----------------------------
1. **Domain-specific guidance**: Include domain knowledge relevant to your causal model
2. **Causal reasoning principles**: Reference Pearl's causal hierarchy (association, intervention, counterfactual)
3. **Output format**: Specify desired analysis structure and detail level
4. **Plausibility constraints**: Guide the agent on what interventions are realistic
5. **Quantification**: Encourage numerical estimates when appropriate
Example system prompt template:
```python
SYSTEM_PROMPT = """You are a {domain} causal reasoning expert.
Your analysis should:
- Identify {specific_relationships}
- Consider {relevant_factors}
- Generate {scenario_types}
- Provide {output_format}
Ground your reasoning in the causal graph structure provided."""
```
Why use `run()`
--------------
- **Dual-mode operation**: Automatically selects LLM mode (task string) or deterministic mode (initial_state dict)
- **LLM mode**: Performs sophisticated multi-loop causal reasoning with structured output
- **Deterministic mode**: Evolves the world state for `max_steps` using the deterministic evolution
operator, then generates counterfactuals from the evolved state (consistent timelines)
- Accepts both dict and JSON payloads for flexible integration
- Returns a compact result dict used across Swarms agents
Architecture (high level)
-------------------------
```mermaid
flowchart TB
subgraph Input[Input]
I1[Task string]
I2[Initial state dict]
end
I1 -->|String| LLM[LLM Mode]
I2 -->|Dict| DET[Deterministic Mode]
subgraph LLMFlow[LLM Causal Analysis]
LLM --> P1[Build Causal Prompt]
P1 --> L1[LLM Step 1]
L1 --> L2[LLM Step 2...N]
L2 --> SYN[Synthesize Analysis]
SYN --> CF1[Generate Counterfactuals]
end
subgraph DetFlow[Deterministic Simulation]
DET --> P2[ensure_standardization_stats]
P2 --> P3[Standardize to z-space]
P3 --> T[Topological sort]
T --> E[predict_outcomes linear SCM]
E --> D[De-standardize outputs]
D --> R[Timeline rollout]
R --> CF2[Generate Counterfactuals]
end
subgraph Model[Causal Model]
G1[Causal graph]
G2[Edge strengths]
G3[Standardization stats]
end
subgraph Output[Outputs]
O1[Causal analysis / Evolved state]
O2[Counterfactual scenarios]
O3[Causal graph info]
end
G1 --> LLMFlow
G1 --> DetFlow
G2 --> DetFlow
G3 --> DetFlow
CF1 --> O2
CF2 --> O2
SYN --> O1
R --> O1
G1 --> O3
```
Complete method index (quick)
-----------------------------
The following is the public surface implemented by `CRCAAgent` (Lite) in
`ceca_lite/crca-lite.py`.
LLM integration
- `_get_cr_ca_schema()` — CR-CA function calling schema for structured reasoning
- `step(task)` — Execute a single step of LLM-based causal reasoning
- `_build_causal_prompt(task)` — Build causal analysis prompt with graph context
- `_build_memory_context()` — Build memory context from previous analysis steps
- `_synthesize_causal_analysis(task)` — Synthesize final causal analysis using LLM
- `_run_llm_causal_analysis(task)` — Run multi-loop LLM-based causal analysis
Core graph & state
- `_ensure_node_exists(node)` — ensure node present in internal maps
- `add_causal_relationship(source, target, strength=1.0, ...)` — add/update edge
- `_get_parents(node)`, `_get_children(node)` — graph accessors
- `_topological_sort()` — Kahn's algorithm
- `get_nodes()`, `get_edges()`, `is_dag()` — graph introspection
Standardization & prediction
- `set_standardization_stats(var, mean, std)` — set z-stats
- `ensure_standardization_stats(state)` — auto-fill sensible stats
- `_standardize_state(state)` / `_destandardize_value(var, z)` — z-score transforms
- `_predict_outcomes(factual_state, interventions)` — evolution operator (linear SCM)
- `_predict_outcomes_cached(...)` — cached wrapper
Counterfactuals & reasoning
- `generate_counterfactual_scenarios(factual_state, target_variables, max_scenarios=5)`
- `_calculate_scenario_probability(factual_state, interventions)` — heuristic plausibility
- `counterfactual_abduction_action_prediction(factual_state, interventions)` — abductionactionprediction (Pearl)
Estimation, analysis & utilities
- `fit_from_dataframe(df, variables, window=30, ...)` — WLS edge estimation and stats
- `quantify_uncertainty(df, variables, windows=200, ...)` — bootstrap CIs
- `analyze_causal_strength(source, target)` — path/edge summary
- `identify_causal_chain(start, end)` — BFS shortest path
- `detect_change_points(series, threshold=2.5)` — simple detector
Advanced (optional / Pro)
- `learn_structure(...)`, `plan_interventions(...)`, `gradient_based_intervention_optimization(...)`,
`convex_intervention_optimization(...)`, `evolutionary_multi_objective_optimization(...)`,
`probabilistic_nested_simulation(...)`, `deep_root_cause_analysis(...)`, and more.
Return shape from `run()`
-------------------------
`run()` returns a dictionary with different keys depending on mode:
**LLM Mode** (when `task` is a string):
- `task`: the provided task/problem string
- `causal_analysis`: synthesized causal analysis report (string)
- `counterfactual_scenarios`: list of `CounterfactualScenario` objects
- `causal_graph_info`: {"nodes": [...], "edges": [...], "is_dag": bool}
- `analysis_steps`: list of analysis steps with memory context
**Deterministic Mode** (when `initial_state` is a dict):
- `initial_state`: the provided input state (dict)
- `evolved_state`: state after applying `max_steps` of the evolution operator
- `counterfactual_scenarios`: list of `CounterfactualScenario` with name/interventions/expected_outcomes/probability/reasoning
- `causal_graph_info`: {"nodes": [...], "edges": [...], "is_dag": bool}
- `steps`: `max_steps` used
Usage patterns & examples
-------------------------
1) LLM-based causal analysis (sophisticated reasoning)
```python
agent = CRCAAgent(
variables=["a","b","c"],
model_name="gpt-4o",
max_loops=3
)
agent.add_causal_relationship("a","b", strength=0.8)
# LLM mode: pass task as string
task = "Analyze the causal relationship between a and b"
res = agent.run(task=task)
print(res["causal_analysis"])
print(res["analysis_steps"])
```
2) Deterministic simulation (script-style)
```python
agent = CRCAAgent(variables=["a","b","c"])
agent.add_causal_relationship("a","b", strength=0.8)
state = {"a":1.0, "b":2.0, "c":3.0}
res = agent.run(initial_state=state, max_steps=2)
print(res["evolved_state"])
```
3) Orchestration / agent-style (JSON payloads)
```python
payload = '{"a":1.0,"b":2.0,"c":3.0}'
res = agent.run(initial_state=payload, max_steps=1)
if "error" in res:
print("Bad payload:", res["error"])
else:
print("Evolved:", res["evolved_state"])
```
4) Lower-level testing & research
```python
pred = agent._predict_outcomes({"a":1.0,"b":2.0},{"a":0.0})
print(pred)
```
Design notes & limitations
--------------------------
- **LLM Integration**: Uses swarms Agent infrastructure for LLM calls. Configure model via `model_name` parameter. Multi-loop reasoning enabled by default.
- **Dual-mode operation**: Automatically selects LLM mode (task string) or deterministic mode (initial_state dict). Both modes generate counterfactuals using deterministic methods.
- **Linearity**: default `_predict_outcomes` is linear in standardized z-space. To model non-linear dynamics, subclass `CRCAAgent` and override `_predict_outcomes`.
- **Probabilities**: scenario probability is a heuristic proximity measure (Mahalanobis-like) — not a formal posterior.
- **Stats**: the engine auto-fills standardization stats with sensible defaults (`mean=observed`, `std=1.0`) via `ensure_standardization_stats` to avoid degenerate std=0 cases.
- **Dependencies**: Lite intentionally avoids heavy libs (pandas/scipy/cvxpy) but includes LLM integration via swarms Agent base.
Extending & integration
-----------------------
For advanced capabilities (structure learning, Bayesian inference, optimization,
extensive statistical methods), use the full CRCA Agent featured [WIP]
The Lite version provides core causal reasoning with LLM support while maintaining minimal dependencies.
References
----------
- Pearl, J. (2009). *Causality: Models, Reasoning, and Inference*.
- Pearl, J., & Mackenzie, D. (2018). *The Book of Why*.

@ -3,6 +3,7 @@ from swarms.agents.consistency_agent import SelfConsistencyAgent
from swarms.agents.create_agents_from_yaml import (
create_agents_from_yaml,
)
from swarms.agents.cr_ca_agent import CRCAAgent
from swarms.agents.flexion_agent import ReflexionAgent
from swarms.agents.gkp_agent import GKPAgent
from swarms.agents.i_agent import IterativeReflectiveExpansion
@ -22,4 +23,5 @@ __all__ = [
"ReflexionAgent",
"GKPAgent",
"AgentJudge",
"CRCAAgent",
]

@ -0,0 +1,630 @@
"""
CR-CA Lite: A lightweight Causal Reasoning with Counterfactual Analysis engine.
This is a minimal implementation of the ASTT/CR-CA framework focusing on:
- Core evolution operator E(x)
- Counterfactual scenario generation
- Causal chain identification
- Basic causal graph operations
Dependencies: numpy only (typing, dataclasses, enum are stdlib)
"""
from typing import Dict, Any, List, Tuple, Optional, Union
import numpy as np
from dataclasses import dataclass
from enum import Enum
class CausalRelationType(Enum):
"""Types of causal relationships"""
DIRECT = "direct"
INDIRECT = "indirect"
CONFOUNDING = "confounding"
MEDIATING = "mediating"
MODERATING = "moderating"
@dataclass
class CausalNode:
"""Represents a node in the causal graph"""
name: str
value: Optional[float] = None
confidence: float = 1.0
node_type: str = "variable"
@dataclass
class CausalEdge:
"""Represents an edge in the causal graph"""
source: str
target: str
strength: float = 1.0
relation_type: CausalRelationType = CausalRelationType.DIRECT
confidence: float = 1.0
@dataclass
class CounterfactualScenario:
"""Represents a counterfactual scenario"""
name: str
interventions: Dict[str, float]
expected_outcomes: Dict[str, float]
probability: float = 1.0
reasoning: str = ""
class CRCAgent:
"""
CR-CA Lite: Lightweight Causal Reasoning with Counterfactual Analysis engine.
Core components:
- Evolution operator: E(x) = _predict_outcomes()
- Counterfactual generation: generate_counterfactual_scenarios()
- Causal chain identification: identify_causal_chain()
- State mapping: _standardize_state() / _destandardize_value()
Args:
variables: Optional list of variable names
causal_edges: Optional list of (source, target) tuples for initial edges
"""
def __init__(
self,
variables: Optional[List[str]] = None,
causal_edges: Optional[List[Tuple[str, str]]] = None,
max_loops: Optional[Union[int, str]] = 1,
):
"""
Initialize CR-CA Lite engine.
Args:
variables: Optional list of variable names to add to graph
causal_edges: Optional list of (source, target) tuples for initial edges
"""
# Pure Python graph representation: {node: {child: strength}}
self.causal_graph: Dict[str, Dict[str, float]] = {}
self.causal_graph_reverse: Dict[str, List[str]] = {} # For fast parent lookup
# Standardization statistics: {'var': {'mean': m, 'std': s}}
self.standardization_stats: Dict[str, Dict[str, float]] = {}
# Initialize graph
if variables:
for var in variables:
self._ensure_node_exists(var)
if causal_edges:
for source, target in causal_edges:
self.add_causal_relationship(source, target)
# Agent-like loop control: accept numeric or "auto"
# Keep the original (possibly "auto") value; resolution happens at run time.
self.max_loops = max_loops
def _ensure_node_exists(self, node: str) -> None:
"""Ensure node present in graph structures."""
if node not in self.causal_graph:
self.causal_graph[node] = {}
if node not in self.causal_graph_reverse:
self.causal_graph_reverse[node] = []
def add_causal_relationship(
self,
source: str,
target: str,
strength: float = 1.0,
relation_type: CausalRelationType = CausalRelationType.DIRECT,
confidence: float = 1.0
) -> None:
"""
Add a causal edge to the graph.
Args:
source: Source variable name
target: Target variable name
strength: Causal effect strength (default: 1.0)
relation_type: Type of causal relation (default: DIRECT)
confidence: Confidence in the relationship (default: 1.0)
"""
# Ensure nodes exist
self._ensure_node_exists(source)
self._ensure_node_exists(target)
# Add or update edge: source -> target with strength
self.causal_graph[source][target] = float(strength)
# Update reverse mapping for parent lookup (avoid duplicates)
if source not in self.causal_graph_reverse[target]:
self.causal_graph_reverse[target].append(source)
def _get_parents(self, node: str) -> List[str]:
"""
Get parent nodes (predecessors) of a node.
Args:
node: Node name
Returns:
List of parent node names
"""
return self.causal_graph_reverse.get(node, [])
def _get_children(self, node: str) -> List[str]:
"""
Get child nodes (successors) of a node.
Args:
node: Node name
Returns:
List of child node names
"""
return list(self.causal_graph.get(node, {}).keys())
def _topological_sort(self) -> List[str]:
"""
Perform topological sort using Kahn's algorithm (pure Python).
Returns:
List of nodes in topological order
"""
# Compute in-degrees
in_degree: Dict[str, int] = {node: 0 for node in self.causal_graph.keys()}
for node in self.causal_graph:
for child in self._get_children(node):
in_degree[child] = in_degree.get(child, 0) + 1
# Initialize queue with nodes having no incoming edges
queue: List[str] = [node for node, degree in in_degree.items() if degree == 0]
result: List[str] = []
# Process nodes
while queue:
node = queue.pop(0)
result.append(node)
# Reduce in-degree of children
for child in self._get_children(node):
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
return result
def identify_causal_chain(self, start: str, end: str) -> List[str]:
"""
Find shortest causal path from start to end using BFS (pure Python).
Implements core causal chain identification (Ax2, Ax6).
Args:
start: Starting variable
end: Target variable
Returns:
List of variables forming the causal chain, or empty list if no path exists
"""
if start not in self.causal_graph or end not in self.causal_graph:
return []
if start == end:
return [start]
# BFS to find shortest path
queue: List[Tuple[str, List[str]]] = [(start, [start])]
visited: set = {start}
while queue:
current, path = queue.pop(0)
# Check all children
for child in self._get_children(current):
if child == end:
return path + [child]
if child not in visited:
visited.add(child)
queue.append((child, path + [child]))
return [] # No path found
# detect_confounders removed in Lite version (advanced inference)
def _has_path(self, start: str, end: str) -> bool:
"""
Check if a path exists from start to end using DFS.
Args:
start: Starting node
end: Target node
Returns:
True if path exists, False otherwise
"""
if start == end:
return True
stack = [start]
visited = set()
while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
for child in self._get_children(current):
if child == end:
return True
if child not in visited:
stack.append(child)
return False
# identify_adjustment_set removed in Lite version (advanced inference)
def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]:
"""
Standardize state values to z-scores.
Args:
state: Dictionary of variable values
Returns:
Dictionary of standardized (z-score) values
"""
z: Dict[str, float] = {}
for k, v in state.items():
s = self.standardization_stats.get(k)
if s and s.get("std", 0.0) > 0:
z[k] = (v - s["mean"]) / s["std"]
else:
z[k] = v
return z
def _destandardize_value(self, var: str, z_value: float) -> float:
"""
Convert z-score back to original scale.
Args:
var: Variable name
z_value: Standardized (z-score) value
Returns:
Original scale value
"""
s = self.standardization_stats.get(var)
if s and s.get("std", 0.0) > 0:
return z_value * s["std"] + s["mean"]
return z_value
def _predict_outcomes(
self,
factual_state: Dict[str, float],
interventions: Dict[str, float]
) -> Dict[str, float]:
"""
Evolution operator E(x): Predict outcomes given state and interventions.
This is the core CR-CA evolution operator implementing:
x_{t+1} = E(x_t)
Mathematical foundation:
- Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε\n+ - NOTE: This implementation is linear. To model nonlinear dynamics override\n+ `_predict_outcomes` in a subclass with a custom evolution operator.\n*** End Patch
- Propagates effects through causal graph in topological order
- Standardizes inputs, computes in z-space, de-standardizes outputs
Args:
factual_state: Current world state (baseline)
interventions: Interventions to apply (do-operator)
Returns:
Dictionary of predicted variable values
"""
# Merge factual state with interventions
raw = factual_state.copy()
raw.update(interventions)
# Standardize to z-scores
z_state = self._standardize_state(raw)
z_pred = dict(z_state)
# Process nodes in topological order
for node in self._topological_sort():
# If node is intervened on, keep its value
if node in interventions:
if node not in z_pred:
z_pred[node] = z_state.get(node, 0.0)
continue
# Get parents
parents = self._get_parents(node)
if not parents:
continue
# Compute linear combination: Σᵢ βᵢ·z_xi
s = 0.0
for p in parents:
pz = z_pred.get(p, z_state.get(p, 0.0))
strength = self.causal_graph.get(p, {}).get(node, 0.0)
s += pz * strength
z_pred[node] = s
# De-standardize results
return {v: self._destandardize_value(v, z) for v, z in z_pred.items()}
def _calculate_scenario_probability(
self,
factual_state: Dict[str, float],
interventions: Dict[str, float]
) -> float:
"""
Calculate a heuristic probability of a counterfactual scenario.
NOTE: This is a lightweight heuristic proximity measure (Mahalanobis-like)
and NOT a full statistical estimator it ignores covariance and should
be treated as a relative plausibility score for Lite usage.
Args:
factual_state: Baseline state
interventions: Intervention values
Returns:
Heuristic probability value between 0.05 and 0.98
"""
z_sq = 0.0
for var, new in interventions.items():
s = self.standardization_stats.get(var, {"mean": 0.0, "std": 1.0})
mu, sd = s.get("mean", 0.0), s.get("std", 1.0) or 1.0
old = factual_state.get(var, mu)
dz = (new - mu) / sd - (old - mu) / sd
z_sq += float(dz) * float(dz)
p = 0.95 * float(np.exp(-0.5 * z_sq)) + 0.05
return float(max(0.05, min(0.98, p)))
def generate_counterfactual_scenarios(
self,
factual_state: Dict[str, float],
target_variables: List[str],
max_scenarios: int = 5
) -> List[CounterfactualScenario]:
"""
Generate counterfactual scenarios for target variables.
Implements Ax8 (Counterfactuals) - core CR-CA functionality.
Args:
factual_state: Current factual state
target_variables: Variables to generate counterfactuals for
max_scenarios: Maximum number of scenarios per variable
Returns:
List of CounterfactualScenario objects
"""
# Ensure stats exist for variables in factual_state (fallback behavior)
self.ensure_standardization_stats(factual_state)
scenarios: List[CounterfactualScenario] = []
z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
for i, tv in enumerate(target_variables[:max_scenarios]):
stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
cur = factual_state.get(tv, stats.get("mean", 0.0))
# If std is zero or missing, use absolute perturbations instead
if not stats or stats.get("std", 0.0) <= 0:
base = cur
abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
vals = [base + step for step in abs_steps]
else:
mean = stats["mean"]
std = stats["std"]
cz = (cur - mean) / std
vals = [(cz + dz) * std + mean for dz in z_steps]
for j, v in enumerate(vals):
interventions = {tv: float(v)}
scenarios.append(
CounterfactualScenario(
name=f"scenario_{i}_{j}",
interventions=interventions,
expected_outcomes=self._predict_outcomes(
factual_state, interventions
),
probability=self._calculate_scenario_probability(
factual_state, interventions
),
reasoning=f"Intervention on {tv} with value {v}",
)
)
return scenarios
def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
"""
Analyze the strength of causal relationship between two variables.
Args:
source: Source variable
target: Target variable
Returns:
Dictionary with strength, confidence, path_length, relation_type
"""
if source not in self.causal_graph or target not in self.causal_graph[source]:
return {"strength": 0.0, "confidence": 0.0, "path_length": float('inf')}
strength = self.causal_graph[source].get(target, 0.0)
path = self.identify_causal_chain(source, target)
path_length = len(path) - 1 if path else float('inf')
return {
"strength": float(strength),
"confidence": 1.0, # Simplified: assume full confidence
"path_length": path_length,
"relation_type": CausalRelationType.DIRECT.value
}
def set_standardization_stats(
self,
variable: str,
mean: float,
std: float
) -> None:
"""
Set standardization statistics for a variable.
Args:
variable: Variable name
mean: Mean value
std: Standard deviation
"""
self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0}
def ensure_standardization_stats(self, state: Dict[str, float]) -> None:
"""
Ensure standardization stats exist for all variables in a given state.
If stats are missing, create a sensible fallback (mean=observed, std=1.0).
This prevents degenerate std=0 issues in Lite mode.
"""
for var, val in state.items():
if var not in self.standardization_stats:
self.standardization_stats[var] = {"mean": float(val), "std": 1.0}
def get_nodes(self) -> List[str]:
"""
Get all nodes in the causal graph.
Returns:
List of node names
"""
return list(self.causal_graph.keys())
def get_edges(self) -> List[Tuple[str, str]]:
"""
Get all edges in the causal graph.
Returns:
List of (source, target) tuples
"""
edges = []
for source, targets in self.causal_graph.items():
for target in targets.keys():
edges.append((source, target))
return edges
def is_dag(self) -> bool:
"""
Check if the causal graph is a DAG (no cycles).
Uses DFS to detect cycles.
Returns:
True if DAG, False if cycles exist
"""
def has_cycle(node: str, visited: set, rec_stack: set) -> bool:
"""DFS to detect cycles."""
visited.add(node)
rec_stack.add(node)
for child in self._get_children(node):
if child not in visited:
if has_cycle(child, visited, rec_stack):
return True
elif child in rec_stack:
return True
rec_stack.remove(node)
return False
visited = set()
rec_stack = set()
for node in self.causal_graph:
if node not in visited:
if has_cycle(node, visited, rec_stack):
return False
return True
def run(
self,
initial_state: Any,
target_variables: Optional[List[str]] = None,
max_steps: Union[int, str] = 1
) -> Dict[str, Any]:
"""
Run causal simulation: evolve state and generate counterfactuals.
Simple entry point for CR-CA engine.
Args:
initial_state: Initial world state
target_variables: Variables to generate counterfactuals for (default: all nodes)
max_steps: Number of evolution steps (default: 1)
Returns:
Dictionary with evolved state, counterfactuals, and graph info
"""
# Accept either a dict initial_state or a JSON string (agent-like behavior)
if not isinstance(initial_state, dict):
try:
import json
parsed = json.loads(initial_state)
if isinstance(parsed, dict):
initial_state = parsed
else:
return {"error": "initial_state JSON must decode to a dict"}
except Exception:
return {"error": "initial_state must be a dict or JSON-encoded dict"}
# Use all nodes as targets if not specified
if target_variables is None:
target_variables = list(self.causal_graph.keys())
# Resolve "auto" sentinel for max_steps (accepts method arg or instance-level default)
def _resolve_max_steps(value: Union[int, str]) -> int:
if isinstance(value, str) and value == "auto":
# Heuristic: one step per variable (at least 1)
return max(1, len(self.causal_graph))
try:
return int(value)
except Exception:
return max(1, len(self.causal_graph))
effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.max_loops == 1 else self.max_loops)
# If caller passed default 1 and instance set a different max_loops, prefer instance value
if max_steps == 1 and self.max_loops != 1:
effective_steps = _resolve_max_steps(self.max_loops)
# Evolve state
current_state = initial_state.copy()
for step in range(effective_steps):
current_state = self._predict_outcomes(current_state, {})
# Ensure standardization stats exist for the evolved state and generate counterfactuals from it
self.ensure_standardization_stats(current_state)
counterfactual_scenarios = self.generate_counterfactual_scenarios(
current_state,
target_variables,
max_scenarios=5
)
return {
"initial_state": initial_state,
"evolved_state": current_state,
"counterfactual_scenarios": counterfactual_scenarios,
"causal_graph_info": {
"nodes": self.get_nodes(),
"edges": self.get_edges(),
"is_dag": self.is_dag()
},
"steps": effective_steps
}
# Agent-like behavior: `run` accepts either a dict or a JSON string as the initial_state
# so the engine behaves like a normal agent by default.
Loading…
Cancel
Save