Добавлено начальное окружение и модели для работы с базой данных, схемы запросов и ответов, а также функции для загрузки и регистрации инструментов.
parent
f7174411e4
commit
d7399ef4ef
@ -0,0 +1,3 @@
|
|||||||
|
LM_STUDIO_API_URL=http://192.168.0.104:1234/v1/chat/completions
|
||||||
|
LM_STUDIO_MODEL=lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
||||||
|
DATABASE_URL=sqlite:///./app.db
|
@ -0,0 +1,113 @@
|
|||||||
|
# Langchain multi-tool LLM service
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
> First, configure `.env` file for your LM Studio MODEL and HOST!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
### List of the Tools
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GET /tools
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"name": "string",
|
||||||
|
"description": "string"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create a Tool
|
||||||
|
|
||||||
|
```bash
|
||||||
|
POST /tools
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "Calculator",
|
||||||
|
"description": "Useful for performing mathematical calculations. Input should be a valid mathematical expression.",
|
||||||
|
"function_code": "def tool_function(input: str) -> str:\n try:\n aeval = Interpreter()\n result = aeval(input)\n return str(result)\n except Exception as e:\n return f\"Error evaluating expression: {e}\""
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
<Object>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"name": "string",
|
||||||
|
"description": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Get the Tool
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GET /tools/{id}
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
<Object>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"name": "string",
|
||||||
|
"description": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Submit a Query
|
||||||
|
|
||||||
|
```bash
|
||||||
|
POST /query
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"input": "What is the capital of France and what is 15 multiplied by 3?"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"output": "Your request is being processed."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Get processed Answer
|
||||||
|
|
||||||
|
```bash
|
||||||
|
GET /answer/{question_id}
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"query": "string",
|
||||||
|
"answer": "string",
|
||||||
|
"timestamp": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,160 @@
|
|||||||
|
import requests
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.agents import initialize_agent, AgentType, Tool
|
||||||
|
from pydantic import Field
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from datetime import datetime
|
||||||
|
import wikipedia
|
||||||
|
from asteval import Interpreter # For a safer calculator
|
||||||
|
import logging
|
||||||
|
from .tools import tools_registry
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
LM_STUDIO_API_URL = os.getenv("LM_STUDIO_API_URL", "http://192.168.0.104:1234/v1/chat/completions")
|
||||||
|
MODEL_NAME = os.getenv("LM_STUDIO_MODEL", "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
|
||||||
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
|
class LMStudioLLM(LLM):
|
||||||
|
"""
|
||||||
|
Custom LangChain LLM to interface with LM Studio API.
|
||||||
|
"""
|
||||||
|
api_url: str = Field(default=LM_STUDIO_API_URL, description="The API endpoint for LM Studio.")
|
||||||
|
model: str = Field(default=MODEL_NAME, description="The model path/name.")
|
||||||
|
temperature: float = Field(default=0.7, description="Sampling temperature.")
|
||||||
|
max_tokens: Optional[int] = Field(default=4096, description="Maximum number of tokens to generate.")
|
||||||
|
streaming: bool = Field(default=False, alias="stream", description="Whether to use streaming responses.")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
populate_by_name = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "lmstudio"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifying_params(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_url": self.api_url,
|
||||||
|
"model": self.model,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""
|
||||||
|
Generate a response from the LM Studio model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The input prompt.
|
||||||
|
stop (Optional[List[str]]): Stop sequences.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response.
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Content-Type": CONTENT_TYPE,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens if self.max_tokens is not None else -1,
|
||||||
|
"stream": self.streaming, # Uses alias 'stream'
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Payload: {payload}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.api_url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=60,
|
||||||
|
stream=self.streaming
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.info(f"Response content: {response.text}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Failed to connect to LM Studio API: {e}")
|
||||||
|
raise RuntimeError(f"Failed to connect to LM Studio API: {e}")
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
return self._handle_stream(response)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
response_json = response.json()
|
||||||
|
choices = response_json.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
raise ValueError("No choices found in the response.")
|
||||||
|
|
||||||
|
# Extract the first response's content
|
||||||
|
content = choices[0].get("message", {}).get("content", "")
|
||||||
|
return content.strip()
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
logger.error(f"Invalid response format: {e}")
|
||||||
|
raise RuntimeError(f"Invalid response format: {e}")
|
||||||
|
|
||||||
|
def _handle_stream(self, response: requests.Response) -> str:
|
||||||
|
"""
|
||||||
|
Process streaming responses from the LM Studio API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (requests.Response): The streaming response object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The concatenated content from the stream.
|
||||||
|
"""
|
||||||
|
content = ""
|
||||||
|
try:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
decoded_line = line.decode('utf-8')
|
||||||
|
if decoded_line.startswith("data: "):
|
||||||
|
data = decoded_line[6:]
|
||||||
|
if data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
json_data = requests.utils.json.loads(data)
|
||||||
|
choices = json_data.get("choices", [])
|
||||||
|
for chunk in choices:
|
||||||
|
delta = chunk.get("delta", {})
|
||||||
|
piece = delta.get("content", "")
|
||||||
|
content += piece
|
||||||
|
except requests.utils.json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return content.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing streaming response: {e}")
|
||||||
|
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||||
|
|
||||||
|
def create_agent(tools: List[Tool]) -> Any:
|
||||||
|
"""
|
||||||
|
Initialize the LangChain agent with the provided tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools (List[Tool]): List of LangChain Tool objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: Initialized agent.
|
||||||
|
"""
|
||||||
|
llm = LMStudioLLM()
|
||||||
|
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools=tools,
|
||||||
|
llm=llm,
|
||||||
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
verbose=False,
|
||||||
|
handle_parsing_errors=True,
|
||||||
|
)
|
||||||
|
return agent
|
@ -0,0 +1,15 @@
|
|||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./app.db")
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}
|
||||||
|
)
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
Base = declarative_base()
|
@ -0,0 +1,175 @@
|
|||||||
|
# app/main.py
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List
|
||||||
|
from . import models, schemas, database, agent, tools
|
||||||
|
from .tools import load_tools, add_tool, tools_registry
|
||||||
|
from .agent import create_agent
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Initialize logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create the database tables
|
||||||
|
models.Base.metadata.create_all(bind=database.engine)
|
||||||
|
|
||||||
|
app = FastAPI(title="LangChain FastAPI Service")
|
||||||
|
|
||||||
|
# Dependency to get DB session
|
||||||
|
def get_db():
|
||||||
|
db = database.SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Load existing tools at startup
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup_event():
|
||||||
|
load_tools()
|
||||||
|
|
||||||
|
# Helper function to create LangChain tools from registry
|
||||||
|
def get_langchain_tools() -> List:
|
||||||
|
tools_list = []
|
||||||
|
for name, func in tools_registry.items():
|
||||||
|
tool = schemas.ToolOut(name=name, description=func.__doc__ or "No description provided.")
|
||||||
|
lc_tool = agent.Tool(
|
||||||
|
name=name,
|
||||||
|
func=func,
|
||||||
|
description=tool.description
|
||||||
|
)
|
||||||
|
tools_list.append(lc_tool)
|
||||||
|
return tools_list
|
||||||
|
|
||||||
|
# Endpoint to submit a query
|
||||||
|
@app.post("/query", response_model=schemas.QueryResponse)
|
||||||
|
def submit_query(request: schemas.QueryRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Submit a user query to the agent.
|
||||||
|
"""
|
||||||
|
user_input = request.input
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Save the question to the database
|
||||||
|
question = models.QuestionModel(query=user_input, answer="", timestamp=timestamp)
|
||||||
|
db.add(question)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(question)
|
||||||
|
|
||||||
|
# Define background task for processing
|
||||||
|
background_tasks.add_task(process_query, question.id, user_input, db)
|
||||||
|
|
||||||
|
return schemas.QueryResponse(output="Your request is being processed.")
|
||||||
|
|
||||||
|
def process_query(question_id: int, user_input: str, db: Session):
|
||||||
|
"""
|
||||||
|
Process the user query and save the answer.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create agent with current tools
|
||||||
|
lc_tools = []
|
||||||
|
for name, func in tools_registry.items():
|
||||||
|
tool = agent.Tool(
|
||||||
|
name=name,
|
||||||
|
func=func,
|
||||||
|
description=func.__doc__ or "No description provided."
|
||||||
|
)
|
||||||
|
lc_tools.append(tool)
|
||||||
|
|
||||||
|
langchain_agent = agent.create_agent(lc_tools)
|
||||||
|
|
||||||
|
# Invoke the agent
|
||||||
|
response = langchain_agent({"input": user_input})
|
||||||
|
answer = response["output"]
|
||||||
|
|
||||||
|
# Update the question with the answer
|
||||||
|
db_question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
|
||||||
|
db_question.answer = answer
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Optionally, save each step as separate answers
|
||||||
|
answer_model = models.AnswerModel(question_id=question_id, content=answer, timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
db.add(answer_model)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Processed query {question_id}: {user_input} -> {answer}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing query {question_id}: {e}")
|
||||||
|
db_question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
|
||||||
|
db_question.answer = f"Error processing the query: {e}"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Endpoint to register a new tool
|
||||||
|
@app.post("/tools", response_model=schemas.ToolOut)
|
||||||
|
def register_tool(tool: schemas.ToolCreate, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Register a new tool.
|
||||||
|
"""
|
||||||
|
# Check if tool with the same name exists
|
||||||
|
existing_tool = db.query(models.ToolModel).filter(models.ToolModel.name == tool.name).first()
|
||||||
|
if existing_tool:
|
||||||
|
raise HTTPException(status_code=400, detail="Tool with this name already exists.")
|
||||||
|
|
||||||
|
# Create a new tool
|
||||||
|
new_tool = models.ToolModel(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
function_code=tool.function_code
|
||||||
|
)
|
||||||
|
db.add(new_tool)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(new_tool)
|
||||||
|
|
||||||
|
# Add to the registry
|
||||||
|
add_tool(new_tool)
|
||||||
|
|
||||||
|
return schemas.ToolOut(id=new_tool.id, name=new_tool.name, description=new_tool.description)
|
||||||
|
|
||||||
|
# Endpoint to list all tools
|
||||||
|
@app.get("/tools", response_model=List[schemas.ToolOut])
|
||||||
|
def list_tools(db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
List all registered tools.
|
||||||
|
"""
|
||||||
|
tool_models = db.query(models.ToolModel).all()
|
||||||
|
return [schemas.ToolOut(id=tool.id, name=tool.name, description=tool.description) for tool in tool_models]
|
||||||
|
|
||||||
|
# Endpoint to get a specific tool by ID
|
||||||
|
@app.get("/tools/{tool_id}", response_model=schemas.ToolOut)
|
||||||
|
def get_tool(tool_id: int, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Get details of a specific tool.
|
||||||
|
"""
|
||||||
|
tool = db.query(models.ToolModel).filter(models.ToolModel.id == tool_id).first()
|
||||||
|
if not tool:
|
||||||
|
raise HTTPException(status_code=404, detail="Tool not found.")
|
||||||
|
return schemas.ToolOut(id=tool.id, name=tool.name, description=tool.description)
|
||||||
|
|
||||||
|
@app.get("/answer/{question_id}", response_model=schemas.AnswerResponse)
|
||||||
|
def get_answer(question_id: int, db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Получить ответ по ID вопроса.
|
||||||
|
"""
|
||||||
|
# Поиск вопроса в базе данных по ID
|
||||||
|
question = db.query(models.QuestionModel).filter(models.QuestionModel.id == question_id).first()
|
||||||
|
|
||||||
|
if not question:
|
||||||
|
raise HTTPException(status_code=404, detail="Вопрос не найден.")
|
||||||
|
|
||||||
|
if not question.answer:
|
||||||
|
return schemas.AnswerResponse(
|
||||||
|
id=question.id,
|
||||||
|
query=question.query,
|
||||||
|
answer="Ответ ещё обрабатывается.",
|
||||||
|
timestamp=question.timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
return schemas.AnswerResponse(
|
||||||
|
id=question.id,
|
||||||
|
query=question.query,
|
||||||
|
answer=question.answer,
|
||||||
|
timestamp=question.timestamp
|
||||||
|
)
|
@ -0,0 +1,33 @@
|
|||||||
|
# app/models.py
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, ForeignKey
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from .database import Base
|
||||||
|
|
||||||
|
class ToolModel(Base):
|
||||||
|
__tablename__ = "tools"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, unique=True, index=True, nullable=False)
|
||||||
|
description = Column(Text, nullable=False)
|
||||||
|
function_code = Column(Text, nullable=False) # Store function as code (for simplicity)
|
||||||
|
|
||||||
|
class QuestionModel(Base):
|
||||||
|
__tablename__ = "questions"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
query = Column(Text, nullable=False)
|
||||||
|
answer = Column(Text, nullable=False)
|
||||||
|
timestamp = Column(String, nullable=False)
|
||||||
|
|
||||||
|
class AnswerModel(Base):
|
||||||
|
__tablename__ = "answers"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
question_id = Column(Integer, ForeignKey("questions.id"))
|
||||||
|
content = Column(Text, nullable=False)
|
||||||
|
timestamp = Column(String, nullable=False)
|
||||||
|
|
||||||
|
question = relationship("QuestionModel", back_populates="answers")
|
||||||
|
|
||||||
|
QuestionModel.answers = relationship("AnswerModel", back_populates="question")
|
@ -0,0 +1,30 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class ToolCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
function_code: str # The Python code for the tool function
|
||||||
|
|
||||||
|
class ToolOut(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
input: str
|
||||||
|
|
||||||
|
class QueryResponse(BaseModel):
|
||||||
|
output: str
|
||||||
|
|
||||||
|
class AnswerResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
query: str
|
||||||
|
answer: str
|
||||||
|
timestamp: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
@ -0,0 +1,55 @@
|
|||||||
|
# app/tools.py
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Dict
|
||||||
|
from .database import SessionLocal
|
||||||
|
from . import models
|
||||||
|
import wikipedia
|
||||||
|
from asteval import Interpreter # Ensure necessary imports are available
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Dictionary to store tool functions
|
||||||
|
tools_registry: Dict[str, Callable[[str], str]] = {}
|
||||||
|
|
||||||
|
def load_tools():
|
||||||
|
"""
|
||||||
|
Load tools from the database and register them.
|
||||||
|
Assumes that function_code contains the body of the function.
|
||||||
|
"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
tool_models = db.query(models.ToolModel).all()
|
||||||
|
for tool in tool_models:
|
||||||
|
if tool.name not in tools_registry:
|
||||||
|
# Dynamically create function from code
|
||||||
|
try:
|
||||||
|
namespace = {}
|
||||||
|
exec(tool.function_code, globals(), namespace)
|
||||||
|
func = namespace.get('tool_function')
|
||||||
|
if func:
|
||||||
|
tools_registry[tool.name] = func
|
||||||
|
logger.info(f"Loaded tool: {tool.name}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Function 'tool_function' not defined in tool: {tool.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading tool {tool.name}: {e}")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def add_tool(tool: models.ToolModel):
|
||||||
|
"""
|
||||||
|
Add a tool to the registry.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
namespace = {}
|
||||||
|
exec(tool.function_code, globals(), namespace)
|
||||||
|
func = namespace.get('tool_function')
|
||||||
|
if func:
|
||||||
|
tools_registry[tool.name] = func
|
||||||
|
logger.info(f"Registered new tool: {tool.name}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Function 'tool_function' not defined in tool: {tool.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding tool {tool.name}: {e}")
|
Loading…
Reference in new issue