From e95090cbbabf6a8b9342e681f99ed61ba20dedb1 Mon Sep 17 00:00:00 2001 From: Nicolas Nahas <45595586+nicorne@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:59:49 +0200 Subject: [PATCH] structured-output --- .../tools/advanced_output_validation.py | 77 +++++++++++++++++++ playground/tools/output_validation.py | 65 ++++++++++++++++ requirements.txt | 1 + swarms/structs/agent.py | 42 +++++++--- swarms/tools/json_utils.py | 9 ++- 5 files changed, 182 insertions(+), 12 deletions(-) create mode 100644 playground/tools/advanced_output_validation.py create mode 100644 playground/tools/output_validation.py diff --git a/playground/tools/advanced_output_validation.py b/playground/tools/advanced_output_validation.py new file mode 100644 index 00000000..4c288c2b --- /dev/null +++ b/playground/tools/advanced_output_validation.py @@ -0,0 +1,77 @@ +""" +* WORKING +What this script does: +Structured output example with validation function +Requirements: +pip install openai +pip install pydantic +Add the folowing API key(s) in your .env file: + - OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure) +""" + +################ Adding project root to PYTHONPATH ################################ +# If you are running playground examples in the project files directly, use this: + +import sys +import os + +sys.path.insert(0, os.getcwd()) + +################ Adding project root to PYTHONPATH ################################ + +from swarms import Agent, OpenAIChat + +from pydantic import BaseModel, Field +from typing_extensions import Annotated +from pydantic import AfterValidator + + +def symbol_must_exists(symbol= str) -> str: + symbols = [ + "AAPL", "MSFT", "AMZN", "GOOGL", "GOOG", "META", "TSLA", "NVDA", "BRK.B", + "JPM", "JNJ", "V", "PG", "UNH", "MA", "HD", "BAC", "XOM", "DIS", "CSCO" + ] + if symbol not in symbols: + raise ValueError(f"symbol must exists in the list: {symbols}") + + return symbol + + +# Initialize the schema for the person's information +class StockInfo(BaseModel): + """ + To create a StockInfo, you need to return a JSON object with the following format: + { + "function_call": "StockInfo", + "parameters": { + ... + } + } + """ + name: str = Field(..., title="Name of the company") + description: str = Field(..., title="Description of the company") + symbol: Annotated[str, AfterValidator(symbol_must_exists)] = Field(..., title="stock symbol of the company") + + +# Define the task to generate a person's information +task = "Generate an existing S&P500's company information" + +# Initialize the agent +agent = Agent( + agent_name="Stock Information Generator", + system_prompt=( + "Generate a public comapany's information" + ), + llm=OpenAIChat(), + max_loops=1, + verbose=True, + # List of schemas that the agent can handle + list_base_models=[StockInfo], + output_validation=True, +) + +# Run the agent to generate the person's information +generated_data = agent.run(task) + +# Print the generated data +print(f"Generated data: {generated_data}") \ No newline at end of file diff --git a/playground/tools/output_validation.py b/playground/tools/output_validation.py new file mode 100644 index 00000000..72e95244 --- /dev/null +++ b/playground/tools/output_validation.py @@ -0,0 +1,65 @@ +""" +* WORKING +What this script does: +Structured output example +Requirements: +Add the folowing API key(s) in your .env file: + - OPENAI_API_KEY (this example works best with Openai bc it uses openai function calling structure) + +""" + +################ Adding project root to PYTHONPATH ################################ +# If you are running playground examples in the project files directly, use this: + +import sys +import os + +sys.path.insert(0, os.getcwd()) + +################ Adding project root to PYTHONPATH ################################ + +from pydantic import BaseModel, Field +from swarms import Agent, OpenAIChat + + +# Initialize the schema for the person's information +class PersonInfo(BaseModel): + """ + To create a PersonInfo, you need to return a JSON object with the following format: + { + "function_call": "PersonInfo", + "parameters": { + ... + } + } + """ + name: str = Field(..., title="Name of the person") + age: int = Field(..., title="Age of the person") + is_student: bool = Field(..., title="Whether the person is a student") + courses: list[str] = Field( + ..., title="List of courses the person is taking" + ) + +# Initialize the agent +agent = Agent( + agent_name="Person Information Generator", + system_prompt=( + "Generate a person's information" + ), + llm=OpenAIChat(), + max_loops=1, + verbose=True, + # List of pydantic models that the agent can use + list_base_models=[PersonInfo], + output_validation=True +) + +# Define the task to generate a person's information +task = "Generate a person's information for Paul Graham 56 years old and is a student at Harvard University and is taking 3 courses: Math, Science, and History." + +# Run the agent to generate the person's information +generated_data = agent.run(task) + +# Print the generated data +print(type(generated_data)) +print(f"Generated data: {generated_data}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e42c5372..3d90735a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,5 @@ pandas>=2.2.2 fastapi>=0.110.1 networkx swarms-memory +swarms-cloud pre-commit \ No newline at end of file diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index e33df302..8476e681 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -7,7 +7,7 @@ import random import sys import time import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import toml import yaml @@ -51,6 +51,7 @@ from swarms.tools.tool_parse_exec import parse_and_execute_json from swarms.utils.data_to_text import data_to_text from swarms.utils.file_processing import create_file_in_folder from swarms.utils.parse_code import extract_code_from_markdown +from swarms.tools.json_utils import extract_json_from_str from swarms.utils.pdf_to_text import pdf_to_text @@ -238,7 +239,8 @@ class Agent: function_calling_type: str = "json", output_cleaner: Optional[Callable] = None, function_calling_format_type: Optional[str] = "OpenAI", - list_base_models: Optional[List[BaseModel]] = None, + list_base_models: Optional[List[Type[BaseModel]]] = None, + output_validation: Optional[bool] = False, metadata_output_type: str = "json", state_save_file_type: str = "json", chain_of_thoughts: bool = False, @@ -333,6 +335,7 @@ class Agent: self.output_cleaner = output_cleaner self.function_calling_format_type = function_calling_format_type self.list_base_models = list_base_models + self.output_validation = output_validation self.metadata_output_type = metadata_output_type self.state_save_file_type = state_save_file_type self.chain_of_thoughts = chain_of_thoughts @@ -577,6 +580,14 @@ class Agent: print( colored(f"Error dynamically changing temperature: {error}") ) + + def pydantic_validation(self, response:str)-> Type[BaseModel]: + """Validates the response using Pydantic.""" + parsed_json = extract_json_from_str(response) + function_call = parsed_json["function_call"] + parameters = json.dumps(parsed_json["parameters"]) + selected_base_model = next((model for model in self.list_base_models if model.__name__ == function_call), None) + return selected_base_model.__pydantic_validator__.validate_json(parameters, strict=True) def format_prompt(self, template, **kwargs: Any) -> str: """Format the template with the provided kwargs using f-string interpolation.""" @@ -675,6 +686,7 @@ class Agent: # Clear the short memory response = None + result = None all_responses = [] steps_pool = [] @@ -717,7 +729,7 @@ class Agent: if self.streaming_on is True: response = self.stream_response(response) else: - print(response) + self.printtier(response) # Add the response to the memory self.short_memory.add( @@ -733,8 +745,8 @@ class Agent: # TODO: Implement reliablity check if self.tools is not None: - # self.parse_function_call_and_execute(response) - self.parse_and_execute_tools(response) + # self.parse_and_execute_tools(response) + result = self.parse_function_call_and_execute(response) if self.code_interpreter is True: # Parse the code and execute @@ -773,6 +785,9 @@ class Agent: # all_responses.append(evaluated_response) + if self.output_validation: + result = self.pydantic_validation(response) + # Sentiment analysis if self.sentiment_analyzer: logger.info("Analyzing sentiment...") @@ -858,6 +873,9 @@ class Agent: # logger.info(f"Final Response: {final_response}") if self.return_history: return self.short_memory.return_history_as_string() + + elif self.output_validation: + return result elif self.return_step_meta: log = ManySteps( @@ -1303,6 +1321,14 @@ class Agent: except Exception as error: print(colored(f"Error retrying function: {error}", "red")) + def printtier(self, response:str) -> str: + """ + Specifies the name of the agent in capital letters in pink and the response text in blue. + Add space above. + """ + print("\n") + return print(f"\033[1;34m{self.name.upper()}:\033[0m {response}") + def update_system_prompt(self, system_prompt: str): """Upddate the system message""" self.system_prompt = system_prompt @@ -1496,7 +1522,7 @@ class Agent: ) response = self.llm(task_prompt, *args, **kwargs) - print(response) + self.printtier(response) self.short_memory.add( role=self.agent_name, content=response @@ -1884,9 +1910,7 @@ class Agent: full_memory = self.short_memory.return_history_as_string() prompt_tokens = self.tokenizer.count_tokens(full_memory) completion_tokens = self.tokenizer.count_tokens(response) - total_tokens = self.tokenizer.count_tokens( - prompt_tokens + completion_tokens - ) + total_tokens = prompt_tokens + completion_tokens logger.info("Logging step metadata...") diff --git a/swarms/tools/json_utils.py b/swarms/tools/json_utils.py index 0902d2c7..1ccbfc36 100644 --- a/swarms/tools/json_utils.py +++ b/swarms/tools/json_utils.py @@ -30,9 +30,12 @@ def extract_json_from_str(response: str): Raises: ValueError: If the string does not contain a valid JSON object. """ - json_start = response.index("{") - json_end = response.rfind("}") - return json.loads(response[json_start : json_end + 1]) + try: + json_start = response.index("{") + json_end = response.rfind("}") + return json.loads(response[json_start : json_end + 1]) + except Exception as e: + raise ValueError("No valid JSON structure found in the input string") def str_to_json(response: str, indent: int = 3):