Merge pull request #630 from Occupying-Mars/krishna/openai_assistant
base openai_assistantpull/662/head
commit
89567acd15
@ -0,0 +1,135 @@
|
||||
# OpenAI Assistant
|
||||
|
||||
The OpenAI Assistant class provides a wrapper around OpenAI's Assistants API, integrating it with the swarms framework.
|
||||
|
||||
## Overview
|
||||
|
||||
The `OpenAIAssistant` class allows you to create and interact with OpenAI Assistants, providing a simple interface for:
|
||||
|
||||
- Creating assistants with specific roles and capabilities
|
||||
- Adding custom functions that the assistant can call
|
||||
- Managing conversation threads
|
||||
- Handling tool calls and function execution
|
||||
- Getting responses from the assistant
|
||||
|
||||
## Insstallation
|
||||
|
||||
```bash
|
||||
pip install swarms
|
||||
```
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
|
||||
from swarms import OpenAIAssistant
|
||||
|
||||
#Create an assistant
|
||||
assistant = OpenAIAssistant(
|
||||
name="Math Tutor",
|
||||
instructions="You are a helpful math tutor.",
|
||||
model="gpt-4o",
|
||||
tools=[{"type": "code_interpreter"}]
|
||||
)
|
||||
|
||||
#Run a Task
|
||||
response = assistant.run("Solve the equation: 3x + 11 = 14")
|
||||
print(response)
|
||||
|
||||
# Continue the conversation in the same thread
|
||||
follow_up = assistant.run("Now explain how you solved it")
|
||||
print(follow_up)
|
||||
```
|
||||
|
||||
## Function Calling
|
||||
|
||||
The assistant supports custom function integration:
|
||||
|
||||
```python
|
||||
|
||||
def get_weather(location: str, unit: str = "celsius") -> str:
|
||||
# Mock weather function
|
||||
return f"The weather in {location} is 22 degrees {unit}"
|
||||
|
||||
# Add function to assistant
|
||||
assistant.add_function(
|
||||
description="Get the current weather in a location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City name"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Constructor
|
||||
|
||||
```python
|
||||
OpenAIAssistant(
|
||||
name: str,
|
||||
instructions: Optional[str] = None,
|
||||
model: str = "gpt-4o",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
)
|
||||
```
|
||||
|
||||
### Methods
|
||||
|
||||
#### run(task: str) -> str
|
||||
Sends a task to the assistant and returns its response. The conversation thread is maintained between calls.
|
||||
|
||||
#### add_function(func: Callable, description: str, parameters: Dict[str, Any]) -> None
|
||||
Adds a callable function that the assistant can use during conversations.
|
||||
|
||||
#### add_message(content: str, file_ids: Optional[List[str]] = None) -> None
|
||||
Adds a message to the current conversation thread.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The assistant implements robust error handling:
|
||||
- Retries on rate limits
|
||||
- Graceful handling of API errors
|
||||
- Clear error messages for debugging
|
||||
- Status monitoring for runs and completions
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. Thread Management
|
||||
- Use the same assistant instance for related conversations
|
||||
- Create new instances for unrelated tasks
|
||||
- Monitor thread status during long-running operations
|
||||
|
||||
2. Function Integration
|
||||
- Keep functions simple and focused
|
||||
- Provide clear descriptions and parameter schemas
|
||||
- Handle errors gracefully in custom functions
|
||||
- Test functions independently before integration
|
||||
|
||||
3. Performance
|
||||
- Reuse assistant instances when possible
|
||||
- Monitor and handle rate limits appropriately
|
||||
- Use appropriate polling intervals for status checks
|
||||
- Consider implementing timeouts for long-running operations
|
||||
|
||||
## References
|
||||
|
||||
- [OpenAI Assistants API Documentation](https://platform.openai.com/docs/assistants/overview)
|
||||
- [OpenAI Function Calling Guide](https://platform.openai.com/docs/guides/function-calling)
|
||||
- [OpenAI Rate Limits](https://platform.openai.com/docs/guides/rate-limits)
|
||||
|
||||
|
||||
|
@ -0,0 +1,264 @@
|
||||
from typing import Optional, List, Dict, Any, Callable
|
||||
import time
|
||||
from openai import OpenAI
|
||||
from swarms.structs.agent import Agent
|
||||
import json
|
||||
|
||||
class OpenAIAssistant(Agent):
|
||||
"""
|
||||
OpenAI Assistant wrapper for the swarms framework.
|
||||
Integrates OpenAI's Assistants API with the swarms architecture.
|
||||
|
||||
Example:
|
||||
>>> assistant = OpenAIAssistant(
|
||||
... name="Math Tutor",
|
||||
... instructions="You are a personal math tutor.",
|
||||
... model="gpt-4o",
|
||||
... tools=[{"type": "code_interpreter"}]
|
||||
... )
|
||||
>>> response = assistant.run("Solve 3x + 11 = 14")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
instructions: Optional[str] = None,
|
||||
model: str = "gpt-4o",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Initialize the OpenAI Assistant.
|
||||
|
||||
Args:
|
||||
name: Name of the assistant
|
||||
instructions: System instructions for the assistant
|
||||
model: Model to use (default: gpt-4-turbo-preview)
|
||||
tools: List of tools to enable (code_interpreter, retrieval)
|
||||
file_ids: List of file IDs to attach
|
||||
metadata: Additional metadata
|
||||
functions: List of custom functions to make available
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Initialize tools list with any provided functions
|
||||
self.tools = tools or []
|
||||
if functions:
|
||||
for func in functions:
|
||||
self.tools.append({
|
||||
"type": "function",
|
||||
"function": func
|
||||
})
|
||||
|
||||
# Create the OpenAI Assistant
|
||||
self.client = OpenAI()
|
||||
self.assistant = self.client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
model=model,
|
||||
tools=self.tools,
|
||||
file_ids=file_ids or [],
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Store available functions
|
||||
self.available_functions: Dict[str, Callable] = {}
|
||||
|
||||
def add_function(self, func: Callable, description: str, parameters: Dict[str, Any]) -> None:
|
||||
"""Add a function that the assistant can call.
|
||||
|
||||
Args:
|
||||
func: The function to make available to the assistant
|
||||
description: Description of what the function does
|
||||
parameters: JSON schema describing the function parameters
|
||||
"""
|
||||
func_dict = {
|
||||
"name": func.__name__,
|
||||
"description": description,
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
# Add to tools list
|
||||
self.tools.append({
|
||||
"type": "function",
|
||||
"function": func_dict
|
||||
})
|
||||
|
||||
# Store function reference
|
||||
self.available_functions[func.__name__] = func
|
||||
|
||||
# Update assistant with new tools
|
||||
self.assistant = self.client.beta.assistants.update(
|
||||
assistant_id=self.assistant.id,
|
||||
tools=self.tools
|
||||
)
|
||||
|
||||
def _handle_tool_calls(self, run, thread_id: str) -> None:
|
||||
"""Handle any required tool calls during a run.
|
||||
|
||||
This method processes any tool calls required by the assistant during execution.
|
||||
It extracts function calls, executes them with provided arguments, and submits
|
||||
the results back to the assistant.
|
||||
|
||||
Args:
|
||||
run: The current run object from the OpenAI API
|
||||
thread_id: ID of the current conversation thread
|
||||
|
||||
Returns:
|
||||
Updated run object after processing tool calls
|
||||
|
||||
Raises:
|
||||
Exception: If there are errors executing the tool calls
|
||||
"""
|
||||
while run.status == "requires_action":
|
||||
tool_calls = run.required_action.submit_tool_outputs.tool_calls
|
||||
tool_outputs = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.type == "function":
|
||||
# Get function details
|
||||
function_name = tool_call.function.name
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
# Call function if available
|
||||
if function_name in self.available_functions:
|
||||
function_response = self.available_functions[function_name](**function_args)
|
||||
tool_outputs.append({
|
||||
"tool_call_id": tool_call.id,
|
||||
"output": str(function_response)
|
||||
})
|
||||
|
||||
# Submit outputs back to the run
|
||||
run = self.client.beta.threads.runs.submit_tool_outputs(
|
||||
thread_id=thread_id,
|
||||
run_id=run.id,
|
||||
tool_outputs=tool_outputs
|
||||
)
|
||||
|
||||
# Wait for processing
|
||||
run = self._wait_for_run(run)
|
||||
|
||||
return run
|
||||
|
||||
def _wait_for_run(self, run) -> Any:
|
||||
"""Wait for a run to complete and handle any required actions.
|
||||
|
||||
This method polls the OpenAI API to check the status of a run until it completes
|
||||
or fails. It handles intermediate states like required actions and implements
|
||||
exponential backoff.
|
||||
|
||||
Args:
|
||||
run: The run object to monitor
|
||||
|
||||
Returns:
|
||||
The completed run object
|
||||
|
||||
Raises:
|
||||
Exception: If the run fails or expires
|
||||
"""
|
||||
while True:
|
||||
run = self.client.beta.threads.runs.retrieve(
|
||||
thread_id=run.thread_id,
|
||||
run_id=run.id
|
||||
)
|
||||
|
||||
if run.status == "completed":
|
||||
break
|
||||
elif run.status == "requires_action":
|
||||
run = self._handle_tool_calls(run, run.thread_id)
|
||||
if run.status == "completed":
|
||||
break
|
||||
elif run.status in ["failed", "expired"]:
|
||||
raise Exception(f"Run failed with status: {run.status}")
|
||||
|
||||
time.sleep(3) # Wait 3 seconds before checking again
|
||||
|
||||
return run
|
||||
|
||||
def _ensure_thread(self):
|
||||
"""Ensure a thread exists for the conversation.
|
||||
|
||||
This method checks if there is an active thread for the current conversation.
|
||||
If no thread exists, it creates a new one. This maintains conversation context
|
||||
across multiple interactions.
|
||||
|
||||
Side Effects:
|
||||
Sets self.thread if it doesn't exist
|
||||
"""
|
||||
if not self.thread:
|
||||
self.thread = self.client.beta.threads.create()
|
||||
|
||||
def add_message(self, content: str, file_ids: Optional[List[str]] = None) -> None:
|
||||
"""Add a message to the thread.
|
||||
|
||||
This method adds a new user message to the conversation thread. It ensures
|
||||
a thread exists before adding the message and handles file attachments.
|
||||
|
||||
Args:
|
||||
content: The text content of the message to add
|
||||
file_ids: Optional list of file IDs to attach to the message. These must be
|
||||
files that have been previously uploaded to OpenAI.
|
||||
|
||||
Side Effects:
|
||||
Creates a new thread if none exists
|
||||
Adds the message to the thread in OpenAI's system
|
||||
"""
|
||||
self._ensure_thread()
|
||||
self.client.beta.threads.messages.create(
|
||||
thread_id=self.thread.id,
|
||||
role="user",
|
||||
content=content,
|
||||
file_ids=file_ids or []
|
||||
)
|
||||
|
||||
def _get_response(self) -> str:
|
||||
"""Get the latest assistant response from the thread."""
|
||||
messages = self.client.beta.threads.messages.list(
|
||||
thread_id=self.thread.id,
|
||||
order="desc",
|
||||
limit=1
|
||||
)
|
||||
|
||||
if not messages.data:
|
||||
return ""
|
||||
|
||||
message = messages.data[0]
|
||||
if message.role == "assistant":
|
||||
return message.content[0].text.value
|
||||
return ""
|
||||
|
||||
def run(self, task: str, *args, **kwargs) -> str:
|
||||
"""Run a task using the OpenAI Assistant.
|
||||
|
||||
Args:
|
||||
task: The task or prompt to send to the assistant
|
||||
|
||||
Returns:
|
||||
The assistant's response as a string
|
||||
"""
|
||||
self._ensure_thread()
|
||||
|
||||
# Add the user message
|
||||
self.add_message(task)
|
||||
|
||||
# Create and run the assistant
|
||||
run = self.client.beta.threads.runs.create(
|
||||
thread_id=self.thread.id,
|
||||
assistant_id=self.assistant.id,
|
||||
instructions=self.instructions
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
run = self._wait_for_run(run)
|
||||
|
||||
# Only get and return the response if run completed successfully
|
||||
if run.status == "completed":
|
||||
return self._get_response()
|
||||
return ""
|
||||
|
||||
def call(self, task: str, *args, **kwargs) -> str:
|
||||
"""Alias for run() to maintain compatibility with different agent interfaces."""
|
||||
return self.run(task, *args, **kwargs)
|
@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Optional
|
||||
from swarms.structs.base_workflow import BaseWorkflow
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.utils.loguru_logger import logger
|
||||
|
||||
class AsyncWorkflow(BaseWorkflow):
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "AsyncWorkflow",
|
||||
agents: List[Agent] = None,
|
||||
max_workers: int = 5,
|
||||
dashboard: bool = False,
|
||||
autosave: bool = False,
|
||||
verbose: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(agents=agents, **kwargs)
|
||||
self.name = name
|
||||
self.agents = agents or []
|
||||
self.max_workers = max_workers
|
||||
self.dashboard = dashboard
|
||||
self.autosave = autosave
|
||||
self.verbose = verbose
|
||||
self.task_pool = []
|
||||
self.results = []
|
||||
self.loop = None
|
||||
|
||||
async def _execute_agent_task(self, agent: Agent, task: str) -> Any:
|
||||
"""Execute a single agent task asynchronously"""
|
||||
try:
|
||||
if self.verbose:
|
||||
logger.info(f"Agent {agent.agent_name} processing task: {task}")
|
||||
result = await agent.arun(task)
|
||||
if self.verbose:
|
||||
logger.info(f"Agent {agent.agent_name} completed task")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent {agent.agent_name}: {str(e)}")
|
||||
return str(e)
|
||||
|
||||
async def run(self, task: str) -> List[Any]:
|
||||
"""Run the workflow with all agents processing the task concurrently"""
|
||||
if not self.agents:
|
||||
raise ValueError("No agents provided to the workflow")
|
||||
|
||||
try:
|
||||
# Create tasks for all agents
|
||||
tasks = [self._execute_agent_task(agent, task) for agent in self.agents]
|
||||
|
||||
# Execute all tasks concurrently
|
||||
self.results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
if self.autosave:
|
||||
# TODO: Implement autosave logic here
|
||||
pass
|
||||
|
||||
return self.results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in workflow execution: {str(e)}")
|
||||
raise
|
Loading…
Reference in new issue