"""
Simple CLI inference script with search functionality.

This script allows interaction with a model (with optional LoRA adapter)
and provides search functionality for data retrieval.
"""

import argparse
import json
import os
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

from unsloth import FastLanguageModel
from vllm import SamplingParams

from src.config import MODEL_NAME, OUTPUT_DIR, logger
from src.search_module import load_vectorstore, search


def find_latest_checkpoint(search_dir=None):
    """
    Find the latest checkpoint in the specified directory or OUTPUT_DIR.

    Args:
        search_dir: Directory to search for checkpoints (default: OUTPUT_DIR)

    Returns:
        Path to the latest checkpoint or None if no checkpoints found
    """
    if search_dir is None:
        search_dir = "trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236"
        logger.info(f"No search directory provided, using default: {search_dir}")
    else:
        logger.info(f"Searching for checkpoints in: {search_dir}")

    # Check if the directory exists first
    if not os.path.exists(search_dir):
        logger.warning(f"Search directory {search_dir} does not exist")
        return None

    # First try to find checkpoints in the format checkpoint-{step}
    import glob

    checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*"))

    if checkpoints:
        # Extract checkpoint numbers and sort
        checkpoint_numbers = []
        for checkpoint in checkpoints:
            match = re.search(r"checkpoint-(\d+)$", checkpoint)
            if match:
                checkpoint_numbers.append((int(match.group(1)), checkpoint))

        if checkpoint_numbers:
            # Sort by checkpoint number (descending)
            checkpoint_numbers.sort(reverse=True)
            latest = checkpoint_numbers[0][1]
            logger.info(f"Found latest checkpoint: {latest}")
            return latest

    # If no checkpoints found, look for saved_adapter_{timestamp}.bin files
    adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin"))
    if adapter_files:
        # Sort by modification time (newest first)
        adapter_files.sort(key=os.path.getmtime, reverse=True)
        latest = adapter_files[0]
        logger.info(f"Found latest adapter file: {latest}")
        return latest

    # If all else fails, look for any .bin files
    bin_files = glob.glob(os.path.join(search_dir, "*.bin"))
    if bin_files:
        # Sort by modification time (newest first)
        bin_files.sort(key=os.path.getmtime, reverse=True)
        latest = bin_files[0]
        logger.info(f"Found latest .bin file: {latest}")
        return latest

    logger.warning(f"No checkpoints found in {search_dir}")
    return None


def setup_model_and_tokenizer():
    """Initialize model and tokenizer with LoRA support."""
    config = {
        "max_seq_length": 4096 * 2,
        "lora_rank": 64,
        "gpu_memory_utilization": 0.6,
        "model_name": MODEL_NAME,
        "target_modules": [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
    }

    logger.info(f"Setting up model {config['model_name']} with LoRA support...")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=config["model_name"],
        max_seq_length=config["max_seq_length"],
        load_in_4bit=True,
        fast_inference=True,
        max_lora_rank=config["lora_rank"],
        gpu_memory_utilization=config["gpu_memory_utilization"],
    )

    # Setup LoRA
    model = FastLanguageModel.get_peft_model(
        model,
        r=config["lora_rank"],
        target_modules=config["target_modules"],
        lora_alpha=config["lora_rank"],
        use_gradient_checkpointing=True,
        random_state=3407,
    )

    logger.info("Model and tokenizer setup complete.")
    return model, tokenizer


def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> SamplingParams:
    """Get sampling parameters for generation."""
    return SamplingParams(
        temperature=temperature,
        top_p=0.95,
        max_tokens=max_tokens,
    )


def extract_function_calls(text: str) -> List[Dict[str, Any]]:
    """Extract function calls from a text."""
    import json
    import re

    # Pattern to match JSON objects
    pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}"
    json_matches = re.findall(pattern, text)

    function_calls = []
    for json_str in json_matches:
        try:
            obj = json.loads(json_str)
            if "function" in obj:
                function_calls.append(obj)
        except json.JSONDecodeError:
            continue

    return function_calls


def build_user_prompt(q):
    """
    Build a user prompt with the question and search tool definition.

    Args:
        q (str): The question to ask

    Returns:
        str: Formatted user prompt
    """
    user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
Given a question, answer it using by doing searches using the search_corpus tool.
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.

PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION.
ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING.
PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES.

You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in <think> and </think> tags.

{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}

Question: {q}
"""
    return user_prompt


def format_search_results(results: Union[str, List[str]]) -> str:
    """
    Format search results for display.

    Args:
        results: Search results as string or list of strings

    Returns:
        Formatted search results
    """
    if isinstance(results, list):
        content = "\n".join([f"Result {i + 1}:\n{r}\n------" for i, r in enumerate(results)])
    else:
        content = results

    return f"\n===== SEARCH RESULTS =====\n{content}\n===========================\n"


class DeepSearchCLI:
    """CLI for interacting with the model and search functionality."""

    def __init__(
        self,
        lora_path: Optional[str] = None,
        temperature: float = 0.7,
        system_prompt: Optional[str] = None,
    ):
        """
        Initialize the CLI.

        Args:
            lora_path: Path to LoRA weights (None for base model)
            temperature: Sampling temperature
            system_prompt: Optional system prompt to guide the model's behavior
        """
        self.model, self.tokenizer = setup_model_and_tokenizer()
        self.lora_path = lora_path
        self.temperature = temperature
        self.sampling_params = get_sampling_params(temperature)
        self.lora_request = None
        self.history = []
        self.search_history = []
        self.system_prompt = (
            system_prompt
            or f"""Cutting Knowledge Date: December 2023
Today Date: {datetime.now().strftime("%d %b %Y")}

When you receive a tool call response, use the output to format an answer to the original user question.

You are a helpful assistant with tool calling capabilities."""
        )

        # Load LoRA if specified
        if self.lora_path:
            logger.info(f"Loading LoRA adapter from {self.lora_path}...")
            self.lora_request = self.model.load_lora(self.lora_path)
            if self.lora_request:
                logger.info(f"LoRA adapter loaded successfully: {self.lora_request}")
            else:
                logger.error("Failed to load LoRA adapter")

    def generate(self, prompt: str, max_generations: int = 20) -> str:
        """
        Generate a response to the prompt using agentic mechanism.

        Args:
            prompt: The prompt to generate a response to
            max_generations: Maximum number of turns in the conversation

        Returns:
            The generated response after completing the conversation
        """
        # Initialize chat state
        chat_state = {
            "messages": [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": build_user_prompt(prompt)},
            ],
            "finished": False,
        }

        # Agent loop
        for i in range(max_generations):
            # Generate response
            chat_state = self._run_agent_generation(chat_state)

            # Check if conversation is finished
            chat_state = self._check_finished_chat(chat_state)
            if chat_state.get("finished"):
                break

            # Process tool calls if any
            chat_state = self._run_tool_calls(chat_state)

        # Get final response
        final_response = chat_state["messages"][-1]["content"]

        # Update history
        self.history.append({"role": "user", "content": prompt})
        self.history.append({"role": "assistant", "content": final_response})

        return final_response

    def _run_agent_generation(self, chat_state: dict) -> dict:
        """Run a single generation step for the agent."""
        formatted_prompt = self.tokenizer.apply_chat_template(
            chat_state["messages"],
            tokenize=False,
            add_generation_prompt=True,
        )

        start_time = time.time()
        if self.lora_request:
            response = self.model.fast_generate(
                [formatted_prompt],
                sampling_params=self.sampling_params,
                lora_request=self.lora_request,
            )
        else:
            response = self.model.fast_generate(
                [formatted_prompt],
                sampling_params=self.sampling_params,
            )

        gen_time = time.time() - start_time
        logger.debug(f"Generation completed in {gen_time:.2f} seconds")

        if hasattr(response[0], "outputs"):
            response_text = response[0].outputs[0].text
        else:
            response_text = response[0]

        # Extract assistant response
        assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]

        chat_state["messages"].append({"role": "assistant", "content": assistant_response})

        return chat_state

    def _check_finished_chat(self, chat_state: dict) -> dict:
        """Check if the chat is finished (no more function calls)."""
        if chat_state.get("finished"):
            return chat_state

        assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"

        assistant_response = chat_state["messages"][-1]["content"]
        function_calls = extract_json_objects(assistant_response)

        if len(function_calls) == 0:
            chat_state["finished"] = True

        return chat_state

    def _run_tool_calls(self, chat_state: dict) -> dict:
        """Execute tool calls found in chat state."""
        if chat_state.get("finished"):
            return chat_state

        try:
            assistant_response = chat_state["messages"][-1]["content"]
            function_calls = extract_json_objects(assistant_response)

            if len(function_calls) > 1:
                logger.warning("Multiple function calls found in assistant response")
                raise ValueError("Expected only one function call in assistant response")

            elif len(function_calls) == 1:
                function_call = function_calls[0]
                query = function_call["function"]["parameters"]["query"]
                logger.info(f"šŸ” Search Query: {query}")

                results = search(query, return_type=str, results=2)

                # Print search results to terminal
                # logger.info("\n===== SEARCH RESULTS =====")
                # logger.info(
                #     results
                # )  # The results are already formatted with Result 1:, Result 2:, etc.
                # logger.info("===========================\n")

                chat_state["messages"].append({"role": "ipython", "content": results})

                # Record search in history
                search_entry = {
                    "turn": len(self.history) // 2,
                    "searches": [{"query": query, "results": results}],
                }
                self.search_history.append(search_entry)

        except Exception as e:
            logger.error(f"Error during tool call: {str(e)}")
            chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
            chat_state["finished"] = True

        return chat_state

    def clear_history(self):
        """Clear the conversation history."""
        self.history = []
        self.search_history = []
        logger.info("Conversation history cleared.")

    def set_system_prompt(self, prompt: str):
        """
        Set a new system prompt.

        Args:
            prompt: The new system prompt
        """
        if not prompt:
            logger.warning("System prompt cannot be empty. Using default.")
            return

        self.system_prompt = prompt
        logger.info("System prompt updated.")
        logger.info(f"New system prompt: {self.system_prompt}")

    def display_welcome(self):
        """Display welcome message."""
        model_type = "LoRA" if self.lora_path else "Base"
        logger.info(f"\n{'=' * 50}")
        logger.info(f"DeepSearch CLI - {model_type} Model")
        logger.info(f"Model: {MODEL_NAME}")
        logger.info(f"Temperature: {self.temperature}")
        if self.lora_path:
            logger.info(f"LoRA Path: {self.lora_path}")
        logger.info(f"System Prompt: {self.system_prompt}")
        logger.info(f"{'=' * 50}")
        logger.info("Type 'help' to see available commands.")

    def print_pretty_chat_history(self):
        """Print the full chat history in a pretty format, including searches."""
        if not self.history:
            logger.info("No chat history available.")
            return

        logger.info("\n" + "=" * 80)
        logger.info("CHAT HISTORY WITH SEARCH DETAILS")
        logger.info("=" * 80)

        # Group history into conversation turns
        for i in range(0, len(self.history), 2):
            turn_number = i // 2

            # Print user message
            if i < len(self.history):
                user_msg = self.history[i]["content"]
                logger.info(f"\n[Turn {turn_number + 1}] USER: ")
                logger.info("-" * 40)
                logger.info(user_msg)

            # Print searches associated with this turn if any
            for search_entry in self.search_history:
                if search_entry["turn"] == turn_number:
                    for idx, search in enumerate(search_entry["searches"]):
                        logger.info(f'\nšŸ” SEARCH {idx + 1}: "{search["query"]}"')
                        logger.info("-" * 40)
                        logger.info(search["results"])

            # Print assistant response
            if i + 1 < len(self.history):
                assistant_msg = self.history[i + 1]["content"]
                logger.info(f"\n[Turn {turn_number + 1}] ASSISTANT: ")
                logger.info("-" * 40)
                logger.info(assistant_msg)

        logger.info("\n" + "=" * 80 + "\n")

    def save_chat_history(self, filepath=None):
        """
        Save chat history to a file.

        Args:
            filepath: Path to save file (if None, auto-generate based on timestamp)

        Returns:
            Path to the saved file
        """
        if not self.history:
            logger.info("No chat history to save.")
            return None

        # Generate a default filepath if none provided
        if filepath is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_type = "lora" if self.lora_path else "base"
            filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.txt")

            # Ensure the directory exists
            os.makedirs(os.path.dirname(filepath), exist_ok=True)

        # Prepare chat history data
        pretty_history = []

        # Group history into conversation turns
        for i in range(0, len(self.history), 2):
            turn_number = i // 2
            turn_data = {
                "turn": turn_number + 1,
                "user": self.history[i]["content"] if i < len(self.history) else "",
                "searches": [],
                "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
            }

            # Add searches for this turn
            for search_entry in self.search_history:
                if search_entry["turn"] == turn_number:
                    turn_data["searches"].extend(search_entry["searches"])

            pretty_history.append(turn_data)

        # Write to file
        try:
            with open(filepath, "w", encoding="utf-8") as f:
                f.write(f"{'=' * 80}\n")
                f.write("DEEPSEARCH CHAT HISTORY\n")
                f.write(f"Model: {MODEL_NAME}\n")
                f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n")
                f.write(f"Temperature: {self.temperature}\n")
                f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                f.write(f"{'=' * 80}\n\n")

                for turn in pretty_history:
                    f.write(f"[Turn {turn['turn']}] USER:\n")
                    f.write(f"{'-' * 40}\n")
                    f.write(f"{turn['user']}\n\n")

                    # Write searches
                    for i, search in enumerate(turn["searches"]):
                        f.write(f'šŸ” SEARCH {i + 1}: "{search["query"]}"\n')
                        f.write(f"{'-' * 40}\n")
                        f.write(f"{search['results']}\n\n")

                    f.write(f"[Turn {turn['turn']}] ASSISTANT:\n")
                    f.write(f"{'-' * 40}\n")
                    f.write(f"{turn['assistant']}\n\n")
                    f.write(f"{'=' * 40}\n\n")

            logger.info(f"Chat history saved to: {filepath}")
            return filepath

        except Exception as e:
            logger.error(f"Error saving chat history: {e}")
            return None

    def save_chat_history_json(self, filepath=None):
        """
        Save chat history to a JSON file.

        Args:
            filepath: Path to save file (if None, auto-generate based on timestamp)

        Returns:
            Path to the saved file
        """
        if not self.history:
            logger.info("No chat history to save.")
            return None

        # Generate a default filepath if none provided
        if filepath is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_type = "lora" if self.lora_path else "base"
            filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.json")

            # Ensure the directory exists
            os.makedirs(os.path.dirname(filepath), exist_ok=True)

        # Prepare chat history data
        history_data = {
            "model": MODEL_NAME,
            "lora_path": self.lora_path if self.lora_path else None,
            "temperature": self.temperature,
            "timestamp": datetime.now().isoformat(),
            "turns": [],
        }

        # Group history into conversation turns
        for i in range(0, len(self.history), 2):
            turn_number = i // 2
            turn_data = {
                "turn": turn_number + 1,
                "user": self.history[i]["content"] if i < len(self.history) else "",
                "searches": [],
                "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "",
            }

            # Add searches for this turn
            for search_entry in self.search_history:
                if search_entry["turn"] == turn_number:
                    turn_data["searches"].extend(search_entry["searches"])

            history_data["turns"].append(turn_data)

        # Write to file
        try:
            with open(filepath, "w", encoding="utf-8") as f:
                json.dump(history_data, f, indent=2, ensure_ascii=False)

            logger.info(f"Chat history saved to JSON: {filepath}")
            return filepath

        except Exception as e:
            logger.error(f"Error saving chat history to JSON: {e}")
            return None

    def display_help(self):
        """Display help information."""
        logger.info("\n===== Commands =====")
        logger.info("search <query>  - Search for information")
        logger.info("system <prompt> - Set a new system prompt")
        logger.info("clear           - Clear conversation history")
        logger.info("history         - Display full chat history with searches")
        logger.info("save            - Save chat history to a text file")
        logger.info("savejson        - Save chat history to a JSON file")
        logger.info("help            - Display this help message")
        logger.info("exit/quit       - Exit the program")
        logger.info("Any other input will be treated as a prompt to the model.")
        logger.info("===================\n")

    def run(self):
        """Run the CLI."""
        self.display_welcome()

        while True:
            try:
                user_input = input("\n> ").strip()

                if not user_input:
                    continue

                if user_input.lower() in ["exit", "quit"]:
                    logger.info("Exiting...")
                    break

                if user_input.lower() == "help":
                    self.display_help()
                    continue

                if user_input.lower() == "clear":
                    self.clear_history()
                    continue

                if user_input.lower() == "history":
                    self.print_pretty_chat_history()
                    continue

                if user_input.lower() == "save":
                    self.save_chat_history()
                    continue

                if user_input.lower() == "savejson":
                    self.save_chat_history_json()
                    continue

                if user_input.lower().startswith("system "):
                    new_prompt = user_input[7:].strip()
                    self.set_system_prompt(new_prompt)
                    continue

                if user_input.lower().startswith("search "):
                    query = user_input[7:].strip()
                    if query:
                        try:
                            results = search(query, return_type=str)
                            formatted_results = format_search_results(results)
                            logger.info(formatted_results)

                            # Add to search history
                            search_entry = {
                                "turn": len(self.history) // 2,
                                "searches": [{"query": query, "results": results}],
                            }
                            self.search_history.append(search_entry)
                        except Exception as e:
                            logger.error(f"Error searching: {e}")
                    else:
                        logger.warning("Please provide a search query.")
                    continue

                # Process as a prompt to the model
                logger.info("\nGenerating response...")
                response = self.generate(user_input)
                logger.info("\n----- Response -----")
                logger.info(response)

            except KeyboardInterrupt:
                logger.info("\nExiting...")
                break
            except Exception as e:
                logger.error(f"Error: {e}")


def extract_json_objects(text):
    """
    Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.

    Args:
        text (str): The input text possibly containing JSON objects.

    Returns:
        list: A list of parsed JSON objects (dictionaries) extracted from the text.
    """
    results = []
    length = len(text)
    i = 0

    while i < length:
        # Look for the start of a JSON object
        if text[i] == "{":
            start = i
            stack = 1
            i += 1
            # Continue until we find the matching closing brace
            while i < length and stack > 0:
                if text[i] == "{":
                    stack += 1
                elif text[i] == "}":
                    stack -= 1
                i += 1
            # Only attempt to decode if the braces are balanced
            if stack == 0:
                candidate = text[start:i]
                try:
                    obj = json.loads(candidate)
                    # Optionally, ensure it's a dictionary if that's what you expect
                    if isinstance(obj, dict):
                        results.append(obj)
                except json.JSONDecodeError:
                    # If it's not valid JSON, skip it.
                    pass
        else:
            i += 1
    return results


# Tool definition for search corpus
SEARCH_TOOL_DEFINITION = {
    "type": "function",
    "function": {
        "name": "search_corpus",
        "description": "Search over the knowledge corpus with a given query",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "The query to search the knowledge corpus with",
                },
            },
            "required": ["query"],
        },
    },
}


def main():
    """Main function."""
    parser = argparse.ArgumentParser(description="DeepSearch CLI")
    parser.add_argument(
        "--lora_path",
        type=str,
        default="auto",
        help="Path to LoRA weights (None for base model, 'auto' for auto-detection)",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="Sampling temperature (default: 0.7)",
    )
    parser.add_argument(
        "--system_prompt",
        type=str,
        default=None,
        help="System prompt to guide model behavior",
    )
    args = parser.parse_args()

    # Auto-detect LoRA path if requested
    lora_path = None
    if args.lora_path and args.lora_path.lower() != "none":
        if args.lora_path == "auto":
            detected_path = find_latest_checkpoint()
            if detected_path:
                lora_path = detected_path
                logger.info(f"Auto-detected LoRA path: {lora_path}")
            else:
                logger.warning("No LoRA checkpoint found. Using base model.")
        else:
            lora_path = args.lora_path

    # Initialize and run the CLI
    cli = DeepSearchCLI(
        lora_path=lora_path,
        temperature=args.temperature,
        system_prompt=args.system_prompt,
    )
    cli.run()


if __name__ == "__main__":
    # Ensure the vectorstore is loaded
    if load_vectorstore() is None:
        logger.warning("FAISS vectorstore could not be loaded. Search functionality may not work.")

    main()