From 8052cb62d1b5e9aec9b937509ca3d77c69e7c3cb Mon Sep 17 00:00:00 2001 From: Nicolas Nahas <45595586+nicorne@users.noreply.github.com> Date: Wed, 7 Aug 2024 00:19:37 +0200 Subject: [PATCH] Pydantic validation --- .../advanced_output_validation.py | 76 +++++++++++++++++++ .../simple_structured_output.py | 71 +++++++++++++++++ swarms/structs/agent.py | 35 +++++++-- swarms/tools/json_utils.py | 10 ++- 4 files changed, 183 insertions(+), 9 deletions(-) create mode 100644 playground/strcutured_output/advanced_output_validation.py create mode 100644 playground/strcutured_output/simple_structured_output.py diff --git a/playground/strcutured_output/advanced_output_validation.py b/playground/strcutured_output/advanced_output_validation.py new file mode 100644 index 00000000..49dd02de --- /dev/null +++ b/playground/strcutured_output/advanced_output_validation.py @@ -0,0 +1,76 @@ +""" +* WORKING +What this script does: +Structured output example with validation function + +Requirements: +Add the folowing API key(s) in your .env file: + - OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure) +""" + +################ Adding project root to PYTHONPATH ################################ +# If you are running playground examples in the project files directly, use this: + +import sys +import os + +sys.path.insert(0, os.getcwd()) + +################ Adding project root to PYTHONPATH ################################ + +from swarms import Agent, OpenAIChat + +from pydantic import BaseModel, Field +from typing_extensions import Annotated +from pydantic import AfterValidator + + +def symbol_must_exists(symbol= str) -> str: + symbols = [ + "AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B", + "JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO" + ] + if symbol not in symbols: + raise ValueError(f"symbol must exists in the list: {symbols}") + + return symbol + + +# Initialize the schema for the person's information +class StockInfo(BaseModel): + """ + To create a StockInfo, you need to return a JSON object with the following format: + { + "function_call": "StockInfo", + "parameters": { + ... + } + } + """ + name: str = Field(..., title="Name of the company") + description: str = Field(..., title="Description of the company") + symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company") + + +# Define the task to generate a person's information +task = "Generate an existing S&P500's company information" + +# Initialize the agent +agent = Agent( + agent_name="Stock Information Generator", + system_prompt=( + "Generate a public comapany's information" + ), + llm=OpenAIChat(), + max_loops=1, + verbose=True, + # List of schemas that the agent can handle + list_base_models=[StockInfo], + output_validation=True, +) + +# Run the agent to generate the person's information +generated_data = agent.run(task) + +# Print the generated data +print(f"Generated data: {generated_data}") diff --git a/playground/strcutured_output/simple_structured_output.py b/playground/strcutured_output/simple_structured_output.py new file mode 100644 index 00000000..d509107b --- /dev/null +++ b/playground/strcutured_output/simple_structured_output.py @@ -0,0 +1,71 @@ +""" +* WORKING + +What this script does: +Structured output example + +Requirements: +Add the folowing API key(s) in your .env file: + - OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure) + +Note: +If you are running playground examples in the project files directly (without swarms installed via PIP), +make sure to add the project root to your PYTHONPATH by running the following command in the project's root directory: + 'export PYTHONPATH=$(pwd):$PYTHONPATH' +""" + +################ Adding project root to PYTHONPATH ################################ +# If you are running playground examples in the project files directly, use this: + +import sys +import os + +sys.path.insert(0, os.getcwd()) + +################ Adding project root to PYTHONPATH ################################ + +from pydantic import BaseModel, Field +from swarms import Agent, OpenAIChat + + +# Initialize the schema for the person's information +class PersonInfo(BaseModel): + """ + To create a PersonInfo, you need to return a JSON object with the following format: + { + "function_call": "PersonInfo", + "parameters": { + ... + } + } + """ + name: str = Field(..., title="Name of the person") + agent: int = Field(..., title="Age of the person") + is_student: bool = Field(..., title="Whether the person is a student") + courses: list[str] = Field( + ..., title="List of courses the person is taking" + ) + +# Initialize the agent +agent = Agent( + agent_name="Person Information Generator", + system_prompt=( + "Generate a person's information" + ), + llm=OpenAIChat(), + max_loops=1, + verbose=True, + # List of pydantic models that the agent can use + list_base_models=[PersonInfo], + output_validation=True +) + +# Define the task to generate a person's information +task = "Generate a person's information" + +# Run the agent to generate the person's information +generated_data = agent.run(task) + +# Print the generated data +print(type(generated_data)) +print(f"Generated data: {generated_data}") diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index fccfef2d..b3809db1 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -7,7 +7,7 @@ import random import sys import time import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import yaml from loguru import logger @@ -43,6 +43,7 @@ from swarms.tools.pydantic_to_json import ( from swarms.tools.tool_parse_exec import parse_and_execute_json from swarms.utils.data_to_text import data_to_text from swarms.utils.parse_code import extract_code_from_markdown +from swarms.tools.json_utils import extract_json_from_str from swarms.utils.pdf_to_text import pdf_to_text from swarms.tools.prebuilt.code_executor import CodeExecutor from swarms.models.popular_llms import OpenAIChat @@ -251,6 +252,7 @@ class Agent(BaseStructure): output_cleaner: Optional[Callable] = None, function_calling_format_type: Optional[str] = "OpenAI", list_base_models: Optional[List[BaseModel]] = None, + output_validation: Optional[bool] = False, metadata_output_type: str = "json", state_save_file_type: str = "json", chain_of_thoughts: bool = False, @@ -311,6 +313,7 @@ class Agent(BaseStructure): self.multi_modal = multi_modal self.pdf_path = pdf_path self.list_of_pdf = list_of_pdf + self.output_validation = output_validation self.tokenizer = tokenizer self.long_term_memory = long_term_memory self.preset_stopping_token = preset_stopping_token @@ -570,6 +573,14 @@ class Agent(BaseStructure): colored(f"Error dynamically changing temperature: {error}") ) + def pydantic_validation(self, response:str)-> Type[BaseModel]: + """Validates the response using Pydantic.""" + parsed_json = extract_json_from_str(response) + function_call = parsed_json["function_call"] + parameters = json.dumps(parsed_json["parameters"]) + selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None) + return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True) + def format_prompt(self, template, **kwargs: Any) -> str: """Format the template with the provided kwargs using f-string interpolation.""" return template.format(**kwargs) @@ -782,7 +793,7 @@ class Agent(BaseStructure): response = self.llm( task_prompt, *args, **kwargs ) - print(response) + self.printtier(response) # Add to memory self.short_memory.add( @@ -813,8 +824,8 @@ class Agent(BaseStructure): # TODO: Implement reliablity check if self.tools is not None: - # self.parse_function_call_and_execute(response) - self.parse_and_execute_tools(response) + result = None + result = self.parse_function_call_and_execute(response) if self.code_interpreter is True: # Parse the code and execute @@ -839,6 +850,10 @@ class Agent(BaseStructure): ) # all_responses.append(evaluated_response) + + if self.output_validation: + result = None + result = self.pydantic_validation(response) # Sentiment analysis if self.sentiment_analyzer: @@ -926,6 +941,8 @@ class Agent(BaseStructure): if self.return_history: return self.short_memory.return_history_as_string() + elif self.output_validation: + return result else: return final_response @@ -1305,6 +1322,14 @@ class Agent(BaseStructure): for word in self.response_filters: response = response.replace(word, "[FILTERED]") return response + + def printtier(self, response:str) -> str: + """ + Specifies the name of the agent in capital letters in pink and the response text in blue. + Add space above. + """ + print("\n") + return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}") def filtered_run(self, task: str) -> str: """ @@ -1734,7 +1759,7 @@ class Agent(BaseStructure): ) response = self.llm(task_prompt, *args, **kwargs) - print(response) + self.printtier(response) self.short_memory.add( role=self.agent_name, content=response diff --git a/swarms/tools/json_utils.py b/swarms/tools/json_utils.py index 0902d2c7..581fe7ab 100644 --- a/swarms/tools/json_utils.py +++ b/swarms/tools/json_utils.py @@ -30,10 +30,12 @@ def extract_json_from_str(response: str): Raises: ValueError: If the string does not contain a valid JSON object. """ - json_start = response.index("{") - json_end = response.rfind("}") - return json.loads(response[json_start : json_end + 1]) - + try: + json_start = response.index("{") + json_end = response.rfind("}") + return json.loads(response[json_start : json_end + 1]) + except Exception as e: + raise ValueError("No valid JSON structure found in the input string") def str_to_json(response: str, indent: int = 3): """