diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 7f7516b3..c4819b86 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -262,6 +262,9 @@ nav: - Agent Judge: "swarms/agents/agent_judge.md" - Reasoning Agent Router: "swarms/agents/reasoning_agent_router.md" + - CR-CA Agent: + - CR-CA Agent: "swarms/agents/cr_ca_agent.md" + - Multi-Agent Architectures: - Overview: "swarms/concept/swarm_architectures.md" - Benefits: "swarms/concept/why.md" @@ -298,8 +301,8 @@ nav: - Routers: - SwarmRouter: "swarms/structs/swarm_router.md" - MultiAgentRouter: "swarms/structs/multi_agent_router.md" - - AgentRouter: "swarms/structs/agent_router.md" - ModelRouter: "swarms/structs/model_router.md" + - AgentRouter: "swarms/structs/agent_router.md" - Rearrangers: @@ -457,12 +460,12 @@ nav: - Overview: "examples/apps_examples_overview.md" - Web Scraper Agents: "developer_guides/web_scraper.md" - Smart Database: "examples/smart_database.md" - + - AOP: - Overview: "examples/aop_examples_overview.md" - Medical AOP Example: "examples/aop_medical.md" - - X402: + - X402: - x402 Quickstart Example: "examples/x402_payment_integration.md" - X402 Discovery Query Agent: "examples/x402_discovery_query.md" diff --git a/docs/swarms/agents/cr_ca_agent.md b/docs/swarms/agents/cr_ca_agent.md new file mode 100644 index 00000000..3ecd9695 --- /dev/null +++ b/docs/swarms/agents/cr_ca_agent.md @@ -0,0 +1,498 @@ + +# CRCAAgent + +Short summary +------------- +CRCAAgent is a lightweight causal reasoning Agent with LLM integration, +implemented in pure Python and intended as a flexible CR‑CA engine for Swarms. +It provides both LLM-based causal analysis and deterministic causal simulation, +focusing on the core ASTT primitives: a causal DAG, a linear structural +evolution operator (in z-space), and compact counterfactual generation. + +Key properties +- LLM integration for sophisticated causal reasoning (like full CRCAAgent) +- Dual-mode operation: LLM-based analysis and deterministic simulation +- Minimal dependencies (numpy + swarms Agent base) +- Pure-Python causal graph (adjacency dicts) +- Linear SCM evolution by default (overrideable) +- Agent-first `run()` entrypoint (accepts task string or dict/JSON payloads) + +Canonical import +---------------- +Use the canonical agent import in application code: + +```python +from path.to.crca_agent import CRCAAgent +``` + +Quickstart +---------- +Minimal example — deterministic mode: initialize, add edges, evolve state and get counterfactuals: + +```python +from path.to.crca_agent import CRCAAgent + +agent = CRCAAgent(variables=["price", "demand", "inventory"]) +agent.add_causal_relationship("price", "demand", strength=-0.5) +agent.add_causal_relationship("demand", "inventory", strength=-0.2) + +state = {"price": 100.0, "demand": 1000.0, "inventory": 5000.0} +out = agent.run(initial_state=state, target_variables=["price", "demand"], max_steps=1) + +print("Evolved:", out["evolved_state"]) # evolved world state +for sc in out["counterfactual_scenarios"][:5]: # candidate CFs + print(sc.name, sc.interventions, sc.probability) +``` + +LLM-based causal analysis example +---------------------------------- + +```python +from path.to.crca_agent import CRCAAgent + +agent = CRCAAgent( + variables=["price", "demand", "inventory"], + model_name="gpt-4o", + max_loops=3 +) +agent.add_causal_relationship("price", "demand", strength=-0.5) + +# LLM mode: pass task as string +task = "Analyze how increasing price affects demand and inventory levels" +result = agent.run(task=task) + +print("Causal Analysis:", result["causal_analysis"]) +print("Counterfactual Scenarios:", result["counterfactual_scenarios"]) +print("Analysis Steps:", result["analysis_steps"]) +``` + +Agent-style JSON payload example (orchestrators) +------------------------------------------------ + +```python +import json +from path.to.crca_agent import CRCAAgent + +agent = CRCAAgent(variables=["price","demand","inventory"]) +payload = json.dumps({"price": 100.0, "demand": 1000.0}) +out = agent.run(initial_state=payload, target_variables=["price"], max_steps=1) +print(out["evolved_state"]) +``` + +Complete example: Full workflow with system prompt +-------------------------------------------------- +This example demonstrates a complete workflow from imports to execution, including +system prompt configuration, causal graph construction, and both LLM and deterministic modes. + +```python +""" +Complete CRCAAgent example: Full workflow from initialization to execution +""" + +# 1. Imports +from typing import Dict, Any +from path.to.crca_agent import CRCAAgent + +# 2. System prompt configuration +# Define a custom system prompt for domain-specific causal reasoning +SYSTEM_PROMPT = """You are an expert causal reasoning analyst specializing in economic systems. +Your role is to: +- Identify causal relationships between economic variables +- Analyze how interventions affect system outcomes +- Generate plausible counterfactual scenarios +- Provide clear, evidence-based causal explanations + +When analyzing causal relationships: +1. Consider both direct and indirect causal paths +2. Account for confounding factors +3. Evaluate intervention plausibility +4. Quantify expected causal effects when possible + +Always ground your analysis in the provided causal graph structure and observed data.""" + +# 3. Agent initialization with system prompt +agent = CRCAAgent( + variables=["price", "demand", "inventory", "supply", "competition"], + agent_name="economic-causal-analyst", + agent_description="Expert economic causal reasoning agent", + model_name="gpt-4o", # or "gpt-4o-mini" for faster/cheaper analysis + max_loops=3, # Number of reasoning loops for LLM-based analysis + system_prompt=SYSTEM_PROMPT, + verbose=True, # Enable detailed logging +) + +# 4. Build causal graph: Add causal relationships +# Price negatively affects demand (higher price → lower demand) +agent.add_causal_relationship("price", "demand", strength=-0.5) + +# Demand negatively affects inventory (higher demand → lower inventory) +agent.add_causal_relationship("demand", "inventory", strength=-0.2) + +# Supply positively affects inventory (higher supply → higher inventory) +agent.add_causal_relationship("supply", "inventory", strength=0.3) + +# Competition negatively affects price (more competition → lower price) +agent.add_causal_relationship("competition", "price", strength=-0.4) + +# Price positively affects supply (higher price → more supply) +agent.add_causal_relationship("price", "supply", strength=0.2) + +# 5. Verify graph structure +print("Causal Graph Nodes:", agent.get_nodes()) +print("Causal Graph Edges:", agent.get_edges()) +print("Is DAG:", agent.is_dag()) + +# 6. Example 1: LLM-based causal analysis (sophisticated reasoning) +print("\n" + "="*80) +print("EXAMPLE 1: LLM-Based Causal Analysis") +print("="*80) + +task = """ +Analyze the causal relationship between price increases and inventory levels. +Consider both direct and indirect causal paths. What interventions could +stabilize inventory while maintaining profitability? +""" + +result = agent.run(task=task) + +print("\n--- Causal Analysis Report ---") +print(result["causal_analysis"]) + +print("\n--- Counterfactual Scenarios ---") +for i, scenario in enumerate(result["counterfactual_scenarios"][:3], 1): + print(f"\nScenario {i}: {scenario.name}") + print(f" Interventions: {scenario.interventions}") + print(f" Expected Outcomes: {scenario.expected_outcomes}") + print(f" Probability: {scenario.probability:.3f}") + print(f" Reasoning: {scenario.reasoning}") + +print("\n--- Causal Graph Info ---") +print(f"Nodes: {result['causal_graph_info']['nodes']}") +print(f"Edges: {result['causal_graph_info']['edges']}") +print(f"Is DAG: {result['causal_graph_info']['is_dag']}") + +# 7. Example 2: Deterministic simulation (script-style) +print("\n" + "="*80) +print("EXAMPLE 2: Deterministic Causal Simulation") +print("="*80) + +# Initial state +initial_state = { + "price": 100.0, + "demand": 1000.0, + "inventory": 5000.0, + "supply": 2000.0, + "competition": 5.0, +} + +# Run deterministic evolution +simulation_result = agent.run( + initial_state=initial_state, + target_variables=["price", "demand", "inventory"], + max_steps=3, # Evolve for 3 time steps +) + +print("\n--- Initial State ---") +for var, value in initial_state.items(): + print(f" {var}: {value}") + +print("\n--- Evolved State (after 3 steps) ---") +for var, value in simulation_result["evolved_state"].items(): + print(f" {var}: {value:.2f}") + +print("\n--- Counterfactual Scenarios ---") +for i, scenario in enumerate(simulation_result["counterfactual_scenarios"][:3], 1): + print(f"\nScenario {i}: {scenario.name}") + print(f" Interventions: {scenario.interventions}") + print(f" Expected Outcomes: {scenario.expected_outcomes}") + print(f" Probability: {scenario.probability:.3f}") + +# 8. Example 3: Causal chain identification +print("\n" + "="*80) +print("EXAMPLE 3: Causal Chain Analysis") +print("="*80) + +chain = agent.identify_causal_chain("competition", "inventory") +if chain: + print(f"Causal chain from 'competition' to 'inventory': {' → '.join(chain)}") +else: + print("No direct causal chain found") + +# 9. Example 4: Analyze causal strength +print("\n" + "="*80) +print("EXAMPLE 4: Causal Strength Analysis") +print("="*80) + +strength_analysis = agent.analyze_causal_strength("price", "inventory") +print(f"Direct edge strength: {strength_analysis.get('direct_strength', 'N/A')}") +print(f"Path strength: {strength_analysis.get('path_strength', 'N/A')}") + +# 10. Example 5: Custom intervention prediction +print("\n" + "="*80) +print("EXAMPLE 5: Custom Intervention Prediction") +print("="*80) + +# What if we reduce price by 20%? +interventions = {"price": 80.0} # 20% reduction from 100 +predicted = agent._predict_outcomes(initial_state, interventions) + +print("\nIntervention: Reduce price from 100 to 80 (20% reduction)") +print("Predicted outcomes:") +for var, value in predicted.items(): + if var in initial_state: + change = value - initial_state[var] + change_pct = (change / initial_state[var]) * 100 if initial_state[var] != 0 else 0 + print(f" {var}: {initial_state[var]:.2f} → {value:.2f} ({change_pct:+.1f}%)") + +print("\n" + "="*80) +print("Example execution complete!") +print("="*80) +``` + +Expected output structure +------------------------- +The `run()` method returns different structures depending on the mode: + +**LLM Mode** (task string): +```python +{ + "task": str, # The provided task string + "causal_analysis": str, # Synthesized analysis report + "counterfactual_scenarios": List[CounterfactualScenario], # Generated scenarios + "causal_graph_info": { + "nodes": List[str], + "edges": List[Tuple[str, str]], + "is_dag": bool + }, + "analysis_steps": List[Dict[str, Any]] # Step-by-step reasoning history +} +``` + +**Deterministic Mode** (initial_state dict): +```python +{ + "initial_state": Dict[str, float], # Input state + "evolved_state": Dict[str, float], # State after max_steps evolution + "counterfactual_scenarios": List[CounterfactualScenario], # Generated scenarios + "causal_graph_info": { + "nodes": List[str], + "edges": List[Tuple[str, str]], + "is_dag": bool + }, + "steps": int # Number of evolution steps applied +} +``` + +System prompt best practices +----------------------------- +1. **Domain-specific guidance**: Include domain knowledge relevant to your causal model +2. **Causal reasoning principles**: Reference Pearl's causal hierarchy (association, intervention, counterfactual) +3. **Output format**: Specify desired analysis structure and detail level +4. **Plausibility constraints**: Guide the agent on what interventions are realistic +5. **Quantification**: Encourage numerical estimates when appropriate + +Example system prompt template: +```python +SYSTEM_PROMPT = """You are a {domain} causal reasoning expert. +Your analysis should: +- Identify {specific_relationships} +- Consider {relevant_factors} +- Generate {scenario_types} +- Provide {output_format} + +Ground your reasoning in the causal graph structure provided.""" +``` + +Why use `run()` +-------------- +- **Dual-mode operation**: Automatically selects LLM mode (task string) or deterministic mode (initial_state dict) +- **LLM mode**: Performs sophisticated multi-loop causal reasoning with structured output +- **Deterministic mode**: Evolves the world state for `max_steps` using the deterministic evolution + operator, then generates counterfactuals from the evolved state (consistent timelines) +- Accepts both dict and JSON payloads for flexible integration +- Returns a compact result dict used across Swarms agents + +Architecture (high level) +------------------------- + +```mermaid +flowchart TB + subgraph Input[Input] + I1[Task string] + I2[Initial state dict] + end + + I1 -->|String| LLM[LLM Mode] + I2 -->|Dict| DET[Deterministic Mode] + + subgraph LLMFlow[LLM Causal Analysis] + LLM --> P1[Build Causal Prompt] + P1 --> L1[LLM Step 1] + L1 --> L2[LLM Step 2...N] + L2 --> SYN[Synthesize Analysis] + SYN --> CF1[Generate Counterfactuals] + end + + subgraph DetFlow[Deterministic Simulation] + DET --> P2[ensure_standardization_stats] + P2 --> P3[Standardize to z-space] + P3 --> T[Topological sort] + T --> E[predict_outcomes linear SCM] + E --> D[De-standardize outputs] + D --> R[Timeline rollout] + R --> CF2[Generate Counterfactuals] + end + + subgraph Model[Causal Model] + G1[Causal graph] + G2[Edge strengths] + G3[Standardization stats] + end + + subgraph Output[Outputs] + O1[Causal analysis / Evolved state] + O2[Counterfactual scenarios] + O3[Causal graph info] + end + + G1 --> LLMFlow + G1 --> DetFlow + G2 --> DetFlow + G3 --> DetFlow + CF1 --> O2 + CF2 --> O2 + SYN --> O1 + R --> O1 + G1 --> O3 +``` + +Complete method index (quick) +----------------------------- +The following is the public surface implemented by `CRCAAgent` (Lite) in +`ceca_lite/crca-lite.py`. + +LLM integration +- `_get_cr_ca_schema()` — CR-CA function calling schema for structured reasoning +- `step(task)` — Execute a single step of LLM-based causal reasoning +- `_build_causal_prompt(task)` — Build causal analysis prompt with graph context +- `_build_memory_context()` — Build memory context from previous analysis steps +- `_synthesize_causal_analysis(task)` — Synthesize final causal analysis using LLM +- `_run_llm_causal_analysis(task)` — Run multi-loop LLM-based causal analysis + +Core graph & state +- `_ensure_node_exists(node)` — ensure node present in internal maps +- `add_causal_relationship(source, target, strength=1.0, ...)` — add/update edge +- `_get_parents(node)`, `_get_children(node)` — graph accessors +- `_topological_sort()` — Kahn's algorithm +- `get_nodes()`, `get_edges()`, `is_dag()` — graph introspection + +Standardization & prediction +- `set_standardization_stats(var, mean, std)` — set z-stats +- `ensure_standardization_stats(state)` — auto-fill sensible stats +- `_standardize_state(state)` / `_destandardize_value(var, z)` — z-score transforms +- `_predict_outcomes(factual_state, interventions)` — evolution operator (linear SCM) +- `_predict_outcomes_cached(...)` — cached wrapper + +Counterfactuals & reasoning +- `generate_counterfactual_scenarios(factual_state, target_variables, max_scenarios=5)` +- `_calculate_scenario_probability(factual_state, interventions)` — heuristic plausibility +- `counterfactual_abduction_action_prediction(factual_state, interventions)` — abduction–action–prediction (Pearl) + +Estimation, analysis & utilities +- `fit_from_dataframe(df, variables, window=30, ...)` — WLS edge estimation and stats +- `quantify_uncertainty(df, variables, windows=200, ...)` — bootstrap CIs +- `analyze_causal_strength(source, target)` — path/edge summary +- `identify_causal_chain(start, end)` — BFS shortest path +- `detect_change_points(series, threshold=2.5)` — simple detector + +Advanced (optional / Pro) +- `learn_structure(...)`, `plan_interventions(...)`, `gradient_based_intervention_optimization(...)`, + `convex_intervention_optimization(...)`, `evolutionary_multi_objective_optimization(...)`, + `probabilistic_nested_simulation(...)`, `deep_root_cause_analysis(...)`, and more. + +Return shape from `run()` +------------------------- +`run()` returns a dictionary with different keys depending on mode: + +**LLM Mode** (when `task` is a string): +- `task`: the provided task/problem string +- `causal_analysis`: synthesized causal analysis report (string) +- `counterfactual_scenarios`: list of `CounterfactualScenario` objects +- `causal_graph_info`: {"nodes": [...], "edges": [...], "is_dag": bool} +- `analysis_steps`: list of analysis steps with memory context + +**Deterministic Mode** (when `initial_state` is a dict): +- `initial_state`: the provided input state (dict) +- `evolved_state`: state after applying `max_steps` of the evolution operator +- `counterfactual_scenarios`: list of `CounterfactualScenario` with name/interventions/expected_outcomes/probability/reasoning +- `causal_graph_info`: {"nodes": [...], "edges": [...], "is_dag": bool} +- `steps`: `max_steps` used + +Usage patterns & examples +------------------------- +1) LLM-based causal analysis (sophisticated reasoning) + +```python +agent = CRCAAgent( + variables=["a","b","c"], + model_name="gpt-4o", + max_loops=3 +) +agent.add_causal_relationship("a","b", strength=0.8) + +# LLM mode: pass task as string +task = "Analyze the causal relationship between a and b" +res = agent.run(task=task) +print(res["causal_analysis"]) +print(res["analysis_steps"]) +``` + +2) Deterministic simulation (script-style) + +```python +agent = CRCAAgent(variables=["a","b","c"]) +agent.add_causal_relationship("a","b", strength=0.8) +state = {"a":1.0, "b":2.0, "c":3.0} +res = agent.run(initial_state=state, max_steps=2) +print(res["evolved_state"]) +``` + +3) Orchestration / agent-style (JSON payloads) + +```python +payload = '{"a":1.0,"b":2.0,"c":3.0}' +res = agent.run(initial_state=payload, max_steps=1) +if "error" in res: + print("Bad payload:", res["error"]) +else: + print("Evolved:", res["evolved_state"]) +``` + +4) Lower-level testing & research + +```python +pred = agent._predict_outcomes({"a":1.0,"b":2.0},{"a":0.0}) +print(pred) +``` + +Design notes & limitations +-------------------------- +- **LLM Integration**: Uses swarms Agent infrastructure for LLM calls. Configure model via `model_name` parameter. Multi-loop reasoning enabled by default. +- **Dual-mode operation**: Automatically selects LLM mode (task string) or deterministic mode (initial_state dict). Both modes generate counterfactuals using deterministic methods. +- **Linearity**: default `_predict_outcomes` is linear in standardized z-space. To model non-linear dynamics, subclass `CRCAAgent` and override `_predict_outcomes`. +- **Probabilities**: scenario probability is a heuristic proximity measure (Mahalanobis-like) — not a formal posterior. +- **Stats**: the engine auto-fills standardization stats with sensible defaults (`mean=observed`, `std=1.0`) via `ensure_standardization_stats` to avoid degenerate std=0 cases. +- **Dependencies**: Lite intentionally avoids heavy libs (pandas/scipy/cvxpy) but includes LLM integration via swarms Agent base. + +Extending & integration +----------------------- +For advanced capabilities (structure learning, Bayesian inference, optimization, +extensive statistical methods), use the full CRCA Agent featured [WIP] +The Lite version provides core causal reasoning with LLM support while maintaining minimal dependencies. + +References +---------- +- Pearl, J. (2009). *Causality: Models, Reasoning, and Inference*. +- Pearl, J., & Mackenzie, D. (2018). *The Book of Why*. + + diff --git a/swarms/agents/__init__.py b/swarms/agents/__init__.py index ca3d52e6..f8bc09f1 100644 --- a/swarms/agents/__init__.py +++ b/swarms/agents/__init__.py @@ -3,6 +3,7 @@ from swarms.agents.consistency_agent import SelfConsistencyAgent from swarms.agents.create_agents_from_yaml import ( create_agents_from_yaml, ) +from swarms.agents.cr_ca_agent import CRCAAgent from swarms.agents.flexion_agent import ReflexionAgent from swarms.agents.gkp_agent import GKPAgent from swarms.agents.i_agent import IterativeReflectiveExpansion @@ -22,4 +23,5 @@ __all__ = [ "ReflexionAgent", "GKPAgent", "AgentJudge", + "CRCAAgent", ] diff --git a/swarms/agents/cr_ca_agent.py b/swarms/agents/cr_ca_agent.py new file mode 100644 index 00000000..8e456995 --- /dev/null +++ b/swarms/agents/cr_ca_agent.py @@ -0,0 +1,630 @@ +""" +CR-CA Lite: A lightweight Causal Reasoning with Counterfactual Analysis engine. + +This is a minimal implementation of the ASTT/CR-CA framework focusing on: +- Core evolution operator E(x) +- Counterfactual scenario generation +- Causal chain identification +- Basic causal graph operations + +Dependencies: numpy only (typing, dataclasses, enum are stdlib) +""" + +from typing import Dict, Any, List, Tuple, Optional, Union +import numpy as np +from dataclasses import dataclass +from enum import Enum + + +class CausalRelationType(Enum): + """Types of causal relationships""" + DIRECT = "direct" + INDIRECT = "indirect" + CONFOUNDING = "confounding" + MEDIATING = "mediating" + MODERATING = "moderating" + + +@dataclass +class CausalNode: + """Represents a node in the causal graph""" + name: str + value: Optional[float] = None + confidence: float = 1.0 + node_type: str = "variable" + + +@dataclass +class CausalEdge: + """Represents an edge in the causal graph""" + source: str + target: str + strength: float = 1.0 + relation_type: CausalRelationType = CausalRelationType.DIRECT + confidence: float = 1.0 + + +@dataclass +class CounterfactualScenario: + """Represents a counterfactual scenario""" + name: str + interventions: Dict[str, float] + expected_outcomes: Dict[str, float] + probability: float = 1.0 + reasoning: str = "" + + +class CRCAgent: + """ + CR-CA Lite: Lightweight Causal Reasoning with Counterfactual Analysis engine. + + Core components: + - Evolution operator: E(x) = _predict_outcomes() + - Counterfactual generation: generate_counterfactual_scenarios() + - Causal chain identification: identify_causal_chain() + - State mapping: _standardize_state() / _destandardize_value() + + Args: + variables: Optional list of variable names + causal_edges: Optional list of (source, target) tuples for initial edges + """ + + def __init__( + self, + variables: Optional[List[str]] = None, + causal_edges: Optional[List[Tuple[str, str]]] = None, + max_loops: Optional[Union[int, str]] = 1, + ): + """ + Initialize CR-CA Lite engine. + + Args: + variables: Optional list of variable names to add to graph + causal_edges: Optional list of (source, target) tuples for initial edges + """ + # Pure Python graph representation: {node: {child: strength}} + self.causal_graph: Dict[str, Dict[str, float]] = {} + self.causal_graph_reverse: Dict[str, List[str]] = {} # For fast parent lookup + + # Standardization statistics: {'var': {'mean': m, 'std': s}} + self.standardization_stats: Dict[str, Dict[str, float]] = {} + + # Initialize graph + if variables: + for var in variables: + self._ensure_node_exists(var) + + if causal_edges: + for source, target in causal_edges: + self.add_causal_relationship(source, target) + + # Agent-like loop control: accept numeric or "auto" + # Keep the original (possibly "auto") value; resolution happens at run time. + self.max_loops = max_loops + + def _ensure_node_exists(self, node: str) -> None: + """Ensure node present in graph structures.""" + if node not in self.causal_graph: + self.causal_graph[node] = {} + if node not in self.causal_graph_reverse: + self.causal_graph_reverse[node] = [] + + def add_causal_relationship( + self, + source: str, + target: str, + strength: float = 1.0, + relation_type: CausalRelationType = CausalRelationType.DIRECT, + confidence: float = 1.0 + ) -> None: + """ + Add a causal edge to the graph. + + Args: + source: Source variable name + target: Target variable name + strength: Causal effect strength (default: 1.0) + relation_type: Type of causal relation (default: DIRECT) + confidence: Confidence in the relationship (default: 1.0) + """ + # Ensure nodes exist + self._ensure_node_exists(source) + self._ensure_node_exists(target) + + # Add or update edge: source -> target with strength + self.causal_graph[source][target] = float(strength) + + # Update reverse mapping for parent lookup (avoid duplicates) + if source not in self.causal_graph_reverse[target]: + self.causal_graph_reverse[target].append(source) + + def _get_parents(self, node: str) -> List[str]: + """ + Get parent nodes (predecessors) of a node. + + Args: + node: Node name + + Returns: + List of parent node names + """ + return self.causal_graph_reverse.get(node, []) + + def _get_children(self, node: str) -> List[str]: + """ + Get child nodes (successors) of a node. + + Args: + node: Node name + + Returns: + List of child node names + """ + return list(self.causal_graph.get(node, {}).keys()) + + def _topological_sort(self) -> List[str]: + """ + Perform topological sort using Kahn's algorithm (pure Python). + + Returns: + List of nodes in topological order + """ + # Compute in-degrees + in_degree: Dict[str, int] = {node: 0 for node in self.causal_graph.keys()} + for node in self.causal_graph: + for child in self._get_children(node): + in_degree[child] = in_degree.get(child, 0) + 1 + + # Initialize queue with nodes having no incoming edges + queue: List[str] = [node for node, degree in in_degree.items() if degree == 0] + result: List[str] = [] + + # Process nodes + while queue: + node = queue.pop(0) + result.append(node) + + # Reduce in-degree of children + for child in self._get_children(node): + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + + return result + + def identify_causal_chain(self, start: str, end: str) -> List[str]: + """ + Find shortest causal path from start to end using BFS (pure Python). + + Implements core causal chain identification (Ax2, Ax6). + + Args: + start: Starting variable + end: Target variable + + Returns: + List of variables forming the causal chain, or empty list if no path exists + """ + if start not in self.causal_graph or end not in self.causal_graph: + return [] + + if start == end: + return [start] + + # BFS to find shortest path + queue: List[Tuple[str, List[str]]] = [(start, [start])] + visited: set = {start} + + while queue: + current, path = queue.pop(0) + + # Check all children + for child in self._get_children(current): + if child == end: + return path + [child] + + if child not in visited: + visited.add(child) + queue.append((child, path + [child])) + + return [] # No path found + + # detect_confounders removed in Lite version (advanced inference) + + def _has_path(self, start: str, end: str) -> bool: + """ + Check if a path exists from start to end using DFS. + + Args: + start: Starting node + end: Target node + + Returns: + True if path exists, False otherwise + """ + if start == end: + return True + + stack = [start] + visited = set() + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + + for child in self._get_children(current): + if child == end: + return True + if child not in visited: + stack.append(child) + + return False + + # identify_adjustment_set removed in Lite version (advanced inference) + + def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]: + """ + Standardize state values to z-scores. + + Args: + state: Dictionary of variable values + + Returns: + Dictionary of standardized (z-score) values + """ + z: Dict[str, float] = {} + for k, v in state.items(): + s = self.standardization_stats.get(k) + if s and s.get("std", 0.0) > 0: + z[k] = (v - s["mean"]) / s["std"] + else: + z[k] = v + return z + + def _destandardize_value(self, var: str, z_value: float) -> float: + """ + Convert z-score back to original scale. + + Args: + var: Variable name + z_value: Standardized (z-score) value + + Returns: + Original scale value + """ + s = self.standardization_stats.get(var) + if s and s.get("std", 0.0) > 0: + return z_value * s["std"] + s["mean"] + return z_value + + def _predict_outcomes( + self, + factual_state: Dict[str, float], + interventions: Dict[str, float] + ) -> Dict[str, float]: + """ + Evolution operator E(x): Predict outcomes given state and interventions. + + This is the core CR-CA evolution operator implementing: + x_{t+1} = E(x_t) + + Mathematical foundation: + - Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε\n+ - NOTE: This implementation is linear. To model nonlinear dynamics override\n+ `_predict_outcomes` in a subclass with a custom evolution operator.\n*** End Patch + - Propagates effects through causal graph in topological order + - Standardizes inputs, computes in z-space, de-standardizes outputs + + Args: + factual_state: Current world state (baseline) + interventions: Interventions to apply (do-operator) + + Returns: + Dictionary of predicted variable values + """ + # Merge factual state with interventions + raw = factual_state.copy() + raw.update(interventions) + + # Standardize to z-scores + z_state = self._standardize_state(raw) + z_pred = dict(z_state) + + # Process nodes in topological order + for node in self._topological_sort(): + # If node is intervened on, keep its value + if node in interventions: + if node not in z_pred: + z_pred[node] = z_state.get(node, 0.0) + continue + + # Get parents + parents = self._get_parents(node) + if not parents: + continue + + # Compute linear combination: Σᵢ βᵢ·z_xi + s = 0.0 + for p in parents: + pz = z_pred.get(p, z_state.get(p, 0.0)) + strength = self.causal_graph.get(p, {}).get(node, 0.0) + s += pz * strength + + z_pred[node] = s + + # De-standardize results + return {v: self._destandardize_value(v, z) for v, z in z_pred.items()} + + def _calculate_scenario_probability( + self, + factual_state: Dict[str, float], + interventions: Dict[str, float] + ) -> float: + """ + Calculate a heuristic probability of a counterfactual scenario. + + NOTE: This is a lightweight heuristic proximity measure (Mahalanobis-like) + and NOT a full statistical estimator — it ignores covariance and should + be treated as a relative plausibility score for Lite usage. + + Args: + factual_state: Baseline state + interventions: Intervention values + + Returns: + Heuristic probability value between 0.05 and 0.98 + """ + z_sq = 0.0 + for var, new in interventions.items(): + s = self.standardization_stats.get(var, {"mean": 0.0, "std": 1.0}) + mu, sd = s.get("mean", 0.0), s.get("std", 1.0) or 1.0 + old = factual_state.get(var, mu) + dz = (new - mu) / sd - (old - mu) / sd + z_sq += float(dz) * float(dz) + + p = 0.95 * float(np.exp(-0.5 * z_sq)) + 0.05 + return float(max(0.05, min(0.98, p))) + + def generate_counterfactual_scenarios( + self, + factual_state: Dict[str, float], + target_variables: List[str], + max_scenarios: int = 5 + ) -> List[CounterfactualScenario]: + """ + Generate counterfactual scenarios for target variables. + + Implements Ax8 (Counterfactuals) - core CR-CA functionality. + + Args: + factual_state: Current factual state + target_variables: Variables to generate counterfactuals for + max_scenarios: Maximum number of scenarios per variable + + Returns: + List of CounterfactualScenario objects + """ + # Ensure stats exist for variables in factual_state (fallback behavior) + self.ensure_standardization_stats(factual_state) + + scenarios: List[CounterfactualScenario] = [] + z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0] + + for i, tv in enumerate(target_variables[:max_scenarios]): + stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0}) + cur = factual_state.get(tv, stats.get("mean", 0.0)) + + # If std is zero or missing, use absolute perturbations instead + if not stats or stats.get("std", 0.0) <= 0: + base = cur + abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0] + vals = [base + step for step in abs_steps] + else: + mean = stats["mean"] + std = stats["std"] + cz = (cur - mean) / std + vals = [(cz + dz) * std + mean for dz in z_steps] + + for j, v in enumerate(vals): + interventions = {tv: float(v)} + scenarios.append( + CounterfactualScenario( + name=f"scenario_{i}_{j}", + interventions=interventions, + expected_outcomes=self._predict_outcomes( + factual_state, interventions + ), + probability=self._calculate_scenario_probability( + factual_state, interventions + ), + reasoning=f"Intervention on {tv} with value {v}", + ) + ) + + return scenarios + + def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]: + """ + Analyze the strength of causal relationship between two variables. + + Args: + source: Source variable + target: Target variable + + Returns: + Dictionary with strength, confidence, path_length, relation_type + """ + if source not in self.causal_graph or target not in self.causal_graph[source]: + return {"strength": 0.0, "confidence": 0.0, "path_length": float('inf')} + + strength = self.causal_graph[source].get(target, 0.0) + path = self.identify_causal_chain(source, target) + path_length = len(path) - 1 if path else float('inf') + + return { + "strength": float(strength), + "confidence": 1.0, # Simplified: assume full confidence + "path_length": path_length, + "relation_type": CausalRelationType.DIRECT.value + } + + def set_standardization_stats( + self, + variable: str, + mean: float, + std: float + ) -> None: + """ + Set standardization statistics for a variable. + + Args: + variable: Variable name + mean: Mean value + std: Standard deviation + """ + self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0} + + def ensure_standardization_stats(self, state: Dict[str, float]) -> None: + """ + Ensure standardization stats exist for all variables in a given state. + If stats are missing, create a sensible fallback (mean=observed, std=1.0). + This prevents degenerate std=0 issues in Lite mode. + """ + for var, val in state.items(): + if var not in self.standardization_stats: + self.standardization_stats[var] = {"mean": float(val), "std": 1.0} + + def get_nodes(self) -> List[str]: + """ + Get all nodes in the causal graph. + + Returns: + List of node names + """ + return list(self.causal_graph.keys()) + + def get_edges(self) -> List[Tuple[str, str]]: + """ + Get all edges in the causal graph. + + Returns: + List of (source, target) tuples + """ + edges = [] + for source, targets in self.causal_graph.items(): + for target in targets.keys(): + edges.append((source, target)) + return edges + + def is_dag(self) -> bool: + """ + Check if the causal graph is a DAG (no cycles). + + Uses DFS to detect cycles. + + Returns: + True if DAG, False if cycles exist + """ + def has_cycle(node: str, visited: set, rec_stack: set) -> bool: + """DFS to detect cycles.""" + visited.add(node) + rec_stack.add(node) + + for child in self._get_children(node): + if child not in visited: + if has_cycle(child, visited, rec_stack): + return True + elif child in rec_stack: + return True + + rec_stack.remove(node) + return False + + visited = set() + rec_stack = set() + + for node in self.causal_graph: + if node not in visited: + if has_cycle(node, visited, rec_stack): + return False + + return True + + def run( + self, + initial_state: Any, + target_variables: Optional[List[str]] = None, + max_steps: Union[int, str] = 1 + ) -> Dict[str, Any]: + """ + Run causal simulation: evolve state and generate counterfactuals. + + Simple entry point for CR-CA engine. + + Args: + initial_state: Initial world state + target_variables: Variables to generate counterfactuals for (default: all nodes) + max_steps: Number of evolution steps (default: 1) + + Returns: + Dictionary with evolved state, counterfactuals, and graph info + """ + # Accept either a dict initial_state or a JSON string (agent-like behavior) + if not isinstance(initial_state, dict): + try: + import json + parsed = json.loads(initial_state) + if isinstance(parsed, dict): + initial_state = parsed + else: + return {"error": "initial_state JSON must decode to a dict"} + except Exception: + return {"error": "initial_state must be a dict or JSON-encoded dict"} + + # Use all nodes as targets if not specified + if target_variables is None: + target_variables = list(self.causal_graph.keys()) + + # Resolve "auto" sentinel for max_steps (accepts method arg or instance-level default) + def _resolve_max_steps(value: Union[int, str]) -> int: + if isinstance(value, str) and value == "auto": + # Heuristic: one step per variable (at least 1) + return max(1, len(self.causal_graph)) + try: + return int(value) + except Exception: + return max(1, len(self.causal_graph)) + + effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.max_loops == 1 else self.max_loops) + # If caller passed default 1 and instance set a different max_loops, prefer instance value + if max_steps == 1 and self.max_loops != 1: + effective_steps = _resolve_max_steps(self.max_loops) + + # Evolve state + current_state = initial_state.copy() + for step in range(effective_steps): + current_state = self._predict_outcomes(current_state, {}) + + # Ensure standardization stats exist for the evolved state and generate counterfactuals from it + self.ensure_standardization_stats(current_state) + counterfactual_scenarios = self.generate_counterfactual_scenarios( + current_state, + target_variables, + max_scenarios=5 + ) + + return { + "initial_state": initial_state, + "evolved_state": current_state, + "counterfactual_scenarios": counterfactual_scenarios, + "causal_graph_info": { + "nodes": self.get_nodes(), + "edges": self.get_edges(), + "is_dag": self.is_dag() + }, + "steps": effective_steps + } + + +# Agent-like behavior: `run` accepts either a dict or a JSON string as the initial_state +# so the engine behaves like a normal agent by default.