diff --git a/playground/agents/command_r_tool_agent.py b/playground/agents/command_r_tool_agent.py index cd9644db..9cbd73ad 100644 --- a/playground/agents/command_r_tool_agent.py +++ b/playground/agents/command_r_tool_agent.py @@ -16,6 +16,7 @@ model = AutoModelForCausalLM.from_pretrained( # Load the pre-trained model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) + # Initialize the schema for the person's information class APIExampleRequestSchema(BaseModel): endpoint: str = Field( @@ -31,22 +32,25 @@ class APIExampleRequestSchema(BaseModel): ..., description="The body of the example request" ) response: dict = Field( - ..., description="The expected response of the example request" + ..., + description="The expected response of the example request", ) + # Convert the schema to a JSON string api_example_schema = base_model_to_json(APIExampleRequestSchema) # Convert the schema to a JSON string # Define the task to generate a person's information -task = ( - "Generate an example API request using this code:\n" -) +task = "Generate an example API request using this code:\n" # Create an instance of the ToolAgent class agent = ToolAgent( name="Command R Tool Agent", - description="An agent that generates an API request using the Command R model.", + description=( + "An agent that generates an API request using the Command R" + " model." + ), model=model, tokenizer=tokenizer, json_schema=api_example_schema, @@ -56,4 +60,4 @@ agent = ToolAgent( generated_data = agent.run(task) # Print the generated data -print(f"Generated data: {generated_data}") \ No newline at end of file +print(f"Generated data: {generated_data}") diff --git a/playground/agents/jamba_tool_agent.py b/playground/agents/jamba_tool_agent.py index 1bc2666c..3ca293cd 100644 --- a/playground/agents/jamba_tool_agent.py +++ b/playground/agents/jamba_tool_agent.py @@ -16,6 +16,7 @@ model = AutoModelForCausalLM.from_pretrained( # Load the pre-trained model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) + # Initialize the schema for the person's information class APIExampleRequestSchema(BaseModel): endpoint: str = Field( @@ -31,22 +32,25 @@ class APIExampleRequestSchema(BaseModel): ..., description="The body of the example request" ) response: dict = Field( - ..., description="The expected response of the example request" + ..., + description="The expected response of the example request", ) + # Convert the schema to a JSON string api_example_schema = base_model_to_json(APIExampleRequestSchema) # Convert the schema to a JSON string # Define the task to generate a person's information -task = ( - "Generate an example API request using this code:\n" -) +task = "Generate an example API request using this code:\n" # Create an instance of the ToolAgent class agent = ToolAgent( name="Command R Tool Agent", - description="An agent that generates an API request using the Command R model.", + description=( + "An agent that generates an API request using the Command R" + " model." + ), model=model, tokenizer=tokenizer, json_schema=api_example_schema, @@ -56,4 +60,4 @@ agent = ToolAgent( generated_data = agent(task) # Print the generated data -print(f"Generated data: {generated_data}") \ No newline at end of file +print(f"Generated data: {generated_data}") diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index b053db3f..4637d332 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -32,6 +32,7 @@ from swarms.models.popular_llms import ( from swarms.models.popular_llms import ( ReplicateLLM as Replicate, ) +from swarms.models.popular_llms import OctoAIChat from swarms.models.qwen import QwenVLMultiModal # noqa: E402 from swarms.models.sampling_params import SamplingParams, SamplingType @@ -79,4 +80,5 @@ __all__ = [ "AudioModality", "ImageModality", "VideoModality", + "OctoAIChat", ] diff --git a/swarms/models/popular_llms.py b/swarms/models/popular_llms.py index 449080b5..2f043445 100644 --- a/swarms/models/popular_llms.py +++ b/swarms/models/popular_llms.py @@ -11,6 +11,7 @@ from langchain_community.llms import ( OpenAI, Replicate, ) +from langchain_community.llms.octoai_endpoint import OctoAIEndpoint class AnthropicChat(Anthropic): @@ -46,3 +47,8 @@ class AzureOpenAILLM(AzureChatOpenAI): class OpenAIChatLLM(OpenAIChat): def __call__(self, *args, **kwargs): return self.invoke(*args, **kwargs) + + +class OctoAIChat(OctoAIEndpoint): + def __call__(self, *args, **kwargs): + return self.invoke(*args, **kwargs) diff --git a/swarms/utils/__init__.py b/swarms/utils/__init__.py index 945479fe..329d95ec 100644 --- a/swarms/utils/__init__.py +++ b/swarms/utils/__init__.py @@ -40,6 +40,7 @@ from swarms.utils.remove_json_whitespace import ( remove_whitespace_from_yaml, ) from swarms.utils.save_logs import parse_log_file + # from swarms.utils.supervision_visualizer import MarkVisualizer from swarms.utils.try_except_wrapper import try_except_wrapper from swarms.utils.yaml_output_parser import YamlOutputParser