diff --git a/swarms/agents/cr_ca_agent.py b/swarms/agents/cr_ca_agent.py index f03f5ec5..3e0db60b 100644 --- a/swarms/agents/cr_ca_agent.py +++ b/swarms/agents/cr_ca_agent.py @@ -91,15 +91,19 @@ class CRCALite: # Initialize graph if variables: for var in variables: - if var not in self.causal_graph: - self.causal_graph[var] = {} - if var not in self.causal_graph_reverse: - self.causal_graph_reverse[var] = [] - + self._ensure_node_exists(var) + if causal_edges: for source, target in causal_edges: self.add_causal_relationship(source, target) + def _ensure_node_exists(self, node: str) -> None: + """Ensure node present in graph structures.""" + if node not in self.causal_graph: + self.causal_graph[node] = {} + if node not in self.causal_graph_reverse: + self.causal_graph_reverse[node] = [] + def add_causal_relationship( self, source: str, @@ -118,20 +122,14 @@ class CRCALite: relation_type: Type of causal relation (default: DIRECT) confidence: Confidence in the relationship (default: 1.0) """ - # Initialize nodes if needed - if source not in self.causal_graph: - self.causal_graph[source] = {} - if target not in self.causal_graph: - self.causal_graph[target] = {} - if source not in self.causal_graph_reverse: - self.causal_graph_reverse[source] = [] - if target not in self.causal_graph_reverse: - self.causal_graph_reverse[target] = [] - - # Add edge: source -> target with strength - self.causal_graph[source][target] = strength - - # Update reverse mapping for parent lookup + # Ensure nodes exist + self._ensure_node_exists(source) + self._ensure_node_exists(target) + + # Add or update edge: source -> target with strength + self.causal_graph[source][target] = float(strength) + + # Update reverse mapping for parent lookup (avoid duplicates) if source not in self.causal_graph_reverse[target]: self.causal_graph_reverse[target].append(source) @@ -226,53 +224,7 @@ class CRCALite: return [] # No path found - def detect_confounders(self, treatment: str, outcome: str) -> List[str]: - """ - Detect confounders: variables that are ancestors of both treatment and outcome. - - Args: - treatment: Treatment variable - outcome: Outcome variable - - Returns: - List of confounder variable names - """ - def get_ancestors(node: str) -> set: - """Get all ancestors of a node using DFS.""" - ancestors = set() - stack = [node] - visited = set() - - while stack: - current = stack.pop() - if current in visited: - continue - visited.add(current) - - parents = self._get_parents(current) - for parent in parents: - if parent not in ancestors: - ancestors.add(parent) - stack.append(parent) - - return ancestors - - if treatment not in self.causal_graph or outcome not in self.causal_graph: - return [] - - t_ancestors = get_ancestors(treatment) - o_ancestors = get_ancestors(outcome) - - # Confounders are common ancestors - confounders = list(t_ancestors & o_ancestors) - - # Verify they have paths to both treatment and outcome - valid_confounders = [] - for conf in confounders: - if (self._has_path(conf, treatment) and self._has_path(conf, outcome)): - valid_confounders.append(conf) - - return valid_confounders + # detect_confounders removed in Lite version (advanced inference) def _has_path(self, start: str, end: str) -> bool: """ @@ -305,52 +257,7 @@ class CRCALite: return False - def identify_adjustment_set(self, treatment: str, outcome: str) -> List[str]: - """ - Identify back-door adjustment set for causal effect estimation. - - Args: - treatment: Treatment variable - outcome: Outcome variable - - Returns: - List of variables in the adjustment set - """ - if treatment not in self.causal_graph or outcome not in self.causal_graph: - return [] - - # Get parents of treatment - parents_t = set(self._get_parents(treatment)) - - # Get descendants of treatment - def get_descendants(node: str) -> set: - """Get all descendants using DFS.""" - descendants = set() - stack = [node] - visited = set() - - while stack: - current = stack.pop() - if current in visited: - continue - visited.add(current) - - for child in self._get_children(current): - if child not in descendants: - descendants.add(child) - stack.append(child) - - return descendants - - descendants_t = get_descendants(treatment) - - # Adjustment set: parents of treatment that are not descendants and not the outcome - adjustment = [ - z for z in parents_t - if z not in descendants_t and z != outcome - ] - - return adjustment + # identify_adjustment_set removed in Lite version (advanced inference) def _standardize_state(self, state: Dict[str, float]) -> Dict[str, float]: """ @@ -399,7 +306,7 @@ class CRCALite: x_{t+1} = E(x_t) Mathematical foundation: - - Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε + - Linear structural causal model: y = Σᵢ βᵢ·xᵢ + ε\n+ - NOTE: This implementation is linear. To model nonlinear dynamics override\n+ `_predict_outcomes` in a subclass with a custom evolution operator.\n*** End Patch - Propagates effects through causal graph in topological order - Standardizes inputs, computes in z-space, de-standardizes outputs @@ -449,16 +356,18 @@ class CRCALite: interventions: Dict[str, float] ) -> float: """ - Calculate probability of a counterfactual scenario. + Calculate a heuristic probability of a counterfactual scenario. - Uses Mahalanobis distance in standardized space. + NOTE: This is a lightweight heuristic proximity measure (Mahalanobis-like) + and NOT a full statistical estimator — it ignores covariance and should + be treated as a relative plausibility score for Lite usage. Args: factual_state: Baseline state interventions: Intervention values Returns: - Probability value between 0.05 and 0.98 + Heuristic probability value between 0.05 and 0.98 """ z_sq = 0.0 for var, new in interventions.items(): @@ -490,26 +399,7 @@ class CRCALite: Returns: List of CounterfactualScenario objects """ - scenarios = [] - z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0] - - for i, tv in enumerate(target_variables[:max_scenarios]): - s = self.standardization_stats.get(tv, {"mean": 0.0, "std": 1.0}) - cur = factual_state.get(tv, 0.0) - cz = (cur - s["mean"]) / s["std"] if s["std"] > 0 else 0.0 - - vals = [(cz + dz) * s["std"] + s["mean"] for dz in z_steps] - - for j, v in enumerate(vals): - scenarios.append(CounterfactualScenario( - name=f"scenario_{i}_{j}", - interventions={tv: v}, - expected_outcomes=self._predict_outcomes(factual_state, {tv: v}), - probability=self._calculate_scenario_probability(factual_state, {tv: v}), - reasoning=f"Intervention on {tv} with value {v}" - )) - - return scenarios + # Ensure stats exist for variables in factual_state (fallback behavior)\n+ self.ensure_standardization_stats(factual_state)\n+\n+ scenarios: List[CounterfactualScenario] = []\n+ z_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+\n+ for i, tv in enumerate(target_variables[:max_scenarios]):\n+ stats = self.standardization_stats.get(tv, {\"mean\": 0.0, \"std\": 1.0})\n+ cur = factual_state.get(tv, stats.get(\"mean\", 0.0))\n+\n+ # If std is zero or missing, use absolute perturbations instead\n+ if not stats or stats.get(\"std\", 0.0) <= 0:\n+ base = cur\n+ abs_steps = [-2.0, -1.0, -0.5, 0.5, 1.0, 2.0]\n+ vals = [base + step for step in abs_steps]\n+ else:\n+ mean = stats[\"mean\"]\n+ std = stats[\"std\"]\n+ cz = (cur - mean) / std\n+ vals = [(cz + dz) * std + mean for dz in z_steps]\n+\n+ for j, v in enumerate(vals):\n+ scenarios.append(CounterfactualScenario(\n+ name=f\"scenario_{i}_{j}\",\n+ interventions={tv: float(v)},\n+ expected_outcomes=self._predict_outcomes(factual_state, {tv: float(v)}),\n+ probability=self._calculate_scenario_probability(factual_state, {tv: float(v)}),\n+ reasoning=f\"Intervention on {tv} with value {v}\"\n+ ))\n+\n+ return scenarios\n*** End Patch def analyze_causal_strength(self, source: str, target: str) -> Dict[str, float]: """ @@ -552,6 +442,16 @@ class CRCALite: """ self.standardization_stats[variable] = {"mean": mean, "std": std if std > 0 else 1.0} + def ensure_standardization_stats(self, state: Dict[str, float]) -> None: + """ + Ensure standardization stats exist for all variables in a given state. + If stats are missing, create a sensible fallback (mean=observed, std=1.0). + This prevents degenerate std=0 issues in Lite mode. + """ + for var, val in state.items(): + if var not in self.standardization_stats: + self.standardization_stats[var] = {"mean": float(val), "std": 1.0} + def get_nodes(self) -> List[str]: """ Get all nodes in the causal graph. @@ -636,9 +536,10 @@ class CRCALite: for step in range(max_steps): current_state = self._predict_outcomes(current_state, {}) - # Generate counterfactual scenarios + # Ensure standardization stats exist for the evolved state and generate counterfactuals from it + self.ensure_standardization_stats(current_state) counterfactual_scenarios = self.generate_counterfactual_scenarios( - initial_state, + current_state, target_variables, max_scenarios=5 )