output validation

pydantic_validation
Nicolas Nahas 6 months ago
parent de12e0c0a7
commit 2064896f03

@ -0,0 +1,72 @@
"""
* WORKING
What this script does:
Structured output example with validation function
Requirements:
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from swarms import Agent, OpenAIChat
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from pydantic import AfterValidator
def symbol_must_exists(symbol= str) -> str:
symbols = [
"AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B",
"JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO"
]
if symbol not in symbols:
raise ValueError(f"symbol must exists in the list: {symbols}")
return symbol
# Initialize the schema for the person's information
class StockInfo(BaseModel):
"""
Describing a stock and it's infos
"""
name: str = Field(..., title="Name of the company")
description: str = Field(..., title="Description of the company")
symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company")
# Define the task to generate a person's information
task = "Generate an existing S&P500's company information"
# Initialize the agent
agent = Agent(
agent_name="Stock Information Generator",
system_prompt=(
"Generate a public comapany's information"
),
llm=OpenAIChat(),
max_loops=1,
streaming_on=False, # TODO code breaks when this is True
verbose=True,
# List of schemas that the agent can handle
list_base_models=[StockInfo],
#agent_ops_on=True,
output_validation=True,
)
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(f"Generated data: {generated_data}")

@ -7,7 +7,7 @@ import random
import sys
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
import yaml
from loguru import logger
@ -244,6 +244,7 @@ class Agent(BaseStructure):
output_cleaner: Optional[Callable] = None,
function_calling_format_type: Optional[str] = "OpenAI",
list_base_models: Optional[List[BaseModel]] = None,
output_validation: Optional[bool] = False,
metadata_output_type: str = "json",
state_save_file_type: str = "json",
chain_of_thoughts: bool = False,
@ -328,6 +329,7 @@ class Agent(BaseStructure):
self.output_cleaner = output_cleaner
self.function_calling_format_type = function_calling_format_type
self.list_base_models = list_base_models
self.output_validation = output_validation
self.metadata_output_type = metadata_output_type
self.state_save_file_type = state_save_file_type
self.chain_of_thoughts = chain_of_thoughts
@ -523,6 +525,25 @@ class Agent(BaseStructure):
colored(f"Error dynamically changing temperature: {error}")
)
def extract_json(self, response:str):
stack = []
for i, char in enumerate(response):
if char == '{':
stack.append(i)
elif char == '}' and stack:
start = stack.pop()
if not stack:
full_json = json.loads(response[start:i+1])
function_call = full_json["function_call"]
parameters = json.dumps(full_json["parameters"])
return function_call, parameters
raise ValueError("No valid JSON structure found in the input string")
def pydantic_validation(self, response:str)-> Type[BaseModel]:
function_call, parameters = self.extract_json(response)
selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None)
return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True)
def format_prompt(self, template, **kwargs: Any) -> str:
"""Format the template with the provided kwargs using f-string interpolation."""
return template.format(**kwargs)
@ -671,13 +692,13 @@ class Agent(BaseStructure):
########################## FUNCTION CALLING ##########################
def readability(self, response:str) -> str:
def printtier(self, response:str) -> str:
"""
Specifies the name of the agent in capital letters in pink and the response text in blue.
Add space above.
"""
print("\n")
return f"\033[1;34m{self.name.upper()}:\033[0m {response}"
return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}")
def run(
self,
@ -734,7 +755,7 @@ class Agent(BaseStructure):
response = self.llm(
task_prompt, *args, **kwargs
)
print(self.readability(response))
self.printtier(response)
self.short_memory.add(
role=self.agent_name, content=response
@ -754,13 +775,13 @@ class Agent(BaseStructure):
if img is None
else (task_prompt, img, *args)
)
response = self.readability(self.llm(*response_args, **kwargs))
response = self.llm(*response_args, **kwargs)
# Print
if self.streaming_on is True:
response = self.stream_response(response)
else:
print(response)
self.printtier(response)
# Add the response to the memory
self.short_memory.add(
@ -784,6 +805,10 @@ class Agent(BaseStructure):
role=self.agent_name,
content=evaluated_response,
)
if self.output_validation:
instance = None
instance = self.pydantic_validation(response)
# Sentiment analysis
if self.sentiment_analyzer:
@ -863,6 +888,8 @@ class Agent(BaseStructure):
if self.return_history:
return self.short_memory.return_history_as_string()
elif self.output_validation:
return instance
else:
return response

Loading…
Cancel
Save