Pydantic validation

structured_output
Nicolas Nahas 6 months ago
parent 82681bee15
commit 8052cb62d1

@ -0,0 +1,76 @@
"""
* WORKING
What this script does:
Structured output example with validation function
Requirements:
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from swarms import Agent, OpenAIChat
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from pydantic import AfterValidator
def symbol_must_exists(symbol= str) -> str:
symbols = [
"AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B",
"JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO"
]
if symbol not in symbols:
raise ValueError(f"symbol must exists in the list: {symbols}")
return symbol
# Initialize the schema for the person's information
class StockInfo(BaseModel):
"""
To create a StockInfo, you need to return a JSON object with the following format:
{
"function_call": "StockInfo",
"parameters": {
...
}
}
"""
name: str = Field(..., title="Name of the company")
description: str = Field(..., title="Description of the company")
symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company")
# Define the task to generate a person's information
task = "Generate an existing S&P500's company information"
# Initialize the agent
agent = Agent(
agent_name="Stock Information Generator",
system_prompt=(
"Generate a public comapany's information"
),
llm=OpenAIChat(),
max_loops=1,
verbose=True,
# List of schemas that the agent can handle
list_base_models=[StockInfo],
output_validation=True,
)
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(f"Generated data: {generated_data}")

@ -0,0 +1,71 @@
"""
* WORKING
What this script does:
Structured output example
Requirements:
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
Note:
If you are running playground examples in the project files directly (without swarms installed via PIP),
make sure to add the project root to your PYTHONPATH by running the following command in the project's root directory:
'export PYTHONPATH=$(pwd):$PYTHONPATH'
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from pydantic import BaseModel, Field
from swarms import Agent, OpenAIChat
# Initialize the schema for the person's information
class PersonInfo(BaseModel):
"""
To create a PersonInfo, you need to return a JSON object with the following format:
{
"function_call": "PersonInfo",
"parameters": {
...
}
}
"""
name: str = Field(..., title="Name of the person")
agent: int = Field(..., title="Age of the person")
is_student: bool = Field(..., title="Whether the person is a student")
courses: list[str] = Field(
..., title="List of courses the person is taking"
)
# Initialize the agent
agent = Agent(
agent_name="Person Information Generator",
system_prompt=(
"Generate a person's information"
),
llm=OpenAIChat(),
max_loops=1,
verbose=True,
# List of pydantic models that the agent can use
list_base_models=[PersonInfo],
output_validation=True
)
# Define the task to generate a person's information
task = "Generate a person's information"
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(type(generated_data))
print(f"Generated data: {generated_data}")

@ -7,7 +7,7 @@ import random
import sys
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
import yaml
from loguru import logger
@ -43,6 +43,7 @@ from swarms.tools.pydantic_to_json import (
from swarms.tools.tool_parse_exec import parse_and_execute_json
from swarms.utils.data_to_text import data_to_text
from swarms.utils.parse_code import extract_code_from_markdown
from swarms.tools.json_utils import extract_json_from_str
from swarms.utils.pdf_to_text import pdf_to_text
from swarms.tools.prebuilt.code_executor import CodeExecutor
from swarms.models.popular_llms import OpenAIChat
@ -251,6 +252,7 @@ class Agent(BaseStructure):
output_cleaner: Optional[Callable] = None,
function_calling_format_type: Optional[str] = "OpenAI",
list_base_models: Optional[List[BaseModel]] = None,
output_validation: Optional[bool] = False,
metadata_output_type: str = "json",
state_save_file_type: str = "json",
chain_of_thoughts: bool = False,
@ -311,6 +313,7 @@ class Agent(BaseStructure):
self.multi_modal = multi_modal
self.pdf_path = pdf_path
self.list_of_pdf = list_of_pdf
self.output_validation = output_validation
self.tokenizer = tokenizer
self.long_term_memory = long_term_memory
self.preset_stopping_token = preset_stopping_token
@ -570,6 +573,14 @@ class Agent(BaseStructure):
colored(f"Error dynamically changing temperature: {error}")
)
def pydantic_validation(self, response:str)-> Type[BaseModel]:
"""Validates the response using Pydantic."""
parsed_json = extract_json_from_str(response)
function_call = parsed_json["function_call"]
parameters = json.dumps(parsed_json["parameters"])
selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None)
return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True)
def format_prompt(self, template, **kwargs: Any) -> str:
"""Format the template with the provided kwargs using f-string interpolation."""
return template.format(**kwargs)
@ -782,7 +793,7 @@ class Agent(BaseStructure):
response = self.llm(
task_prompt, *args, **kwargs
)
print(response)
self.printtier(response)
# Add to memory
self.short_memory.add(
@ -813,8 +824,8 @@ class Agent(BaseStructure):
# TODO: Implement reliablity check
if self.tools is not None:
# self.parse_function_call_and_execute(response)
self.parse_and_execute_tools(response)
result = None
result = self.parse_function_call_and_execute(response)
if self.code_interpreter is True:
# Parse the code and execute
@ -839,6 +850,10 @@ class Agent(BaseStructure):
)
# all_responses.append(evaluated_response)
if self.output_validation:
result = None
result = self.pydantic_validation(response)
# Sentiment analysis
if self.sentiment_analyzer:
@ -926,6 +941,8 @@ class Agent(BaseStructure):
if self.return_history:
return self.short_memory.return_history_as_string()
elif self.output_validation:
return result
else:
return final_response
@ -1305,6 +1322,14 @@ class Agent(BaseStructure):
for word in self.response_filters:
response = response.replace(word, "[FILTERED]")
return response
def printtier(self, response:str) -> str:
"""
Specifies the name of the agent in capital letters in pink and the response text in blue.
Add space above.
"""
print("\n")
return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}")
def filtered_run(self, task: str) -> str:
"""
@ -1734,7 +1759,7 @@ class Agent(BaseStructure):
)
response = self.llm(task_prompt, *args, **kwargs)
print(response)
self.printtier(response)
self.short_memory.add(
role=self.agent_name, content=response

@ -30,10 +30,12 @@ def extract_json_from_str(response: str):
Raises:
ValueError: If the string does not contain a valid JSON object.
"""
json_start = response.index("{")
json_end = response.rfind("}")
return json.loads(response[json_start : json_end + 1])
try:
json_start = response.index("{")
json_end = response.rfind("}")
return json.loads(response[json_start : json_end + 1])
except Exception as e:
raise ValueError("No valid JSON structure found in the input string")
def str_to_json(response: str, indent: int = 3):
"""

Loading…
Cancel
Save