import argparse
import json
import os
import re
import time
from copy import deepcopy
from datetime import datetime
from functools import wraps

import numpy as np
import requests
from flashrag.config import Config
from flashrag.generator.generator import BaseGenerator
from flashrag.pipeline import BasicPipeline
from flashrag.retriever.retriever import BaseTextRetriever
from flashrag.utils import get_dataset
from transformers import AutoTokenizer

from config import logger
from src.agent import Agent, AgenticOutputs
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter


def retry(max_retries=10, sleep=1):
    """Decorator to retry a function with exponential backoff."""

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            func_name = func.__name__
            for attempt in range(max_retries):
                try:
                    result = func(*args, **kwargs)
                    return result
                except Exception as e:
                    logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
                    if attempt == max_retries - 1:
                        logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
                        raise e
                    backoff_time = sleep * (2**attempt)
                    logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
                    time.sleep(backoff_time)
            logger.error(f"Function {func_name} retry logic finished unexpectedly.")
            return None

        return wrapper

    return decorator


class RemoteRetriever(BaseTextRetriever):
    """A wrapper for remote retriever service with retry logic and logging."""

    def __init__(self, config: Config):
        """Initializes the RemoteRetriever."""
        super().__init__(config)
        self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
        self.topk = getattr(config, "retriever_topk", 5)
        logger.info(f"🔗 Remote retriever URL: {self.remote_url}")

    @retry(max_retries=3, sleep=2)
    def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
        """Search for documents using the remote retriever service."""
        num = num if num is not None else self.topk
        url = f"{self.remote_url}/search"

        try:
            response = requests.post(
                url,
                json={"query": query, "top_n": num, "return_score": return_score},
                timeout=30,
            )
            response.raise_for_status()

            results = response.json()
            return results
        except requests.exceptions.Timeout:
            logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
            raise
        except requests.exceptions.ConnectionError:
            logger.error(f"Could not connect to search service at {url}")
            raise
        except requests.exceptions.RequestException as e:
            logger.error(f"Search request failed: {e}", exc_info=True)
            raise
        except Exception as e:
            logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
            raise

    @retry(max_retries=3, sleep=2)
    def _batch_search(
        self, queries: list[str], num: int | None = None, return_score: bool = False
    ) -> list[list[dict]]:
        """Batch search for documents using the remote retriever service."""
        num = num if num is not None else self.topk
        url = f"{self.remote_url}/batch_search"

        try:
            response = requests.post(
                url,
                json={"query": queries, "top_n": num, "return_score": return_score},
                timeout=60,
            )
            response.raise_for_status()
            results = response.json()
            return results
        except requests.exceptions.Timeout:
            logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
            raise
        except requests.exceptions.ConnectionError:
            logger.error(f"Could not connect to batch search service at {url}")
            raise
        except requests.exceptions.RequestException as e:
            logger.error(f"Batch search request failed: {e}", exc_info=True)
            raise
        except Exception as e:
            logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
            raise


class ReSearchPipeline(BasicPipeline):
    """Pipeline for ReSearch method using Agent for generation and tool use."""

    def __init__(
        self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
    ):
        """Initializes the ReSearchPipeline."""
        super().__init__(config)
        logger.info("🔧 Initializing ReSearchPipeline...")

        self.retriever = retriever or RemoteRetriever(config)

        self.generator = generator or SGLRemoteGenerator(config)

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
            if not self.tokenizer.pad_token:
                logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
                self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.padding_side = "left"
            logger.info("✅ Tokenizer initialized.")
        except Exception as e:
            logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
            raise

        tokenizer_name = self.tokenizer.name_or_path.lower()

        if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
            adapter = R1DistilTokenizerAdapter()
        elif "llama" in tokenizer_name:
            adapter = LlamaTokenizerAdapter()
        else:
            logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
            adapter = R1DistilTokenizerAdapter()
        logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")

        def retriever_search(query: str, return_type=str, results: int = 5):
            try:
                search_results = self.retriever._search(query, num=results)
                return self.format_search_results(search_results)
            except Exception as e:
                logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
                return "<information>Search failed due to an internal error."

        self.agent = Agent(adapter, search_fn=retriever_search)
        logger.info("✅ Agent initialized.")
        logger.info("✅ ReSearchPipeline initialized successfully.")

    def format_search_results(self, search_results: list[dict]) -> str:
        """Formats search results into a string for the agent prompt."""
        if not search_results:
            return "<information>No results found.</information>"
        max_content_len = 500
        formatted = "\n-------\n".join(
            [
                f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
                for i, r in enumerate(search_results)
            ]
        )
        formatted_str = f"<information>{formatted}</information>"

        return formatted_str

    def extract_search_query(self, text: str) -> str | None:
        """Extract search query from text between <search> tags."""
        pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
        matches = pattern.findall(text)
        if matches:
            query = matches[-1].strip()
            return query
        return None

    def extract_answer(self, text: str) -> str | None:
        """Extract answer from text between <answer> tags."""
        pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
        matches = pattern.findall(text)
        if matches:
            answer = matches[-1].strip()

            return answer

        return None

    def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
        """Runs the ReSearch pipeline on the dataset using the Agent."""
        logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")

        try:
            questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]

        except (KeyError, AttributeError, TypeError) as e:
            logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
            logger.error("Ensure dataset items have a 'question' key or attribute.")
            return dataset

        agent_max_generations = getattr(self.config, "agent_max_generations", 32)
        generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)

        try:
            logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
            agent_outputs: AgenticOutputs = self.agent.run_agent(
                generate_fn=self.generator.generate,
                tokenizer=self.tokenizer,
                questions=questions,
                max_generations=agent_max_generations,
                max_new_tokens=generator_max_output_len,
            )
            final_responses = agent_outputs.final_response_str
            logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")

        except Exception as e:
            logger.error(f"Agent run failed during inference: {e}", exc_info=True)
            logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
            for item in dataset:
                if hasattr(item, "update_output"):
                    item.update_output("pred", "AGENT_ERROR")
                elif isinstance(item, dict):
                    item["pred"] = "AGENT_ERROR"

        logger.info("📝 Extracting answers and updating dataset items...")
        num_updated = 0
        num_missing_answers = 0
        if len(final_responses) == len(dataset):
            for i, item in enumerate(dataset):
                response = final_responses[i]
                answer = self.extract_answer(response)
                pred_to_save = answer if answer is not None else ""

                if answer is None:
                    num_missing_answers += 1

                if hasattr(item, "update_output"):
                    item.update_output("pred", pred_to_save)
                    item.update_output("final_response", response)
                    num_updated += 1
                elif isinstance(item, dict):
                    item["pred"] = pred_to_save
                    item["final_response"] = response
                    num_updated += 1
                else:
                    logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")

            logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
            if num_missing_answers > 0:
                logger.warning(f"{num_missing_answers} items had no <answer> tag.")
        else:
            logger.error(
                f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
            )
            for item in dataset:
                if hasattr(item, "update_output"):
                    item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
                elif isinstance(item, dict):
                    item["pred"] = "RESPONSE_COUNT_MISMATCH"

        if do_eval:
            logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
            try:
                dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
                logger.info("✅ Evaluation completed via base class method.")
            except Exception as e:
                logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
                logger.warning("Evaluation may be incomplete.")
        else:
            logger.info("Skipping evaluation step as do_eval=False.")

        logger.info("✅ ReSearch pipeline run finished.")
        return dataset


class SGLRemoteGenerator(BaseGenerator):
    """Class for decoder-only generator, based on SGLang remote service."""

    def __init__(self, config: Config):
        """Initializes the SGLRemoteGenerator."""
        super().__init__(config)
        logger.info("🔧 Initializing SGLRemoteGenerator...")
        sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
        self.sgl_remote_url = f"http://{sgl_url}/generate"
        self.health_check_url = f"http://{sgl_url}/health"
        logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
        self.model_path = getattr(config, "generator_model_path", None)
        if not self.model_path:
            logger.error("generator_model_path not found in config!")
            raise ValueError("generator_model_path is required for SGLRemoteGenerator")

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
            logger.info("✅ Tokenizer loaded for generator.")
        except Exception as e:
            logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
            raise

        self.generation_params = getattr(config, "generation_params", {})
        self.config = config

        self._check_health()

    def _check_health(self):
        """Checks the health of the remote generator service."""
        try:
            test_response = requests.get(self.health_check_url, timeout=5)
            test_response.raise_for_status()
            logger.info("✅ Remote generator service is available")
        except requests.exceptions.RequestException as e:
            logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
            logger.warning("Please ensure the SGLang service is running and accessible.")

    @retry(max_retries=5, sleep=2)
    def generate(
        self,
        input_list: list[str] | str,
        return_raw_output: bool = False,
        return_scores: bool = False,
        **params,
    ) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
        """Generates text using the remote SGLang service."""
        if isinstance(input_list, str):
            input_list = [input_list]
        if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
            raise ValueError("Input must be a string or a list of strings.")

        batch_size = len(input_list)
        data_to_remote = {"text": input_list}

        effective_params = deepcopy(self.generation_params)
        effective_params.update(params)

        curr_sampling_params = {}
        if effective_params.get("do_sample", True) is False:
            curr_sampling_params["temperature"] = 0.0
        else:
            curr_sampling_params["temperature"] = effective_params.get(
                "temperature", getattr(self.config, "temperature", 0.7)
            )

        default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
        curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)

        stop_sequences = effective_params.get("stop", [])
        if isinstance(stop_sequences, str):
            stop_sequences = [stop_sequences]
        if stop_sequences:
            curr_sampling_params["stop"] = stop_sequences

        keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
        for key in keys_to_remove:
            effective_params.pop(key, None)

        if "top_p" in effective_params:
            curr_sampling_params["top_p"] = effective_params["top_p"]
        if "top_k" in effective_params:
            curr_sampling_params["top_k"] = effective_params["top_k"]

        data_to_remote["sampling_params"] = curr_sampling_params

        if return_scores:
            data_to_remote["return_logprob"] = True
            data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)

        try:
            response = requests.post(
                self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
            )
            response.raise_for_status()

            response_list = response.json()

            if return_raw_output:
                return response_list

            generated_text = []
            for item in response_list:
                text = item.get("text", "")
                finish_reason = item.get("meta_info", {}).get("finish_reason", {})
                matched_stop = finish_reason.get("matched")
                if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
                    text += matched_stop
                generated_text.append(text)

            if return_scores:
                scores = []
                for resp_item in response_list:
                    logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
                    token_scores = [
                        np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
                    ]
                    scores.append(token_scores)
                return generated_text, scores
            else:
                return generated_text

        except requests.exceptions.Timeout:
            logger.error("Generation request timed out after 120 seconds.")
            raise
        except requests.exceptions.ConnectionError:
            logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
            raise
        except requests.exceptions.RequestException as e:
            logger.error(f"Network error during generation: {str(e)}", exc_info=True)
            raise
        except json.JSONDecodeError:
            response_text = "Unknown (error occurred before response object assignment)"
            if "response" in locals() and hasattr(response, "text"):
                response_text = response.text[:500]
            logger.error(
                f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
                exc_info=True,
            )
            raise
        except Exception as e:
            logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
            raise


def load_dataset_items(config: Config, split: str) -> list[dict | object]:
    """Loads dataset items using flashrag's get_dataset."""
    logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
    try:
        all_splits = get_dataset(config)
        if split not in all_splits:
            logger.error(
                f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
            )
            return []
        dataset_items = all_splits[split]
        logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")

        return dataset_items
    except FileNotFoundError:
        logger.error(
            f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
        )
        return []
    except Exception as e:
        logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
        return []


def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
    """Saves summary and debug information."""
    logger.info("💾 Saving results...")
    summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
    debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")

    num_items = len(result_dataset)

    logger.info(f"Saving summary results to {summary_file}...")
    try:
        with open(summary_file, "w", encoding="utf-8") as f:
            f.write("EVALUATION SUMMARY\n")
            f.write("=================\n\n")
            f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Run Duration: {run_duration:.2f} seconds\n")
            f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
            f.write(f"Model: {config.generator_model_path}\n")
            f.write(f"Retriever: {config.remote_retriever_url}\n")
            f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
            f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
            f.write(f"Total items processed: {num_items}\n")
            f.write("\nNote: Verification was skipped in this run.\n")
            f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")

        logger.info(f"✅ Summary saved to {summary_file}")
    except Exception as e:
        logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)

    logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
    try:
        debug_data = []
        for i, item in enumerate(result_dataset):
            item_data: dict[str, object] = {}

            def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
                if isinstance(data_item, dict):
                    return data_item.get(key_or_attr)
                elif hasattr(data_item, key_or_attr):
                    return getattr(data_item, key_or_attr)
                return None

            item_data["item_index"] = i
            item_data["question"] = get_item_value(item, "question")
            item_data["prediction"] = get_item_value(item, "pred")
            item_data["final_response"] = get_item_value(item, "final_response")

            gt_answer_val = None
            try:
                gt_answer_val = get_item_value(item, "answer")
                if gt_answer_val is None:
                    answers_list = get_item_value(item, "answers")
                    if isinstance(answers_list, list) and answers_list:
                        raw_ans = answers_list[0]
                        if isinstance(raw_ans, (str, int, float, bool)):
                            gt_answer_val = raw_ans
                        else:
                            gt_answer_val = str(raw_ans)
                elif not isinstance(gt_answer_val, (str, int, float, bool)):
                    gt_answer_val = str(gt_answer_val)
            except Exception as e:
                logger.warning(f"Could not safely get ground truth for item {i}: {e}")
                gt_answer_val = "ERROR_GETTING_ANSWER"
            item_data["ground_truth"] = gt_answer_val

            eval_score_val = None
            try:
                eval_score_val = get_item_value(item, "score")
                if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
                    eval_score_val = str(eval_score_val)
            except Exception as e:
                logger.warning(f"Could not safely get score for item {i}: {e}")
                eval_score_val = "ERROR_GETTING_SCORE"
            item_data["eval_score"] = eval_score_val

            debug_data.append(item_data)

        with open(debug_file, "w", encoding="utf-8") as f:
            json.dump(debug_data, f, indent=2, ensure_ascii=False)
        logger.info(f"✅ Debug information saved to {debug_file}")
    except Exception as e:
        logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)


def research(args: argparse.Namespace, config: Config):
    """Main function to run the research evaluation pipeline."""
    logger.info("🚀 Starting research pipeline execution...")
    start_time = time.time()

    test_data = load_dataset_items(config, args.split)
    if not test_data:
        logger.error("Failed to load test data. Exiting.")
        return

    try:
        logger.info("🏗️ Building ReSearchPipeline...")
        pipeline = ReSearchPipeline(config)
        logger.info("✅ Pipeline built successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
        return

    agent_max_generations = getattr(config, "agent_max_generations", 32)
    generator_max_output_len = getattr(config, "generator_max_output_len", 24576)

    try:
        logger.info("🏃 Starting pipeline run...")
        result_dataset = pipeline.run(test_data, do_eval=True)
        logger.info("✅ Pipeline run completed.")
    except Exception as e:
        logger.error(f"Error during pipeline run: {e}", exc_info=True)
        result_dataset = test_data
        logger.warning("Pipeline run failed, attempting to save inputs/partial results.")

    run_duration = time.time() - start_time
    logger.info(f"Total run duration: {run_duration:.2f} seconds.")
    save_results(args, config, result_dataset, run_duration)

    logger.info("🏁 Research pipeline execution finished.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
    parser.add_argument(
        "--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default="bamboogle",
        help="Name of the dataset (must match config or data_dir structure).",
    )
    parser.add_argument(
        "--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
    )
    parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
    parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")

    parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
    parser.add_argument(
        "--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
    )
    parser.add_argument(
        "--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
    )
    parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
    parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
    parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
    parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")

    args = parser.parse_args()
    logger.info(f"Starting evaluation script with arguments: {args}")

    try:
        os.makedirs(args.save_dir, exist_ok=True)
        logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
    except OSError as e:
        logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
        exit(1)

    config_overrides = {
        k: v
        for k, v in vars(args).items()
        if v is not None
        and k
        not in [
            "config_path",
            "dataset_name",
            "split",
            "save_dir",
            "save_note",
        ]
    }

    logger.info(f"🔧 Loading configuration from: {args.config_path}")
    try:
        config = Config(args.config_path, config_dict=config_overrides)
        config.dataset_name = args.dataset_name
        if args.data_dir:
            config.data_dir = args.data_dir

        logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
        logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
        logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
        logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")

        logger.info("✅ Config loaded and potentially overridden by CLI arguments.")

        config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)

    except FileNotFoundError:
        logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
        exit(1)
    except Exception as e:
        logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
        exit(1)

    research(args, config)