You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
5.9 KiB
175 lines
5.9 KiB
1 month ago
|
# app/main.py
|
||
|
|
||
|
from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
|
||
|
from sqlalchemy.orm import Session
|
||
|
from typing import List
|
||
|
from . import models, schemas, database, agent, tools
|
||
|
from .tools import load_tools, add_tool, tools_registry
|
||
|
from .agent import create_agent
|
||
|
from datetime import datetime
|
||
|
import logging
|
||
|
|
||
|
# Initialize logging
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
# Create the database tables
|
||
|
models.Base.metadata.create_all(bind=database.engine)
|
||
|
|
||
|
app = FastAPI(title="LangChain FastAPI Service")
|
||
|
|
||
|
# Dependency to get DB session
|
||
|
def get_db():
|
||
|
db = database.SessionLocal()
|
||
|
try:
|
||
|
yield db
|
||
|
finally:
|
||
|
db.close()
|
||
|
|
||
|
# Load existing tools at startup
|
||
|
@app.on_event("startup")
|
||
|
def startup_event():
|
||
|
load_tools()
|
||
|
|
||
|
# Helper function to create LangChain tools from registry
|
||
|
def get_langchain_tools() -> List:
|
||
|
tools_list = []
|
||
|
for name, func in tools_registry.items():
|
||
|
tool = schemas.ToolOut(name=name, description=func.__doc__ or "No description provided.")
|
||
|
lc_tool = agent.Tool(
|
||
|
name=name,
|
||
|
func=func,
|
||
|
description=tool.description
|
||
|
)
|
||
|
tools_list.append(lc_tool)
|
||
|
return tools_list
|
||
|
|
||
|
# Endpoint to submit a query
|
||
|
@app.post("/query", response_model=schemas.QueryResponse)
|
||
|
def submit_query(request: schemas.QueryRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||
|
"""
|
||
|
Submit a user query to the agent.
|
||
|
"""
|
||
|
user_input = request.input
|
||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
|
||
|
# Save the question to the database
|
||
|
question = models.QuestionModel(query=user_input, answer="", timestamp=timestamp)
|
||
|
db.add(question)
|
||
|
db.commit()
|
||
|
db.refresh(question)
|
||
|
|
||
|
# Define background task for processing
|
||
|
background_tasks.add_task(process_query, question.id, user_input, db)
|
||
|
|
||
|
return schemas.QueryResponse(output="Your request is being processed.")
|
||
|
|
||
|
def process_query(question_id: int, user_input: str, db: Session):
|
||
|
"""
|
||
|
Process the user query and save the answer.
|
||
|
"""
|
||
|
try:
|
||
|
# Create agent with current tools
|
||
|
lc_tools = []
|
||
|
for name, func in tools_registry.items():
|
||
|
tool = agent.Tool(
|
||
|
name=name,
|
||
|
func=func,
|
||
|
description=func.__doc__ or "No description provided."
|
||
|
)
|
||
|
lc_tools.append(tool)
|
||
|
|
||
|
langchain_agent = agent.create_agent(lc_tools)
|
||
|
|
||
|
# Invoke the agent
|
||
|
response = langchain_agent({"input": user_input})
|
||
|
answer = response["output"]
|
||
|
|
||
|
# Update the question with the answer
|
||
|
db_question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
|
||
|
db_question.answer = answer
|
||
|
db.commit()
|
||
|
|
||
|
# Optionally, save each step as separate answers
|
||
|
answer_model = models.AnswerModel(question_id=question_id, content=answer, timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||
|
db.add(answer_model)
|
||
|
db.commit()
|
||
|
|
||
|
logger.info(f"Processed query {question_id}: {user_input} -> {answer}")
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error processing query {question_id}: {e}")
|
||
|
db_question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
|
||
|
db_question.answer = f"Error processing the query: {e}"
|
||
|
db.commit()
|
||
|
|
||
|
# Endpoint to register a new tool
|
||
|
@app.post("/tools", response_model=schemas.ToolOut)
|
||
|
def register_tool(tool: schemas.ToolCreate, db: Session = Depends(get_db)):
|
||
|
"""
|
||
|
Register a new tool.
|
||
|
"""
|
||
|
# Check if tool with the same name exists
|
||
|
existing_tool = db.query(models.ToolModel).filter(models.ToolModel.name == tool.name).first()
|
||
|
if existing_tool:
|
||
|
raise HTTPException(status_code=400, detail="Tool with this name already exists.")
|
||
|
|
||
|
# Create a new tool
|
||
|
new_tool = models.ToolModel(
|
||
|
name=tool.name,
|
||
|
description=tool.description,
|
||
|
function_code=tool.function_code
|
||
|
)
|
||
|
db.add(new_tool)
|
||
|
db.commit()
|
||
|
db.refresh(new_tool)
|
||
|
|
||
|
# Add to the registry
|
||
|
add_tool(new_tool)
|
||
|
|
||
|
return schemas.ToolOut(id=new_tool.id, name=new_tool.name, description=new_tool.description)
|
||
|
|
||
|
# Endpoint to list all tools
|
||
|
@app.get("/tools", response_model=List[schemas.ToolOut])
|
||
|
def list_tools(db: Session = Depends(get_db)):
|
||
|
"""
|
||
|
List all registered tools.
|
||
|
"""
|
||
|
tool_models = db.query(models.ToolModel).all()
|
||
|
return [schemas.ToolOut(id=tool.id, name=tool.name, description=tool.description) for tool in tool_models]
|
||
|
|
||
|
# Endpoint to get a specific tool by ID
|
||
|
@app.get("/tools/{tool_id}", response_model=schemas.ToolOut)
|
||
|
def get_tool(tool_id: int, db: Session = Depends(get_db)):
|
||
|
"""
|
||
|
Get details of a specific tool.
|
||
|
"""
|
||
|
tool = db.query(models.ToolModel).filter(models.ToolModel.id == tool_id).first()
|
||
|
if not tool:
|
||
|
raise HTTPException(status_code=404, detail="Tool not found.")
|
||
|
return schemas.ToolOut(id=tool.id, name=tool.name, description=tool.description)
|
||
|
|
||
|
@app.get("/answer/{question_id}", response_model=schemas.AnswerResponse)
|
||
|
def get_answer(question_id: int, db: Session = Depends(get_db)):
|
||
|
"""
|
||
|
Получить ответ по ID вопроса.
|
||
|
"""
|
||
|
# Поиск вопроса в базе данных по ID
|
||
|
question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
|
||
|
|
||
|
if not question:
|
||
|
raise HTTPException(status_code=404, detail="Вопрос не найден.")
|
||
|
|
||
|
if not question.answer:
|
||
|
return schemas.AnswerResponse(
|
||
|
id=question.id,
|
||
|
query=question.query,
|
||
|
answer="Ответ ещё обрабатывается.",
|
||
|
timestamp=question.timestamp
|
||
|
)
|
||
|
|
||
|
return schemas.AnswerResponse(
|
||
|
id=question.id,
|
||
|
query=question.query,
|
||
|
answer=question.answer,
|
||
|
timestamp=question.timestamp
|
||
|
)
|