From 2064896f039efffbbaed9e6a9c882ca0d6955327 Mon Sep 17 00:00:00 2001 From: Nicolas Nahas <45595586+nicorne@users.noreply.github.com> Date: Tue, 30 Jul 2024 22:54:52 +0200 Subject: [PATCH] output validation --- .../strcutured_output/output_validator.py | 72 +++++++++++++++++++ swarms/structs/agent.py | 39 ++++++++-- 2 files changed, 105 insertions(+), 6 deletions(-) create mode 100644 playground/strcutured_output/output_validator.py diff --git a/playground/strcutured_output/output_validator.py b/playground/strcutured_output/output_validator.py new file mode 100644 index 00000000..9dc11a9a --- /dev/null +++ b/playground/strcutured_output/output_validator.py @@ -0,0 +1,72 @@ +""" +* WORKING +What this script does: +Structured output example with validation function + +Requirements: +Add the folowing API key(s) in your .env file: + - OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure) +""" + +################ Adding project root to PYTHONPATH ################################ +# If you are running playground examples in the project files directly, use this: + +import sys +import os + +sys.path.insert(0, os.getcwd()) + +################ Adding project root to PYTHONPATH ################################ + +from swarms import Agent, OpenAIChat + +from pydantic import BaseModel, Field +from typing_extensions import Annotated +from pydantic import AfterValidator + + +def symbol_must_exists(symbol= str) -> str: + symbols = [ + "AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B", + "JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO" + ] + if symbol not in symbols: + raise ValueError(f"symbol must exists in the list: {symbols}") + + return symbol + + +# Initialize the schema for the person's information +class StockInfo(BaseModel): + """ + Describing a stock and it's infos + """ + name: str = Field(..., title="Name of the company") + description: str = Field(..., title="Description of the company") + symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company") + + +# Define the task to generate a person's information +task = "Generate an existing S&P500's company information" + +# Initialize the agent +agent = Agent( + agent_name="Stock Information Generator", + system_prompt=( + "Generate a public comapany's information" + ), + llm=OpenAIChat(), + max_loops=1, + streaming_on=False, # TODO code breaks when this is True + verbose=True, + # List of schemas that the agent can handle + list_base_models=[StockInfo], + #agent_ops_on=True, + output_validation=True, +) + +# Run the agent to generate the person's information +generated_data = agent.run(task) + +# Print the generated data +print(f"Generated data: {generated_data}") diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 1fbac15d..83ab7355 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -7,7 +7,7 @@ import random import sys import time import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import yaml from loguru import logger @@ -244,6 +244,7 @@ class Agent(BaseStructure): output_cleaner: Optional[Callable] = None, function_calling_format_type: Optional[str] = "OpenAI", list_base_models: Optional[List[BaseModel]] = None, + output_validation: Optional[bool] = False, metadata_output_type: str = "json", state_save_file_type: str = "json", chain_of_thoughts: bool = False, @@ -328,6 +329,7 @@ class Agent(BaseStructure): self.output_cleaner = output_cleaner self.function_calling_format_type = function_calling_format_type self.list_base_models = list_base_models + self.output_validation = output_validation self.metadata_output_type = metadata_output_type self.state_save_file_type = state_save_file_type self.chain_of_thoughts = chain_of_thoughts @@ -523,6 +525,25 @@ class Agent(BaseStructure): colored(f"Error dynamically changing temperature: {error}") ) + def extract_json(self, response:str): + stack = [] + for i, char in enumerate(response): + if char == '{': + stack.append(i) + elif char == '}' and stack: + start = stack.pop() + if not stack: + full_json = json.loads(response[start:i+1]) + function_call = full_json["function_call"] + parameters = json.dumps(full_json["parameters"]) + return function_call, parameters + raise ValueError("No valid JSON structure found in the input string") + + def pydantic_validation(self, response:str)-> Type[BaseModel]: + function_call, parameters = self.extract_json(response) + selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None) + return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True) + def format_prompt(self, template, **kwargs: Any) -> str: """Format the template with the provided kwargs using f-string interpolation.""" return template.format(**kwargs) @@ -671,13 +692,13 @@ class Agent(BaseStructure): ########################## FUNCTION CALLING ########################## - def readability(self, response:str) -> str: + def printtier(self, response:str) -> str: """ Specifies the name of the agent in capital letters in pink and the response text in blue. Add space above. """ print("\n") - return f"\033[1;34m{self.name.upper()}:\033[0m {response}" + return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}") def run( self, @@ -734,7 +755,7 @@ class Agent(BaseStructure): response = self.llm( task_prompt, *args, **kwargs ) - print(self.readability(response)) + self.printtier(response) self.short_memory.add( role=self.agent_name, content=response @@ -754,13 +775,13 @@ class Agent(BaseStructure): if img is None else (task_prompt, img, *args) ) - response = self.readability(self.llm(*response_args, **kwargs)) + response = self.llm(*response_args, **kwargs) # Print if self.streaming_on is True: response = self.stream_response(response) else: - print(response) + self.printtier(response) # Add the response to the memory self.short_memory.add( @@ -784,6 +805,10 @@ class Agent(BaseStructure): role=self.agent_name, content=evaluated_response, ) + + if self.output_validation: + instance = None + instance = self.pydantic_validation(response) # Sentiment analysis if self.sentiment_analyzer: @@ -863,6 +888,8 @@ class Agent(BaseStructure): if self.return_history: return self.short_memory.return_history_as_string() + elif self.output_validation: + return instance else: return response