pull/343/head
parent
4055db314a
commit
208ab0e344
@ -1,7 +1,11 @@
|
|||||||
from swarms.agents.message import Message
|
from swarms.agents.message import Message
|
||||||
from swarms.agents.base import AbstractAgent
|
from swarms.agents.base import AbstractAgent
|
||||||
|
from swarms.agents.tool_agent import ToolAgent
|
||||||
|
from swarms.agents.simple_agent import SimpleAgent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Message",
|
"Message",
|
||||||
"AbstractAgent",
|
"AbstractAgent",
|
||||||
|
"ToolAgent",
|
||||||
|
"SimpleAgent",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,111 @@
|
|||||||
|
"""
|
||||||
|
Tool Agent
|
||||||
|
|
||||||
|
"""
|
||||||
|
from swarms.tools.format_tools import Jsonformer
|
||||||
|
from typing import Any
|
||||||
|
from swarms.models.base_llm import AbstractLLM
|
||||||
|
|
||||||
|
|
||||||
|
class ToolAgent(AbstractLLM):
|
||||||
|
"""
|
||||||
|
Represents a tool agent that performs a specific task using a model and tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the tool agent.
|
||||||
|
description (str): A description of the tool agent.
|
||||||
|
model (Any): The model used by the tool agent.
|
||||||
|
tokenizer (Any): The tokenizer used by the tool agent.
|
||||||
|
json_schema (Any): The JSON schema used by the tool agent.
|
||||||
|
*args: Variable length arguments.
|
||||||
|
**kwargs: Keyword arguments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the tool agent.
|
||||||
|
description (str): A description of the tool agent.
|
||||||
|
model (Any): The model used by the tool agent.
|
||||||
|
tokenizer (Any): The tokenizer used by the tool agent.
|
||||||
|
json_schema (Any): The JSON schema used by the tool agent.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run: Runs the tool agent for a specific task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs while running the tool agent.
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from swarms import ToolAgent
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b")
|
||||||
|
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"is_student": {"type": "boolean"},
|
||||||
|
"courses": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = "Generate a person's information based on the following schema:"
|
||||||
|
agent = ToolAgent(model, tokenizer, json_schema, prompt)
|
||||||
|
generated_data = ToolAgent()
|
||||||
|
|
||||||
|
print(generated_data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
model: Any,
|
||||||
|
tokenizer: Any,
|
||||||
|
json_schema: Any,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.json_schema = json_schema
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the tool agent for the specified task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be performed by the tool agent.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output of the tool agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs during the execution of the tool agent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.toolagent = Jsonformer(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
self.json_schema,
|
||||||
|
task,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.toolagent()
|
||||||
|
return out
|
||||||
|
except Exception as error:
|
||||||
|
print(f"[Error] [ToolAgent] {error}")
|
||||||
|
raise error
|
@ -0,0 +1,277 @@
|
|||||||
|
from typing import List, Union, Dict, Any
|
||||||
|
|
||||||
|
from swarms.tools.logits_processor import (
|
||||||
|
NumberStoppingCriteria,
|
||||||
|
OutputNumbersTokens,
|
||||||
|
StringStoppingCriteria,
|
||||||
|
)
|
||||||
|
from termcolor import cprint
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
import json
|
||||||
|
|
||||||
|
GENERATION_MARKER = "|GENERATION|"
|
||||||
|
|
||||||
|
|
||||||
|
class Jsonformer:
|
||||||
|
value: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
json_schema: Dict[str, Any],
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
debug: bool = False,
|
||||||
|
max_array_length: int = 10,
|
||||||
|
max_number_tokens: int = 6,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_string_token_length: int = 10,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.json_schema = json_schema
|
||||||
|
self.prompt = prompt
|
||||||
|
|
||||||
|
self.number_logit_processor = OutputNumbersTokens(
|
||||||
|
self.tokenizer, self.prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
self.generation_marker = "|GENERATION|"
|
||||||
|
self.debug_on = debug
|
||||||
|
self.max_array_length = max_array_length
|
||||||
|
|
||||||
|
self.max_number_tokens = max_number_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_string_token_length = max_string_token_length
|
||||||
|
|
||||||
|
def debug(self, caller: str, value: str, is_prompt: bool = False):
|
||||||
|
if self.debug_on:
|
||||||
|
if is_prompt:
|
||||||
|
cprint(caller, "green", end=" ")
|
||||||
|
cprint(value, "yellow")
|
||||||
|
else:
|
||||||
|
cprint(caller, "green", end=" ")
|
||||||
|
cprint(value, "blue")
|
||||||
|
|
||||||
|
def generate_number(
|
||||||
|
self, temperature: Union[float, None] = None, iterations=0
|
||||||
|
):
|
||||||
|
prompt = self.get_prompt()
|
||||||
|
self.debug("[generate_number]", prompt, is_prompt=True)
|
||||||
|
input_tokens = self.tokenizer.encode(
|
||||||
|
prompt, return_tensors="pt"
|
||||||
|
).to(self.model.device)
|
||||||
|
response = self.model.generate(
|
||||||
|
input_tokens,
|
||||||
|
max_new_tokens=self.max_number_tokens,
|
||||||
|
num_return_sequences=1,
|
||||||
|
logits_processor=[self.number_logit_processor],
|
||||||
|
stopping_criteria=[
|
||||||
|
NumberStoppingCriteria(
|
||||||
|
self.tokenizer, len(input_tokens[0])
|
||||||
|
)
|
||||||
|
],
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
response = self.tokenizer.decode(
|
||||||
|
response[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = response[len(prompt) :]
|
||||||
|
response = response.strip().rstrip(".")
|
||||||
|
self.debug("[generate_number]", response)
|
||||||
|
try:
|
||||||
|
return float(response)
|
||||||
|
except ValueError:
|
||||||
|
if iterations > 3:
|
||||||
|
raise ValueError("Failed to generate a valid number")
|
||||||
|
|
||||||
|
return self.generate_number(
|
||||||
|
temperature=self.temperature * 1.3,
|
||||||
|
iterations=iterations + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_boolean(self) -> bool:
|
||||||
|
prompt = self.get_prompt()
|
||||||
|
self.debug("[generate_boolean]", prompt, is_prompt=True)
|
||||||
|
|
||||||
|
input_tensor = self.tokenizer.encode(
|
||||||
|
prompt, return_tensors="pt"
|
||||||
|
)
|
||||||
|
output = self.model.forward(
|
||||||
|
input_tensor.to(self.model.device)
|
||||||
|
)
|
||||||
|
logits = output.logits[0, -1]
|
||||||
|
|
||||||
|
# todo: this assumes that "true" and "false" are both tokenized to a single token
|
||||||
|
# this is probably not true for all tokenizers
|
||||||
|
# this can be fixed by looking at only the first token of both "true" and "false"
|
||||||
|
true_token_id = self.tokenizer.convert_tokens_to_ids("true")
|
||||||
|
false_token_id = self.tokenizer.convert_tokens_to_ids("false")
|
||||||
|
|
||||||
|
result = logits[true_token_id] > logits[false_token_id]
|
||||||
|
|
||||||
|
self.debug("[generate_boolean]", result)
|
||||||
|
|
||||||
|
return result.item()
|
||||||
|
|
||||||
|
def generate_string(self) -> str:
|
||||||
|
prompt = self.get_prompt() + '"'
|
||||||
|
self.debug("[generate_string]", prompt, is_prompt=True)
|
||||||
|
input_tokens = self.tokenizer.encode(
|
||||||
|
prompt, return_tensors="pt"
|
||||||
|
).to(self.model.device)
|
||||||
|
|
||||||
|
response = self.model.generate(
|
||||||
|
input_tokens,
|
||||||
|
max_new_tokens=self.max_string_token_length,
|
||||||
|
num_return_sequences=1,
|
||||||
|
temperature=self.temperature,
|
||||||
|
stopping_criteria=[
|
||||||
|
StringStoppingCriteria(
|
||||||
|
self.tokenizer, len(input_tokens[0])
|
||||||
|
)
|
||||||
|
],
|
||||||
|
pad_token_id=self.tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some models output the prompt as part of the response
|
||||||
|
# This removes the prompt from the response if it is present
|
||||||
|
if (
|
||||||
|
len(response[0]) >= len(input_tokens[0])
|
||||||
|
and (
|
||||||
|
response[0][: len(input_tokens[0])] == input_tokens
|
||||||
|
).all()
|
||||||
|
):
|
||||||
|
response = response[0][len(input_tokens[0]) :]
|
||||||
|
if response.shape[0] == 1:
|
||||||
|
response = response[0]
|
||||||
|
|
||||||
|
response = self.tokenizer.decode(
|
||||||
|
response, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.debug("[generate_string]", "|" + response + "|")
|
||||||
|
|
||||||
|
if response.count('"') < 1:
|
||||||
|
return response
|
||||||
|
|
||||||
|
return response.split('"')[0].strip()
|
||||||
|
|
||||||
|
def generate_object(
|
||||||
|
self, properties: Dict[str, Any], obj: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
for key, schema in properties.items():
|
||||||
|
self.debug("[generate_object] generating value for", key)
|
||||||
|
obj[key] = self.generate_value(schema, obj, key)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def generate_value(
|
||||||
|
self,
|
||||||
|
schema: Dict[str, Any],
|
||||||
|
obj: Union[Dict[str, Any], List[Any]],
|
||||||
|
key: Union[str, None] = None,
|
||||||
|
) -> Any:
|
||||||
|
schema_type = schema["type"]
|
||||||
|
if schema_type == "number":
|
||||||
|
if key:
|
||||||
|
obj[key] = self.generation_marker
|
||||||
|
else:
|
||||||
|
obj.append(self.generation_marker)
|
||||||
|
return self.generate_number()
|
||||||
|
elif schema_type == "boolean":
|
||||||
|
if key:
|
||||||
|
obj[key] = self.generation_marker
|
||||||
|
else:
|
||||||
|
obj.append(self.generation_marker)
|
||||||
|
return self.generate_boolean()
|
||||||
|
elif schema_type == "string":
|
||||||
|
if key:
|
||||||
|
obj[key] = self.generation_marker
|
||||||
|
else:
|
||||||
|
obj.append(self.generation_marker)
|
||||||
|
return self.generate_string()
|
||||||
|
elif schema_type == "array":
|
||||||
|
new_array = []
|
||||||
|
obj[key] = new_array
|
||||||
|
return self.generate_array(schema["items"], new_array)
|
||||||
|
elif schema_type == "object":
|
||||||
|
new_obj = {}
|
||||||
|
if key:
|
||||||
|
obj[key] = new_obj
|
||||||
|
else:
|
||||||
|
obj.append(new_obj)
|
||||||
|
return self.generate_object(schema["properties"], new_obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported schema type: {schema_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_array(
|
||||||
|
self, item_schema: Dict[str, Any], obj: Dict[str, Any]
|
||||||
|
) -> list:
|
||||||
|
for _ in range(self.max_array_length):
|
||||||
|
# forces array to have at least one element
|
||||||
|
element = self.generate_value(item_schema, obj)
|
||||||
|
obj[-1] = element
|
||||||
|
|
||||||
|
obj.append(self.generation_marker)
|
||||||
|
input_prompt = self.get_prompt()
|
||||||
|
obj.pop()
|
||||||
|
input_tensor = self.tokenizer.encode(
|
||||||
|
input_prompt, return_tensors="pt"
|
||||||
|
)
|
||||||
|
output = self.model.forward(
|
||||||
|
input_tensor.to(self.model.device)
|
||||||
|
)
|
||||||
|
logits = output.logits[0, -1]
|
||||||
|
|
||||||
|
top_indices = logits.topk(30).indices
|
||||||
|
sorted_token_ids = top_indices[
|
||||||
|
logits[top_indices].argsort(descending=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
found_comma = False
|
||||||
|
found_close_bracket = False
|
||||||
|
|
||||||
|
for token_id in sorted_token_ids:
|
||||||
|
decoded_token = self.tokenizer.decode(token_id)
|
||||||
|
if "," in decoded_token:
|
||||||
|
found_comma = True
|
||||||
|
break
|
||||||
|
if "]" in decoded_token:
|
||||||
|
found_close_bracket = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found_close_bracket or not found_comma:
|
||||||
|
break
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
template = """{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}"""
|
||||||
|
progress = json.dumps(self.value)
|
||||||
|
gen_marker_index = progress.find(
|
||||||
|
f'"{self.generation_marker}"'
|
||||||
|
)
|
||||||
|
if gen_marker_index != -1:
|
||||||
|
progress = progress[:gen_marker_index]
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to find generation marker")
|
||||||
|
|
||||||
|
prompt = template.format(
|
||||||
|
prompt=self.prompt,
|
||||||
|
schema=json.dumps(self.json_schema),
|
||||||
|
progress=progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def __call__(self) -> Dict[str, Any]:
|
||||||
|
self.value = {}
|
||||||
|
generated_data = self.generate_object(
|
||||||
|
self.json_schema["properties"], self.value
|
||||||
|
)
|
||||||
|
return generated_data
|
@ -0,0 +1,94 @@
|
|||||||
|
from transformers import (
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
LogitsWarper,
|
||||||
|
StoppingCriteria,
|
||||||
|
)
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class StringStoppingCriteria(StoppingCriteria):
|
||||||
|
def __init__(
|
||||||
|
self, tokenizer: PreTrainedTokenizer, prompt_length: int
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.prompt_length = prompt_length
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
_,
|
||||||
|
) -> bool:
|
||||||
|
if len(input_ids[0]) <= self.prompt_length:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_token_id = input_ids[0][-1]
|
||||||
|
last_token = self.tokenizer.decode(
|
||||||
|
last_token_id, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
result = '"' in last_token
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class NumberStoppingCriteria(StoppingCriteria):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
prompt_length: int,
|
||||||
|
precision: int = 3,
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.precision = precision
|
||||||
|
self.prompt_length = prompt_length
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
) -> bool:
|
||||||
|
decoded = self.tokenizer.decode(
|
||||||
|
input_ids[0][self.prompt_length :],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decoded.count(".") > 1:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
decoded.count(".") == 1
|
||||||
|
and len(decoded.strip().split(".")[1]) > self.precision
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(decoded) > 1
|
||||||
|
and any(c.isdigit() for c in decoded)
|
||||||
|
and decoded[-1] in [" ", "\n"]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class OutputNumbersTokens(LogitsWarper):
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizer, prompt: str):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
|
||||||
|
vocab_size = len(tokenizer)
|
||||||
|
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)
|
||||||
|
|
||||||
|
for _, token_id in tokenizer.get_vocab().items():
|
||||||
|
token_str = tokenizer.decode(token_id).strip()
|
||||||
|
|
||||||
|
if token_str == "" or (
|
||||||
|
all(c.isdigit() or c == "." for c in token_str)
|
||||||
|
and token_str.count(".") <= 1
|
||||||
|
):
|
||||||
|
self.allowed_mask[token_id] = True
|
||||||
|
|
||||||
|
def __call__(self, _, scores):
|
||||||
|
mask = self.allowed_mask.expand_as(scores)
|
||||||
|
scores[~mask] = -float("inf")
|
||||||
|
|
||||||
|
return scores
|
@ -0,0 +1,101 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from swarms import ToolAgent
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_agent_init():
|
||||||
|
model = Mock(spec=AutoModelForCausalLM)
|
||||||
|
tokenizer = Mock(spec=AutoTokenizer)
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"is_student": {"type": "boolean"},
|
||||||
|
"courses": {"type": "array", "items": {"type": "string"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
name = "Test Agent"
|
||||||
|
description = "This is a test agent"
|
||||||
|
|
||||||
|
agent = ToolAgent(
|
||||||
|
name, description, model, tokenizer, json_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.name == name
|
||||||
|
assert agent.description == description
|
||||||
|
assert agent.model == model
|
||||||
|
assert agent.tokenizer == tokenizer
|
||||||
|
assert agent.json_schema == json_schema
|
||||||
|
|
||||||
|
|
||||||
|
@patch.object(ToolAgent, "run")
|
||||||
|
def test_tool_agent_run(mock_run):
|
||||||
|
model = Mock(spec=AutoModelForCausalLM)
|
||||||
|
tokenizer = Mock(spec=AutoTokenizer)
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"is_student": {"type": "boolean"},
|
||||||
|
"courses": {"type": "array", "items": {"type": "string"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
name = "Test Agent"
|
||||||
|
description = "This is a test agent"
|
||||||
|
task = (
|
||||||
|
"Generate a person's information based on the following"
|
||||||
|
" schema:"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = ToolAgent(
|
||||||
|
name, description, model, tokenizer, json_schema
|
||||||
|
)
|
||||||
|
agent.run(task)
|
||||||
|
|
||||||
|
mock_run.assert_called_once_with(task)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_agent_init_with_kwargs():
|
||||||
|
model = Mock(spec=AutoModelForCausalLM)
|
||||||
|
tokenizer = Mock(spec=AutoTokenizer)
|
||||||
|
json_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"is_student": {"type": "boolean"},
|
||||||
|
"courses": {"type": "array", "items": {"type": "string"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
name = "Test Agent"
|
||||||
|
description = "This is a test agent"
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"debug": True,
|
||||||
|
"max_array_length": 20,
|
||||||
|
"max_number_tokens": 12,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"max_string_token_length": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
agent = ToolAgent(
|
||||||
|
name, description, model, tokenizer, json_schema, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.name == name
|
||||||
|
assert agent.description == description
|
||||||
|
assert agent.model == model
|
||||||
|
assert agent.tokenizer == tokenizer
|
||||||
|
assert agent.json_schema == json_schema
|
||||||
|
assert agent.debug == kwargs["debug"]
|
||||||
|
assert agent.max_array_length == kwargs["max_array_length"]
|
||||||
|
assert agent.max_number_tokens == kwargs["max_number_tokens"]
|
||||||
|
assert agent.temperature == kwargs["temperature"]
|
||||||
|
assert (
|
||||||
|
agent.max_string_token_length
|
||||||
|
== kwargs["max_string_token_length"]
|
||||||
|
)
|
Loading…
Reference in new issue