import argparse
import asyncio
from collections import deque
from typing import List, Tuple, Union

from fastapi import FastAPI, HTTPException
from flashrag.config import Config
from flashrag.utils import get_retriever
from pydantic import BaseModel

from config import RETRIEVER_SERVER_PORT

app = FastAPI()

retriever_list = []
available_retrievers = deque()
retriever_semaphore = None


def init_retriever(args):
    global retriever_semaphore
    config = Config(args.config)
    for i in range(args.num_retriever):
        print(f"Initializing retriever {i + 1}/{args.num_retriever}")
        retriever = get_retriever(config)
        retriever_list.append(retriever)
        available_retrievers.append(i)
    # create a semaphore to limit the number of retrievers that can be used concurrently
    retriever_semaphore = asyncio.Semaphore(args.num_retriever)


@app.get("/health")
async def health_check():
    return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}


class QueryRequest(BaseModel):
    query: str
    top_n: int = 10
    return_score: bool = False


class BatchQueryRequest(BaseModel):
    query: List[str]
    top_n: int = 10
    return_score: bool = False


class Document(BaseModel):
    id: str
    contents: str


@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
async def search(request: QueryRequest):
    query = request.query
    top_n = request.top_n
    return_score = request.return_score

    if not query or not query.strip():
        print(f"Query content cannot be empty: {query}")
        raise HTTPException(status_code=400, detail="Query content cannot be empty")

    async with retriever_semaphore:
        retriever_idx = available_retrievers.popleft()
        try:
            if return_score:
                results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
                return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
            else:
                results = retriever_list[retriever_idx].search(query, top_n, return_score)
                return [Document(id=result["id"], contents=result["contents"]) for result in results]
        finally:
            available_retrievers.append(retriever_idx)


@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
async def batch_search(request: BatchQueryRequest):
    query = request.query
    top_n = request.top_n
    return_score = request.return_score

    async with retriever_semaphore:
        retriever_idx = available_retrievers.popleft()
        try:
            if return_score:
                results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
                return [
                    [Document(id=result["id"], contents=result["contents"]) for result in results[i]]
                    for i in range(len(results))
                ], scores
            else:
                results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
                return [
                    [Document(id=result["id"], contents=result["contents"]) for result in results[i]]
                    for i in range(len(results))
                ]
        finally:
            available_retrievers.append(retriever_idx)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="./retriever_config.yaml")
    parser.add_argument("--num_retriever", type=int, default=1)
    parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
    args = parser.parse_args()

    init_retriever(args)

    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=args.port)