You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
3.8 KiB
114 lines
3.8 KiB
import argparse
|
|
import asyncio
|
|
from collections import deque
|
|
from typing import List, Tuple, Union
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from flashrag.config import Config
|
|
from flashrag.utils import get_retriever
|
|
from pydantic import BaseModel
|
|
|
|
from config import RETRIEVER_SERVER_PORT
|
|
|
|
app = FastAPI()
|
|
|
|
retriever_list = []
|
|
available_retrievers = deque()
|
|
retriever_semaphore = None
|
|
|
|
|
|
def init_retriever(args):
|
|
global retriever_semaphore
|
|
config = Config(args.config)
|
|
for i in range(args.num_retriever):
|
|
print(f"Initializing retriever {i + 1}/{args.num_retriever}")
|
|
retriever = get_retriever(config)
|
|
retriever_list.append(retriever)
|
|
available_retrievers.append(i)
|
|
# create a semaphore to limit the number of retrievers that can be used concurrently
|
|
retriever_semaphore = asyncio.Semaphore(args.num_retriever)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
top_n: int = 10
|
|
return_score: bool = False
|
|
|
|
|
|
class BatchQueryRequest(BaseModel):
|
|
query: List[str]
|
|
top_n: int = 10
|
|
return_score: bool = False
|
|
|
|
|
|
class Document(BaseModel):
|
|
id: str
|
|
contents: str
|
|
|
|
|
|
@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
|
|
async def search(request: QueryRequest):
|
|
query = request.query
|
|
top_n = request.top_n
|
|
return_score = request.return_score
|
|
|
|
if not query or not query.strip():
|
|
print(f"Query content cannot be empty: {query}")
|
|
raise HTTPException(status_code=400, detail="Query content cannot be empty")
|
|
|
|
async with retriever_semaphore:
|
|
retriever_idx = available_retrievers.popleft()
|
|
try:
|
|
if return_score:
|
|
results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
|
|
return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
|
|
else:
|
|
results = retriever_list[retriever_idx].search(query, top_n, return_score)
|
|
return [Document(id=result["id"], contents=result["contents"]) for result in results]
|
|
finally:
|
|
available_retrievers.append(retriever_idx)
|
|
|
|
|
|
@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
|
|
async def batch_search(request: BatchQueryRequest):
|
|
query = request.query
|
|
top_n = request.top_n
|
|
return_score = request.return_score
|
|
|
|
async with retriever_semaphore:
|
|
retriever_idx = available_retrievers.popleft()
|
|
try:
|
|
if return_score:
|
|
results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
|
return [
|
|
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
|
for i in range(len(results))
|
|
], scores
|
|
else:
|
|
results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
|
return [
|
|
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
|
for i in range(len(results))
|
|
]
|
|
finally:
|
|
available_retrievers.append(retriever_idx)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", type=str, default="./retriever_config.yaml")
|
|
parser.add_argument("--num_retriever", type=int, default=1)
|
|
parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
|
|
args = parser.parse_args()
|
|
|
|
init_retriever(args)
|
|
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|