parent
7bc9cc2d07
commit
f7be413781
@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class ToolException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BaseTool(ABC):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
|
class Tool(BaseTool):
|
||||||
|
def __init__(self, name: str, description: str, func: Callable[..., Any]):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
except ToolException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return await self.func(*args, **kwargs)
|
||||||
|
except ToolException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
class StructuredTool(BaseTool):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
args_schema: Type[BaseModel],
|
||||||
|
func: Callable[..., Any]
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.args_schema = args_schema
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
except ToolException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return await self.func(*args, **kwargs)
|
||||||
|
except ToolException as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def tool(
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
|
return_direct: bool = False,
|
||||||
|
infer_schema: bool = True
|
||||||
|
) -> Callable:
|
||||||
|
def decorator(func: Callable[..., Any]) -> Union[Tool, StructuredTool]:
|
||||||
|
nonlocal name, description
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
name = func.__name__
|
||||||
|
if description is None:
|
||||||
|
description = func.__doc__ or ""
|
||||||
|
|
||||||
|
if args_schema or infer_schema:
|
||||||
|
if args_schema is None:
|
||||||
|
args_schema = BaseModel
|
||||||
|
|
||||||
|
return StructuredTool(name, description, args_schema, func)
|
||||||
|
else:
|
||||||
|
return Tool(name, description, func)
|
||||||
|
|
||||||
|
return decorator
|
Loading…
Reference in new issue