Refactor CRCAgent class and enhance type hints

pull/1233/head
CI-DEV 4 weeks ago committed by GitHub
parent 535b2fe426
commit b53a52ec34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -10,7 +10,7 @@ This is a minimal implementation of the ASTT/CR-CA framework focusing on:
Dependencies: numpy only (typing, dataclasses, enum are stdlib) Dependencies: numpy only (typing, dataclasses, enum are stdlib)
""" """
from typing import Dict, Any, List, Tuple, Optional from typing import Dict, Any, List, Tuple, Optional, Union
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@ -54,7 +54,7 @@ class CounterfactualScenario:
reasoning: str = "" reasoning: str = ""
class CRCALite: class CRCAgent:
""" """
CR-CA Lite: Lightweight Causal Reasoning with Counterfactual Analysis engine. CR-CA Lite: Lightweight Causal Reasoning with Counterfactual Analysis engine.
@ -73,6 +73,7 @@ class CRCALite:
self, self,
variables: Optional[List[str]] = None, variables: Optional[List[str]] = None,
causal_edges: Optional[List[Tuple[str, str]]] = None, causal_edges: Optional[List[Tuple[str, str]]] = None,
max_loops: Optional[Union[int, str]] = 1,
): ):
""" """
Initialize CR-CA Lite engine. Initialize CR-CA Lite engine.
@ -97,6 +98,10 @@ class CRCALite:
for source, target in causal_edges: for source, target in causal_edges:
self.add_causal_relationship(source, target) self.add_causal_relationship(source, target)
# Agent-like loop control: accept numeric or "auto"
# Keep the original (possibly "auto") value; resolution happens at run time.
self.max_loops = max_loops
def _ensure_node_exists(self, node: str) -> None: def _ensure_node_exists(self, node: str) -> None:
"""Ensure node present in graph structures.""" """Ensure node present in graph structures."""
if node not in self.causal_graph: if node not in self.causal_graph:
@ -269,7 +274,7 @@ class CRCALite:
Returns: Returns:
Dictionary of standardized (z-score) values Dictionary of standardized (z-score) values
""" """
z = {} z: Dict[str, float] = {}
for k, v in state.items(): for k, v in state.items():
s = self.standardization_stats.get(k) s = self.standardization_stats.get(k)
if s and s.get("std", 0.0) > 0: if s and s.get("std", 0.0) > 0:
@ -399,7 +404,44 @@ class CRCALite:
Returns: Returns:
List of CounterfactualScenario objects List of CounterfactualScenario objects
""" """
# Ensure stats exist for variables in factual_state (fallback behavior)\n+ self.ensure_standardization_stats(factual_state)\n+\n+ scenarios: List[CounterfactualScenario] = []\n+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+\n+ for i, tv in enumerate(target_variables[:max_scenarios]):\n+ stats = self.standardization_stats.get(tv, {\"mean\": 0.0, \"std\": 1.0})\n+ cur = factual_state.get(tv, stats.get(\"mean\", 0.0))\n+\n+ # If std is zero or missing, use absolute perturbations instead\n+ if not stats or stats.get(\"std\", 0.0) <= 0:\n+ base = cur\n+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+ vals = [base + step for step in abs_steps]\n+ else:\n+ mean = stats[\"mean\"]\n+ std = stats[\"std\"]\n+ cz = (cur - mean) / std\n+ vals = [(cz + dz) * std + mean for dz in z_steps]\n+\n+ for j, v in enumerate(vals):\n+ scenarios.append(CounterfactualScenario(\n+ name=f\"scenario_{i}_{j}\",\n+ interventions={tv: float(v)},\n+ expected_outcomes=self._predict_outcomes(factual_state, {tv: float(v)}),\n+ probability=self._calculate_scenario_probability(factual_state, {tv: float(v)}),\n+ reasoning=f\"Intervention on {tv} with value {v}\"\n+ ))\n+\n+ return scenarios\n*** End Patch # Ensure stats exist for variables in factual_state (fallback behavior)
self.ensure_standardization_stats(factual_state)
scenarios: List[CounterfactualScenario] = []
z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
for i, tv in enumerate(target_variables[:max_scenarios]):
stats = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0})
cur = factual_state.get(tv, stats.get("mean", 0.0))
# If std is zero or missing, use absolute perturbations instead
if not stats or stats.get("std", 0.0) <= 0:
base = cur
abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]
vals = [base + step for step in abs_steps]
else:
mean = stats["mean"]
std = stats["std"]
cz = (cur - mean) / std
vals = [(cz + dz) * std + mean for dz in z_steps]
for j, v in enumerate(vals):
interventions = {tv: float(v)}
scenarios.append(
CounterfactualScenario(
name=f"scenario_{i}_{j}",
interventions=interventions,
expected_outcomes=self._predict_outcomes(
factual_state, interventions
),
probability=self._calculate_scenario_probability(
factual_state, interventions
),
reasoning=f"Intervention on {tv} with value {v}",
)
)
return scenarios
def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]: def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]:
""" """
@ -510,9 +552,9 @@ class CRCALite:
def run( def run(
self, self,
initial_state: Dict[str, float], initial_state: Any,
target_variables: Optional[List[str]] = None, target_variables: Optional[List[str]] = None,
max_steps: int = 1 max_steps: Union[int, str] = 1
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run causal simulation: evolve state and generate counterfactuals. Run causal simulation: evolve state and generate counterfactuals.
@ -527,13 +569,40 @@ class CRCALite:
Returns: Returns:
Dictionary with evolved state, counterfactuals, and graph info Dictionary with evolved state, counterfactuals, and graph info
""" """
# Accept either a dict initial_state or a JSON string (agent-like behavior)
if not isinstance(initial_state, dict):
try:
import json
parsed = json.loads(initial_state)
if isinstance(parsed, dict):
initial_state = parsed
else:
return {"error": "initial_state JSON must decode to a dict"}
except Exception:
return {"error": "initial_state must be a dict or JSON-encoded dict"}
# Use all nodes as targets if not specified # Use all nodes as targets if not specified
if target_variables is None: if target_variables is None:
target_variables = list(self.causal_graph.keys()) target_variables = list(self.causal_graph.keys())
# Resolve "auto" sentinel for max_steps (accepts method arg or instance-level default)
def _resolve_max_steps(value: Union[int, str]) -> int:
if isinstance(value, str) and value == "auto":
# Heuristic: one step per variable (at least 1)
return max(1, len(self.causal_graph))
try:
return int(value)
except Exception:
return max(1, len(self.causal_graph))
effective_steps = _resolve_max_steps(max_steps if max_steps != 1 or self.max_loops == 1 else self.max_loops)
# If caller passed default 1 and instance set a different max_loops, prefer instance value
if max_steps == 1 and self.max_loops != 1:
effective_steps = _resolve_max_steps(self.max_loops)
# Evolve state # Evolve state
current_state = initial_state.copy() current_state = initial_state.copy()
for step in range(max_steps): for step in range(effective_steps):
current_state = self._predict_outcomes(current_state, {}) current_state = self._predict_outcomes(current_state, {})
# Ensure standardization stats exist for the evolved state and generate counterfactuals from it # Ensure standardization stats exist for the evolved state and generate counterfactuals from it
@ -553,5 +622,9 @@ class CRCALite:
"edges": self.get_edges(), "edges": self.get_edges(),
"is_dag": self.is_dag() "is_dag": self.is_dag()
}, },
"steps": max_steps "steps": effective_steps
} }
# Agent-like behavior: `run` accepts either a dict or a JSON string as the initial_state
# so the engine behaves like a normal agent by default.

Loading…
Cancel
Save