You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
5.9 KiB

# app/main.py
from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List
from . import models, schemas, database, agent, tools
from .tools import load_tools, add_tool, tools_registry
from .agent import create_agent
from datetime import datetime
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create the database tables
models.Base.metadata.create_all(bind=database.engine)
app = FastAPI(title="LangChain FastAPI Service")
# Dependency to get DB session
def get_db():
db = database.SessionLocal()
try:
yield db
finally:
db.close()
# Load existing tools at startup
@app.on_event("startup")
def startup_event():
load_tools()
# Helper function to create LangChain tools from registry
def get_langchain_tools() -> List:
tools_list = []
for name, func in tools_registry.items():
tool = schemas.ToolOut(name=name, description=func.__doc__ or "No description provided.")
lc_tool = agent.Tool(
name=name,
func=func,
description=tool.description
)
tools_list.append(lc_tool)
return tools_list
# Endpoint to submit a query
@app.post("/query", response_model=schemas.QueryResponse)
def submit_query(request: schemas.QueryRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""
Submit a user query to the agent.
"""
user_input = request.input
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Save the question to the database
question = models.QuestionModel(query=user_input, answer="", timestamp=timestamp)
db.add(question)
db.commit()
db.refresh(question)
# Define background task for processing
background_tasks.add_task(process_query, question.id, user_input, db)
return schemas.QueryResponse(output="Your request is being processed.")
def process_query(question_id: int, user_input: str, db: Session):
"""
Process the user query and save the answer.
"""
try:
# Create agent with current tools
lc_tools = []
for name, func in tools_registry.items():
tool = agent.Tool(
name=name,
func=func,
description=func.__doc__ or "No description provided."
)
lc_tools.append(tool)
langchain_agent = agent.create_agent(lc_tools)
# Invoke the agent
response = langchain_agent({"input": user_input})
answer = response["output"]
# Update the question with the answer
db_question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
db_question.answer = answer
db.commit()
# Optionally, save each step as separate answers
answer_model = models.AnswerModel(question_id=question_id, content=answer, timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
db.add(answer_model)
db.commit()
logger.info(f"Processed query {question_id}: {user_input} -> {answer}")
except Exception as e:
logger.error(f"Error processing query {question_id}: {e}")
db_question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
db_question.answer = f"Error processing the query: {e}"
db.commit()
# Endpoint to register a new tool
@app.post("/tools", response_model=schemas.ToolOut)
def register_tool(tool: schemas.ToolCreate, db: Session = Depends(get_db)):
"""
Register a new tool.
"""
# Check if tool with the same name exists
existing_tool = db.query(models.ToolModel).filter(models.ToolModel.name == tool.name).first()
if existing_tool:
raise HTTPException(status_code=400, detail="Tool with this name already exists.")
# Create a new tool
new_tool = models.ToolModel(
name=tool.name,
description=tool.description,
function_code=tool.function_code
)
db.add(new_tool)
db.commit()
db.refresh(new_tool)
# Add to the registry
add_tool(new_tool)
return schemas.ToolOut(id=new_tool.id, name=new_tool.name, description=new_tool.description)
# Endpoint to list all tools
@app.get("/tools", response_model=List[schemas.ToolOut])
def list_tools(db: Session = Depends(get_db)):
"""
List all registered tools.
"""
tool_models = db.query(models.ToolModel).all()
return [schemas.ToolOut(id=tool.id, name=tool.name, description=tool.description) for tool in tool_models]
# Endpoint to get a specific tool by ID
@app.get("/tools/{tool_id}", response_model=schemas.ToolOut)
def get_tool(tool_id: int, db: Session = Depends(get_db)):
"""
Get details of a specific tool.
"""
tool = db.query(models.ToolModel).filter(models.ToolModel.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="Tool not found.")
return schemas.ToolOut(id=tool.id, name=tool.name, description=tool.description)
@app.get("/answer/{question_id}", response_model=schemas.AnswerResponse)
def get_answer(question_id: int, db: Session = Depends(get_db)):
"""
Получить ответ по ID вопроса.
"""
# Поиск вопроса в базе данных по ID
question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
if not question:
raise HTTPException(status_code=404, detail="Вопрос не найден.")
if not question.answer:
return schemas.AnswerResponse(
id=question.id,
query=question.query,
answer="Ответ ещё обрабатывается.",
timestamp=question.timestamp
)
return schemas.AnswerResponse(
id=question.id,
query=question.query,
answer=question.answer,
timestamp=question.timestamp
)