structured-output

Pydantic-validation-2
Nicolas Nahas 5 months ago
parent 103d3937d3
commit e95090cbba

@ -0,0 +1,77 @@
"""
* WORKING
What this script does:
Structured output example with validation function
Requirements:
pip install openai
pip install pydantic
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from swarms import Agent, OpenAIChat
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from pydantic import AfterValidator
def symbol_must_exists(symbol= str) -> str:
symbols = [
"AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B",
"JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO"
]
if symbol not in symbols:
raise ValueError(f"symbol must exists in the list: {symbols}")
return symbol
# Initialize the schema for the person's information
class StockInfo(BaseModel):
"""
To create a StockInfo, you need to return a JSON object with the following format:
{
"function_call": "StockInfo",
"parameters": {
...
}
}
"""
name: str = Field(..., title="Name of the company")
description: str = Field(..., title="Description of the company")
symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company")
# Define the task to generate a person's information
task = "Generate an existing S&P500's company information"
# Initialize the agent
agent = Agent(
agent_name="Stock Information Generator",
system_prompt=(
"Generate a public comapany's information"
),
llm=OpenAIChat(),
max_loops=1,
verbose=True,
# List of schemas that the agent can handle
list_base_models=[StockInfo],
output_validation=True,
)
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(f"Generated data: {generated_data}")

@ -0,0 +1,65 @@
"""
* WORKING
What this script does:
Structured output example
Requirements:
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from pydantic import BaseModel, Field
from swarms import Agent, OpenAIChat
# Initialize the schema for the person's information
class PersonInfo(BaseModel):
"""
To create a PersonInfo, you need to return a JSON object with the following format:
{
"function_call": "PersonInfo",
"parameters": {
...
}
}
"""
name: str = Field(..., title="Name of the person")
age: int = Field(..., title="Age of the person")
is_student: bool = Field(..., title="Whether the person is a student")
courses: list[str] = Field(
..., title="List of courses the person is taking"
)
# Initialize the agent
agent = Agent(
agent_name="Person Information Generator",
system_prompt=(
"Generate a person's information"
),
llm=OpenAIChat(),
max_loops=1,
verbose=True,
# List of pydantic models that the agent can use
list_base_models=[PersonInfo],
output_validation=True
)
# Define the task to generate a person's information
task = "Generate a person's information for Paul Graham 56 years old and is a student at Harvard University and is taking 3 courses: Math, Science, and History."
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(type(generated_data))
print(f"Generated data: {generated_data}")

@ -30,4 +30,5 @@ pandas>=2.2.2
fastapi>=0.110.1
networkx
swarms-memory
swarms-cloud
pre-commit

@ -7,7 +7,7 @@ import random
import sys
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
import toml
import yaml
@ -51,6 +51,7 @@ from swarms.tools.tool_parse_exec import parse_and_execute_json
from swarms.utils.data_to_text import data_to_text
from swarms.utils.file_processing import create_file_in_folder
from swarms.utils.parse_code import extract_code_from_markdown
from swarms.tools.json_utils import extract_json_from_str
from swarms.utils.pdf_to_text import pdf_to_text
@ -238,7 +239,8 @@ class Agent:
function_calling_type: str = "json",
output_cleaner: Optional[Callable] = None,
function_calling_format_type: Optional[str] = "OpenAI",
list_base_models: Optional[List[BaseModel]] = None,
list_base_models: Optional[List[Type[BaseModel]]] = None,
output_validation: Optional[bool] = False,
metadata_output_type: str = "json",
state_save_file_type: str = "json",
chain_of_thoughts: bool = False,
@ -333,6 +335,7 @@ class Agent:
self.output_cleaner = output_cleaner
self.function_calling_format_type = function_calling_format_type
self.list_base_models = list_base_models
self.output_validation = output_validation
self.metadata_output_type = metadata_output_type
self.state_save_file_type = state_save_file_type
self.chain_of_thoughts = chain_of_thoughts
@ -577,6 +580,14 @@ class Agent:
print(
colored(f"Error dynamically changing temperature: {error}")
)
def pydantic_validation(self, response:str)-> Type[BaseModel]:
"""Validates the response using Pydantic."""
parsed_json = extract_json_from_str(response)
function_call = parsed_json["function_call"]
parameters = json.dumps(parsed_json["parameters"])
selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None)
return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True)
def format_prompt(self, template, **kwargs: Any) -> str:
"""Format the template with the provided kwargs using f-string interpolation."""
@ -675,6 +686,7 @@ class Agent:
# Clear the short memory
response = None
result = None
all_responses = []
steps_pool = []
@ -717,7 +729,7 @@ class Agent:
if self.streaming_on is True:
response = self.stream_response(response)
else:
print(response)
self.printtier(response)
# Add the response to the memory
self.short_memory.add(
@ -733,8 +745,8 @@ class Agent:
# TODO: Implement reliablity check
if self.tools is not None:
# self.parse_function_call_and_execute(response)
self.parse_and_execute_tools(response)
# self.parse_and_execute_tools(response)
result = self.parse_function_call_and_execute(response)
if self.code_interpreter is True:
# Parse the code and execute
@ -773,6 +785,9 @@ class Agent:
# all_responses.append(evaluated_response)
if self.output_validation:
result = self.pydantic_validation(response)
# Sentiment analysis
if self.sentiment_analyzer:
logger.info("Analyzing sentiment...")
@ -858,6 +873,9 @@ class Agent:
# logger.info(f"Final Response: {final_response}")
if self.return_history:
return self.short_memory.return_history_as_string()
elif self.output_validation:
return result
elif self.return_step_meta:
log = ManySteps(
@ -1303,6 +1321,14 @@ class Agent:
except Exception as error:
print(colored(f"Error retrying function: {error}", "red"))
def printtier(self, response:str) -> str:
"""
Specifies the name of the agent in capital letters in pink and the response text in blue.
Add space above.
"""
print("\n")
return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}")
def update_system_prompt(self, system_prompt: str):
"""Upddate the system message"""
self.system_prompt = system_prompt
@ -1496,7 +1522,7 @@ class Agent:
)
response = self.llm(task_prompt, *args, **kwargs)
print(response)
self.printtier(response)
self.short_memory.add(
role=self.agent_name, content=response
@ -1884,9 +1910,7 @@ class Agent:
full_memory = self.short_memory.return_history_as_string()
prompt_tokens = self.tokenizer.count_tokens(full_memory)
completion_tokens = self.tokenizer.count_tokens(response)
total_tokens = self.tokenizer.count_tokens(
prompt_tokens + completion_tokens
)
total_tokens = prompt_tokens + completion_tokens
logger.info("Logging step metadata...")

@ -30,9 +30,12 @@ def extract_json_from_str(response: str):
Raises:
ValueError: If the string does not contain a valid JSON object.
"""
json_start = response.index("{")
json_end = response.rfind("}")
return json.loads(response[json_start : json_end + 1])
try:
json_start = response.index("{")
json_end = response.rfind("}")
return json.loads(response[json_start : json_end + 1])
except Exception as e:
raise ValueError("No valid JSON structure found in the input string")
def str_to_json(response: str, indent: int = 3):

Loading…
Cancel
Save