structured-output

pull/571/head
Nicolas Nahas 1 year ago
parent 103d3937d3
commit e95090cbba

@ -0,0 +1,77 @@
"""
* WORKING
What this script does:
Structured output example with validation function
Requirements:
pip install openai
pip install pydantic
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from swarms import Agent, OpenAIChat
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from pydantic import AfterValidator
def symbol_must_exists(symbol= str) -> str:
symbols = [
"AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B",
"JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO"
]
if symbol not in symbols:
raise ValueError(f"symbol must exists in the list: {symbols}")
return symbol
# Initialize the schema for the person's information
class StockInfo(BaseModel):
"""
To create a StockInfo, you need to return a JSON object with the following format:
{
"function_call": "StockInfo",
"parameters": {
...
}
}
"""
name: str = Field(..., title="Name of the company")
description: str = Field(..., title="Description of the company")
symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company")
# Define the task to generate a person's information
task = "Generate an existing S&P500's company information"
# Initialize the agent
agent = Agent(
agent_name="Stock Information Generator",
system_prompt=(
"Generate a public comapany's information"
),
llm=OpenAIChat(),
max_loops=1,
verbose=True,
# List of schemas that the agent can handle
list_base_models=[StockInfo],
output_validation=True,
)
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(f"Generated data: {generated_data}")

@ -0,0 +1,65 @@
"""
* WORKING
What this script does:
Structured output example
Requirements:
Add the folowing API key(s) in your .env file:
- OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure)
"""
################ Adding project root to PYTHONPATH ################################
# If you are running playground examples in the project files directly, use this:
import sys
import os
sys.path.insert(0, os.getcwd())
################ Adding project root to PYTHONPATH ################################
from pydantic import BaseModel, Field
from swarms import Agent, OpenAIChat
# Initialize the schema for the person's information
class PersonInfo(BaseModel):
"""
To create a PersonInfo, you need to return a JSON object with the following format:
{
"function_call": "PersonInfo",
"parameters": {
...
}
}
"""
name: str = Field(..., title="Name of the person")
age: int = Field(..., title="Age of the person")
is_student: bool = Field(..., title="Whether the person is a student")
courses: list[str] = Field(
..., title="List of courses the person is taking"
)
# Initialize the agent
agent = Agent(
agent_name="Person Information Generator",
system_prompt=(
"Generate a person's information"
),
llm=OpenAIChat(),
max_loops=1,
verbose=True,
# List of pydantic models that the agent can use
list_base_models=[PersonInfo],
output_validation=True
)
# Define the task to generate a person's information
task = "Generate a person's information for Paul Graham 56 years old and is a student at Harvard University and is taking 3 courses: Math, Science, and History."
# Run the agent to generate the person's information
generated_data = agent.run(task)
# Print the generated data
print(type(generated_data))
print(f"Generated data: {generated_data}")

@ -30,4 +30,5 @@ pandas>=2.2.2
fastapi>=0.110.1 fastapi>=0.110.1
networkx networkx
swarms-memory swarms-memory
swarms-cloud
pre-commit pre-commit

@ -7,7 +7,7 @@ import random
import sys import sys
import time import time
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
import toml import toml
import yaml import yaml
@ -51,6 +51,7 @@ from swarms.tools.tool_parse_exec import parse_and_execute_json
from swarms.utils.data_to_text import data_to_text from swarms.utils.data_to_text import data_to_text
from swarms.utils.file_processing import create_file_in_folder from swarms.utils.file_processing import create_file_in_folder
from swarms.utils.parse_code import extract_code_from_markdown from swarms.utils.parse_code import extract_code_from_markdown
from swarms.tools.json_utils import extract_json_from_str
from swarms.utils.pdf_to_text import pdf_to_text from swarms.utils.pdf_to_text import pdf_to_text
@ -238,7 +239,8 @@ class Agent:
function_calling_type: str = "json", function_calling_type: str = "json",
output_cleaner: Optional[Callable] = None, output_cleaner: Optional[Callable] = None,
function_calling_format_type: Optional[str] = "OpenAI", function_calling_format_type: Optional[str] = "OpenAI",
list_base_models: Optional[List[BaseModel]] = None, list_base_models: Optional[List[Type[BaseModel]]] = None,
output_validation: Optional[bool] = False,
metadata_output_type: str = "json", metadata_output_type: str = "json",
state_save_file_type: str = "json", state_save_file_type: str = "json",
chain_of_thoughts: bool = False, chain_of_thoughts: bool = False,
@ -333,6 +335,7 @@ class Agent:
self.output_cleaner = output_cleaner self.output_cleaner = output_cleaner
self.function_calling_format_type = function_calling_format_type self.function_calling_format_type = function_calling_format_type
self.list_base_models = list_base_models self.list_base_models = list_base_models
self.output_validation = output_validation
self.metadata_output_type = metadata_output_type self.metadata_output_type = metadata_output_type
self.state_save_file_type = state_save_file_type self.state_save_file_type = state_save_file_type
self.chain_of_thoughts = chain_of_thoughts self.chain_of_thoughts = chain_of_thoughts
@ -577,6 +580,14 @@ class Agent:
print( print(
colored(f"Error dynamically changing temperature: {error}") colored(f"Error dynamically changing temperature: {error}")
) )
def pydantic_validation(self, response:str)-> Type[BaseModel]:
"""Validates the response using Pydantic."""
parsed_json = extract_json_from_str(response)
function_call = parsed_json["function_call"]
parameters = json.dumps(parsed_json["parameters"])
selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None)
return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True)
def format_prompt(self, template, **kwargs: Any) -> str: def format_prompt(self, template, **kwargs: Any) -> str:
"""Format the template with the provided kwargs using f-string interpolation.""" """Format the template with the provided kwargs using f-string interpolation."""
@ -675,6 +686,7 @@ class Agent:
# Clear the short memory # Clear the short memory
response = None response = None
result = None
all_responses = [] all_responses = []
steps_pool = [] steps_pool = []
@ -717,7 +729,7 @@ class Agent:
if self.streaming_on is True: if self.streaming_on is True:
response = self.stream_response(response) response = self.stream_response(response)
else: else:
print(response) self.printtier(response)
# Add the response to the memory # Add the response to the memory
self.short_memory.add( self.short_memory.add(
@ -733,8 +745,8 @@ class Agent:
# TODO: Implement reliablity check # TODO: Implement reliablity check
if self.tools is not None: if self.tools is not None:
# self.parse_function_call_and_execute(response) # self.parse_and_execute_tools(response)
self.parse_and_execute_tools(response) result = self.parse_function_call_and_execute(response)
if self.code_interpreter is True: if self.code_interpreter is True:
# Parse the code and execute # Parse the code and execute
@ -773,6 +785,9 @@ class Agent:
# all_responses.append(evaluated_response) # all_responses.append(evaluated_response)
if self.output_validation:
result = self.pydantic_validation(response)
# Sentiment analysis # Sentiment analysis
if self.sentiment_analyzer: if self.sentiment_analyzer:
logger.info("Analyzing sentiment...") logger.info("Analyzing sentiment...")
@ -858,6 +873,9 @@ class Agent:
# logger.info(f"Final Response: {final_response}") # logger.info(f"Final Response: {final_response}")
if self.return_history: if self.return_history:
return self.short_memory.return_history_as_string() return self.short_memory.return_history_as_string()
elif self.output_validation:
return result
elif self.return_step_meta: elif self.return_step_meta:
log = ManySteps( log = ManySteps(
@ -1303,6 +1321,14 @@ class Agent:
except Exception as error: except Exception as error:
print(colored(f"Error retrying function: {error}", "red")) print(colored(f"Error retrying function: {error}", "red"))
def printtier(self, response:str) -> str:
"""
Specifies the name of the agent in capital letters in pink and the response text in blue.
Add space above.
"""
print("\n")
return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}")
def update_system_prompt(self, system_prompt: str): def update_system_prompt(self, system_prompt: str):
"""Upddate the system message""" """Upddate the system message"""
self.system_prompt = system_prompt self.system_prompt = system_prompt
@ -1496,7 +1522,7 @@ class Agent:
) )
response = self.llm(task_prompt, *args, **kwargs) response = self.llm(task_prompt, *args, **kwargs)
print(response) self.printtier(response)
self.short_memory.add( self.short_memory.add(
role=self.agent_name, content=response role=self.agent_name, content=response
@ -1884,9 +1910,7 @@ class Agent:
full_memory = self.short_memory.return_history_as_string() full_memory = self.short_memory.return_history_as_string()
prompt_tokens = self.tokenizer.count_tokens(full_memory) prompt_tokens = self.tokenizer.count_tokens(full_memory)
completion_tokens = self.tokenizer.count_tokens(response) completion_tokens = self.tokenizer.count_tokens(response)
total_tokens = self.tokenizer.count_tokens( total_tokens = prompt_tokens + completion_tokens
prompt_tokens + completion_tokens
)
logger.info("Logging step metadata...") logger.info("Logging step metadata...")

@ -30,9 +30,12 @@ def extract_json_from_str(response: str):
Raises: Raises:
ValueError: If the string does not contain a valid JSON object. ValueError: If the string does not contain a valid JSON object.
""" """
json_start = response.index("{") try:
json_end = response.rfind("}") json_start = response.index("{")
return json.loads(response[json_start : json_end + 1]) json_end = response.rfind("}")
return json.loads(response[json_start : json_end + 1])
except Exception as e:
raise ValueError("No valid JSON structure found in the input string")
def str_to_json(response: str, indent: int = 3): def str_to_json(response: str, indent: int = 3):

Loading…
Cancel
Save