diff --git a/swarms/agents/cr_ca_agent.py b/swarms/agents/cr_ca_agent.py index 3e0db60b..8e456995 100644 --- a/swarms/agents/cr_ca_agent.py +++ b/swarms/agents/cr_ca_agent.py @@ -10,7 +10,7 @@ This is a minimal implementation of the ASTT/CR-CA framework focusing on: Dependencies: numpy only (typing, dataclasses, enum are stdlib) """ -from typing import Dict, Any, List, Tuple, Optional +from typing import Dict, Any, List, Tuple, Optional, Union import numpy as np from dataclasses import dataclass from enum import Enum @@ -54,7 +54,7 @@ class CounterfactualScenario: reasoning: str = "" -class CRCALite: +class CRCAgent: """ CR-CA Lite: Lightweight Causal Reasoning with Counterfactual Analysis engine. @@ -73,6 +73,7 @@ class CRCALite: self, variables: Optional[List[str]] = None, causal_edges: Optional[List[Tuple[str, str]]] = None, + max_loops: Optional[Union[int, str]] = 1, ): """ Initialize CR-CA Lite engine. @@ -97,6 +98,10 @@ class CRCALite: for source, target in causal_edges: self.add_causal_relationship(source, target) + # Agent-like loop control: accept numeric or "auto" + # Keep the original (possibly "auto") value; resolution happens at run time. + self.max_loops = max_loops + def _ensure_node_exists(self, node: str) -> None: """Ensure node present in graph structures.""" if node not in self.causal_graph: @@ -269,7 +274,7 @@ class CRCALite: Returns: Dictionary of standardized (z-score) values """ - z = {} + z: Dict[str, float] = {} for k, v in state.items(): s = self.standardization_stats.get(k) if s and s.get("std", 0.0) > 0: @@ -399,7 +404,44 @@ class CRCALite: Returns: List of CounterfactualScenario objects """ - # Ensure stats exist for variables in factual_state (fallback behavior)\n+ self.ensure_standardization_stats(factual_state)\n+\n+ scenarios: List[CounterfactualScenario] = []\n+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+\n+ for i, tv in enumerate(target_variables[:max_scenarios]):\n+ stats = self.standardization_stats.get(tv, {\"mean\": 0.0, \"std\": 1.0})\n+ cur = factual_state.get(tv, stats.get(\"mean\", 0.0))\n+\n+ # If std is zero or missing, use absolute perturbations instead\n+ if not stats or stats.get(\"std\", 0.0) <= 0:\n+ base = cur\n+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+ vals = [base + step for step in abs_steps]\n+ else:\n+ mean = stats[\"mean\"]\n+ std = stats[\"std\"]\n+ cz = (cur - mean) / std\n+ vals = [(cz + dz) * std + mean for dz in z_steps]\n+\n+ for j, v in enumerate(vals):\n+ scenarios.append(CounterfactualScenario(\n+ name=f\"scenario_{i}_{j}\",\n+ interventions={tv: float(v)},\n+ expected_outcomes=self._predict_outcomes(factual_state, {tv: float(v)}),\n+ probability=self._calculate_scenario_probability(factual_state, {tv: float(v)}),\n+ reasoning=f\"Intervention on {tv} with value {v}\"\n+ ))\n+\n+ return scenarios\n*** End Patch + # Ensure stats exist for variables in factual_state (fallback behavior) + self.ensure_standardization_stats(factual_state) + + scenarios: List[CounterfactualScenario] = [] + z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0] + + for i, tv in enumerate(target_variables[:max_scenarios]): + stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0}) + cur = factual_state.get(tv, stats.get("mean", 0.0)) + + # If std is zero or missing, use absolute perturbations instead + if not stats or stats.get("std", 0.0) <= 0: + base = cur + abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0] + vals = [base + step for step in abs_steps] + else: + mean = stats["mean"] + std = stats["std"] + cz = (cur - mean) / std + vals = [(cz + dz) * std + mean for dz in z_steps] + + for j, v in enumerate(vals): + interventions = {tv: float(v)} + scenarios.append( + CounterfactualScenario( + name=f"scenario_{i}_{j}", + interventions=interventions, + expected_outcomes=self._predict_outcomes( + factual_state, interventions + ), + probability=self._calculate_scenario_probability( + factual_state, interventions + ), + reasoning=f"Intervention on {tv} with value {v}", + ) + ) + + return scenarios def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]: """ @@ -510,9 +552,9 @@ class CRCALite: def run( self, - initial_state: Dict[str, float], + initial_state: Any, target_variables: Optional[List[str]] = None, - max_steps: int = 1 + max_steps: Union[int, str] = 1 ) -> Dict[str, Any]: """ Run causal simulation: evolve state and generate counterfactuals. @@ -527,13 +569,40 @@ class CRCALite: Returns: Dictionary with evolved state, counterfactuals, and graph info """ + # Accept either a dict initial_state or a JSON string (agent-like behavior) + if not isinstance(initial_state, dict): + try: + import json + parsed = json.loads(initial_state) + if isinstance(parsed, dict): + initial_state = parsed + else: + return {"error": "initial_state JSON must decode to a dict"} + except Exception: + return {"error": "initial_state must be a dict or JSON-encoded dict"} + # Use all nodes as targets if not specified if target_variables is None: target_variables = list(self.causal_graph.keys()) + # Resolve "auto" sentinel for max_steps (accepts method arg or instance-level default) + def _resolve_max_steps(value: Union[int, str]) -> int: + if isinstance(value, str) and value == "auto": + # Heuristic: one step per variable (at least 1) + return max(1, len(self.causal_graph)) + try: + return int(value) + except Exception: + return max(1, len(self.causal_graph)) + + effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.max_loops == 1 else self.max_loops) + # If caller passed default 1 and instance set a different max_loops, prefer instance value + if max_steps == 1 and self.max_loops != 1: + effective_steps = _resolve_max_steps(self.max_loops) + # Evolve state current_state = initial_state.copy() - for step in range(max_steps): + for step in range(effective_steps): current_state = self._predict_outcomes(current_state, {}) # Ensure standardization stats exist for the evolved state and generate counterfactuals from it @@ -553,5 +622,9 @@ class CRCALite: "edges": self.get_edges(), "is_dag": self.is_dag() }, - "steps": max_steps + "steps": effective_steps } + + +# Agent-like behavior: `run` accepts either a dict or a JSON string as the initial_state +# so the engine behaves like a normal agent by default.