You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/playground/agents/tool_agent_pydantic.py

50 lines
1.4 KiB

10 months ago
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
from swarms import ToolAgent
from swarms.utils.json_utils import base_model_to_json
# Load the pre-trained model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"databricks/dolly-v2-12b",
load_in_4bit=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b")
# Initialize the schema for the person's information
class Schema(BaseModel):
name: str = Field(..., title="Name of the person")
agent: int = Field(..., title="Age of the person")
is_student: bool = Field(
..., title="Whether the person is a student"
)
courses: list[str] = Field(
..., title="List of courses the person is taking"
)
# Convert the schema to a JSON string
tool_schema = base_model_to_json(Schema)
# Define the task to generate a person's information
task = (
"Generate a person's information based on the following schema:"
)
# Create an instance of the ToolAgent class
agent = ToolAgent(
name="dolly-function-agent",
description="Ana gent to create a child data",
model=model,
tokenizer=tokenizer,
json_schema=tool_schema,
)
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(f"Generated data: {generated_data}")