parent
a16a584093
commit
3292174dbc
@ -0,0 +1,17 @@
|
|||||||
|
from swarms import Agent, OpenAIChat
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="API Requester",
|
||||||
|
agent_description="This agent is responsible for making API requests.",
|
||||||
|
system_prompt="You're a helpful API Requester agent. ",
|
||||||
|
llm=OpenAIChat(),
|
||||||
|
autosave=True,
|
||||||
|
max_loops="auto",
|
||||||
|
dashboard=True,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Run the agent
|
||||||
|
out = agent.run("Create an api request to OpenAI in python.")
|
||||||
|
print(out)
|
@ -0,0 +1,219 @@
|
|||||||
|
# Agent that picks up your intent
|
||||||
|
# Depending on your intent it routes you to an agent that can help you with your request.
|
||||||
|
# Account management agent and product support agent
|
||||||
|
# Account Management Agent --> Talk about the user, their account. Just understand the user's intent and route them to the right agent.
|
||||||
|
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from swarms import BaseLLM, base_model_to_openai_function
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
## Pydantic model for the tool schema
|
||||||
|
class HASSchema(BaseModel):
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
title="Name",
|
||||||
|
description="The name of the agent to send the task to.",
|
||||||
|
)
|
||||||
|
task: str = Field(
|
||||||
|
...,
|
||||||
|
title="Task",
|
||||||
|
description="The task to send to the agent.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
swarm_schema = base_model_to_openai_function(HASSchema, output_str=True)
|
||||||
|
|
||||||
|
ACCOUNT_MANAGEMENT_SYSTEM_PROMPT = """
|
||||||
|
|
||||||
|
You are an Account Management Agent. Your primary role is to engage with users regarding their accounts. Your main tasks include understanding the user's intent, addressing their immediate needs, and routing them to the appropriate agent for further assistance. Be simple and direct in your communication.
|
||||||
|
|
||||||
|
When a user contacts you, start by greeting them and asking how you can assist with their account. Listen carefully to their concerns, questions, or issues. If the user provides information that is specific to their account, acknowledge it and ask any necessary follow-up questions to clarify their needs. Ensure that you fully understand their intent before proceeding.
|
||||||
|
|
||||||
|
Once you have a clear understanding of the user's request or issue, determine the best course of action. If you can resolve the issue yourself, do so efficiently. If the issue requires specialized assistance, explain to the user that you will route them to the appropriate agent who can help further. Ensure the user feels heard and understood throughout the process.
|
||||||
|
|
||||||
|
Your ultimate goal is to provide a seamless and positive experience for the user by effectively managing their inquiries and directing them to the right resource for resolution. Always maintain a polite and professional tone, and ensure that the user feels supported and valued.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PRODUCT_SUPPORT_QA_SYSTEM_PROMPT = """
|
||||||
|
|
||||||
|
|
||||||
|
You are a Product Support Agent.
|
||||||
|
Your primary role is to provide assistance to users who have questions or issues related to the product. Your main tasks include understanding the user's needs, providing accurate information, and resolving any problems they may encounter. Be clear and concise in your communication.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class llama3Hosted(BaseLLM):
|
||||||
|
"""
|
||||||
|
A class representing a hosted version of the Llama3 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name or path of the Llama3 model to use.
|
||||||
|
temperature (float): The temperature parameter for generating responses.
|
||||||
|
max_tokens (int): The maximum number of tokens in the generated response.
|
||||||
|
system_prompt (str): The system prompt to use for generating responses.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (str): The name or path of the Llama3 model.
|
||||||
|
temperature (float): The temperature parameter for generating responses.
|
||||||
|
max_tokens (int): The maximum number of tokens in the generated response.
|
||||||
|
system_prompt (str): The system prompt for generating responses.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run(task, *args, **kwargs): Generates a response for the given task.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
temperature: float = 0.8,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
system_prompt: str = "You are a helpful assistant.",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Generates a response for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The user's task or input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response from the Llama3 model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = "http://34.204.8.31:30001/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": self.system_prompt},
|
||||||
|
{"role": "user", "content": task},
|
||||||
|
],
|
||||||
|
"stop_token_ids": [128009, 128001],
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
response = requests.request(
|
||||||
|
"POST", url, headers=headers, data=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assistant_message = response_json["choices"][0]["message"][
|
||||||
|
"content"
|
||||||
|
]
|
||||||
|
|
||||||
|
return assistant_message
|
||||||
|
|
||||||
|
|
||||||
|
def select_agent_and_send_task(name: str = None, task: str = None):
|
||||||
|
"""
|
||||||
|
Select an agent and send a task to them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the agent to send the task to.
|
||||||
|
task (str): The task to send to the agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the agent.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if name == "Product Support Agent":
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Product Support Agent",
|
||||||
|
system_prompt=PRODUCT_SUPPORT_QA_SYSTEM_PROMPT,
|
||||||
|
llm=llama3Hosted(),
|
||||||
|
max_loops=2,
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
streaming_on=True,
|
||||||
|
verbose=True,
|
||||||
|
output_type=str,
|
||||||
|
metadata_output_type="json",
|
||||||
|
function_calling_format_type="OpenAI",
|
||||||
|
function_calling_type="json",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return "Invalid agent name. Please select 'Account Management Agent' or 'Product Support Agent'."
|
||||||
|
|
||||||
|
response = agent.run(task)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_then_activate_agent(json_data: str):
|
||||||
|
"""
|
||||||
|
Parse the JSON data and activate the appropriate agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_data (str): The JSON data containing the agent name and task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the agent.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(json_data)
|
||||||
|
name = data.get("name")
|
||||||
|
task = data.get("task")
|
||||||
|
|
||||||
|
response = select_agent_and_send_task(name, task)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return "Invalid JSON data."
|
||||||
|
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Account Management Agent",
|
||||||
|
system_prompt=ACCOUNT_MANAGEMENT_SYSTEM_PROMPT,
|
||||||
|
# sop_list=[GLOSSARY_PROMPTS, FEW_SHORT_PROMPTS],
|
||||||
|
# sop=list_tool_schemas_json,
|
||||||
|
llm=llama3Hosted(
|
||||||
|
max_tokens=3000,
|
||||||
|
),
|
||||||
|
max_loops="auto",
|
||||||
|
interactive=True,
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
streaming_on=True,
|
||||||
|
# interactive=True,
|
||||||
|
# tools=[search_weather], # or list of tools
|
||||||
|
verbose=True,
|
||||||
|
# Set the output type to the tool schema which is a BaseModel
|
||||||
|
list_base_models=[HASSchema],
|
||||||
|
output_type=str, # or dict, or str
|
||||||
|
metadata_output_type="json",
|
||||||
|
# List of schemas that the agent can handle
|
||||||
|
function_calling_format_type="OpenAI",
|
||||||
|
function_calling_type="json", # or soon yaml
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the agent to generate the person's information
|
||||||
|
generated_data = agent.run("I need help with my modem.")
|
||||||
|
parse_json_then_activate_agent(generated_data)
|
||||||
|
|
||||||
|
|
||||||
|
# Print the generated data
|
||||||
|
print(f"Generated data: {generated_data}")
|
@ -0,0 +1,19 @@
|
|||||||
|
from swarms import tool
|
||||||
|
|
||||||
|
|
||||||
|
# Create the wrapper to wrap the function
|
||||||
|
@tool(
|
||||||
|
name="Geo Coordinates Locator",
|
||||||
|
description=("Locates geo coordinates with a city and or zip code"),
|
||||||
|
return_string=False,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
def send_api_request_to_get_geo_coordinates(
|
||||||
|
city: str = None, zip: int = None
|
||||||
|
):
|
||||||
|
return "Test"
|
||||||
|
|
||||||
|
|
||||||
|
# Run the function to get the schema
|
||||||
|
out = send_api_request_to_get_geo_coordinates()
|
||||||
|
print(out)
|
@ -0,0 +1,81 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from swarms import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class llama3Hosted(BaseLLM):
|
||||||
|
"""
|
||||||
|
A class representing a hosted version of the Llama3 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): The name or path of the Llama3 model to use.
|
||||||
|
temperature (float): The temperature parameter for generating responses.
|
||||||
|
max_tokens (int): The maximum number of tokens in the generated response.
|
||||||
|
system_prompt (str): The system prompt to use for generating responses.
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (str): The name or path of the Llama3 model.
|
||||||
|
temperature (float): The temperature parameter for generating responses.
|
||||||
|
max_tokens (int): The maximum number of tokens in the generated response.
|
||||||
|
system_prompt (str): The system prompt for generating responses.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run(task, *args, **kwargs): Generates a response for the given task.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
temperature: float = 0.8,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
system_prompt: str = "You are a helpful assistant.",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Generates a response for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The user's task or input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response from the Llama3 model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = "http://34.204.8.31:30001/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": self.system_prompt},
|
||||||
|
{"role": "user", "content": task},
|
||||||
|
],
|
||||||
|
"stop_token_ids": [128009, 128001],
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
response = requests.request(
|
||||||
|
"POST", url, headers=headers, data=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assistant_message = response_json["choices"][0]["message"][
|
||||||
|
"content"
|
||||||
|
]
|
||||||
|
|
||||||
|
return assistant_message
|
@ -1,191 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
||||||
|
|
||||||
|
|
||||||
class MPT7B:
|
|
||||||
"""
|
|
||||||
MPT class for generating text using a pre-trained model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): Name of the model to use.
|
|
||||||
tokenizer_name (str): Name of the tokenizer to use.
|
|
||||||
max_tokens (int): Maximum number of tokens to generate.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model_name (str): Name of the model to use.
|
|
||||||
tokenizer_name (str): Name of the tokenizer to use.
|
|
||||||
tokenizer (transformers.AutoTokenizer): Tokenizer object.
|
|
||||||
model (transformers.AutoModelForCausalLM): Model object.
|
|
||||||
pipe (transformers.pipelines.TextGenerationPipeline): Text generation pipeline.
|
|
||||||
max_tokens (int): Maximum number of tokens to generate.
|
|
||||||
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
|
||||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
|
||||||
'Once upon a time in a land far, far away...'
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
tokenizer_name: str,
|
|
||||||
max_tokens: int = 100,
|
|
||||||
):
|
|
||||||
# Loading model and tokenizer details
|
|
||||||
self.model_name = model_name
|
|
||||||
self.tokenizer_name = tokenizer_name
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
config = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name, trust_remote_code=True
|
|
||||||
).config
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name, config=config, trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initializing a text-generation pipeline
|
|
||||||
self.pipe = pipeline(
|
|
||||||
"text-generation",
|
|
||||||
model=self.model,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
device="cuda:0",
|
|
||||||
)
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Run the model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): Task to run.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
|
||||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
|
||||||
'Once upon a time in a land far, far away...'
|
|
||||||
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
|
|
||||||
['In the deep jungles,',
|
|
||||||
'At the heart of the city,']
|
|
||||||
>>> mpt_instance.freeze_model()
|
|
||||||
>>> mpt_instance.unfreeze_model()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
if task == "generate":
|
|
||||||
return self.generate(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Task '{task}' not recognized!")
|
|
||||||
|
|
||||||
async def run_async(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Run the model asynchronously
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): Task to run.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
|
||||||
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
|
|
||||||
'Once upon a time in a land far, far away...'
|
|
||||||
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
|
|
||||||
['In the deep jungles,',
|
|
||||||
'At the heart of the city,']
|
|
||||||
>>> mpt_instance.freeze_model()
|
|
||||||
>>> mpt_instance.unfreeze_model()
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Wrapping synchronous calls with async
|
|
||||||
return self.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def generate(self, prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Generate Text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): Prompt to generate text from.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
||||||
return self.pipe(
|
|
||||||
prompt,
|
|
||||||
max_new_tokens=self.max_tokens,
|
|
||||||
do_sample=True,
|
|
||||||
use_cache=True,
|
|
||||||
)[0]["generated_text"]
|
|
||||||
|
|
||||||
async def generate_async(self, prompt: str) -> str:
|
|
||||||
"""Generate Async"""
|
|
||||||
return self.generate(prompt)
|
|
||||||
|
|
||||||
def __call__(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""Call the model"""
|
|
||||||
return self.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
async def __call_async__(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""Call the model asynchronously""" ""
|
|
||||||
return await self.run_async(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def batch_generate(
|
|
||||||
self, prompts: list, temperature: float = 1.0
|
|
||||||
) -> list:
|
|
||||||
"""Batch generate text"""
|
|
||||||
self.logger.info(f"Generating text for {len(prompts)} prompts...")
|
|
||||||
results = []
|
|
||||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
||||||
for prompt in prompts:
|
|
||||||
result = self.pipe(
|
|
||||||
prompt,
|
|
||||||
max_new_tokens=self.max_tokens,
|
|
||||||
do_sample=True,
|
|
||||||
use_cache=True,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
results.append(result[0]["generated_text"])
|
|
||||||
return results
|
|
||||||
|
|
||||||
def unfreeze_model(self):
|
|
||||||
"""Unfreeze the model"""
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param.requires_grad = True
|
|
||||||
self.logger.info("Model has been unfrozen.")
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage:
|
|
||||||
# mpt_instance = MPT(
|
|
||||||
# "mosaicml/mpt-7b-storywriter", "EleutherAI/gpt-neox-20b", max_tokens=150
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # For synchronous calls
|
|
||||||
# print(mpt_instance("generate", "Once upon a time in a land far, far away..."))
|
|
||||||
|
|
||||||
# For asynchronous calls, use an event loop or similar async framework
|
|
||||||
# For example:
|
|
||||||
# # import asyncio
|
|
||||||
# # asyncio.run(mpt_instance.__call_async__("generate", "Once upon a time in a land far, far away..."))
|
|
||||||
# # Example usage:
|
|
||||||
# mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
|
|
||||||
|
|
||||||
# # For synchronous calls
|
|
||||||
# print(mpt_instance("generate", "Once upon a time in a land far, far away..."))
|
|
||||||
# print(mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7))
|
|
||||||
|
|
||||||
# # Freezing and unfreezing the model
|
|
||||||
# mpt_instance.freeze_model()
|
|
||||||
# mpt_instance.unfreeze_model()
|
|
@ -1,167 +0,0 @@
|
|||||||
import base64
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import uuid
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
stable_api_key = os.environ.get("STABLE_API_KEY")
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
|
||||||
"""
|
|
||||||
A class to interact with the Stable Diffusion API for generating images from text prompts.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
-----------
|
|
||||||
api_key : str
|
|
||||||
The API key for accessing the Stable Diffusion API.
|
|
||||||
api_host : str
|
|
||||||
The host URL for the Stable Diffusion API.
|
|
||||||
engine_id : str
|
|
||||||
The engine ID for the Stable Diffusion API.
|
|
||||||
cfg_scale : int
|
|
||||||
Configuration scale for image generation.
|
|
||||||
height : int
|
|
||||||
The height of the generated image.
|
|
||||||
width : int
|
|
||||||
The width of the generated image.
|
|
||||||
samples : int
|
|
||||||
The number of samples to generate.
|
|
||||||
steps : int
|
|
||||||
The number of steps for the generation process.
|
|
||||||
output_dir : str
|
|
||||||
Directory where the generated images will be saved.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
--------
|
|
||||||
__init__(self, api_key: str, api_host: str, cfg_scale: int, height: int, width: int, samples: int, steps: int):
|
|
||||||
Initializes the StableDiffusion instance with provided parameters.
|
|
||||||
|
|
||||||
generate_image(self, task: str) -> List[str]:
|
|
||||||
Generates an image based on the provided text prompt and returns the paths of the saved images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str = stable_api_key,
|
|
||||||
api_host: str = "https://api.stability.ai",
|
|
||||||
cfg_scale: int = 7,
|
|
||||||
height: int = 1024,
|
|
||||||
width: int = 1024,
|
|
||||||
samples: int = 1,
|
|
||||||
steps: int = 30,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the StableDiffusion class with API configurations.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
api_key : str
|
|
||||||
The API key for accessing the Stable Diffusion API.
|
|
||||||
api_host : str
|
|
||||||
The host URL for the Stable Diffusion API.
|
|
||||||
cfg_scale : int
|
|
||||||
Configuration scale for image generation.
|
|
||||||
height : int
|
|
||||||
The height of the generated image.
|
|
||||||
width : int
|
|
||||||
The width of the generated image.
|
|
||||||
samples : int
|
|
||||||
The number of samples to generate.
|
|
||||||
steps : int
|
|
||||||
The number of steps for the generation process.
|
|
||||||
"""
|
|
||||||
self.api_key = api_key
|
|
||||||
self.api_host = api_host
|
|
||||||
self.engine_id = "stable-diffusion-v1-6"
|
|
||||||
self.cfg_scale = cfg_scale
|
|
||||||
self.height = height
|
|
||||||
self.width = width
|
|
||||||
self.samples = samples
|
|
||||||
self.steps = steps
|
|
||||||
self.headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
self.output_dir = "images"
|
|
||||||
os.makedirs(self.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
def run(self, task: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Generates an image based on a given text prompt.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
task : str
|
|
||||||
The text prompt based on which the image will be generated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
List[str]:
|
|
||||||
A list of file paths where the generated images are saved.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
-------
|
|
||||||
Exception:
|
|
||||||
If the API request fails and returns a non-200 response.
|
|
||||||
"""
|
|
||||||
response = requests.post(
|
|
||||||
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
|
|
||||||
headers=self.headers,
|
|
||||||
json={
|
|
||||||
"text_prompts": [{"text": task}],
|
|
||||||
"cfg_scale": self.cfg_scale,
|
|
||||||
"height": self.height,
|
|
||||||
"width": self.width,
|
|
||||||
"samples": self.samples,
|
|
||||||
"steps": self.steps,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(f"Non-200 response: {response.text}")
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
image_paths = []
|
|
||||||
for i, image in enumerate(data["artifacts"]):
|
|
||||||
unique_id = uuid.uuid4() # Generate a unique identifier
|
|
||||||
image_path = os.path.join(
|
|
||||||
self.output_dir, f"{unique_id}_v1_txt2img_{i}.png"
|
|
||||||
)
|
|
||||||
with open(image_path, "wb") as f:
|
|
||||||
f.write(base64.b64decode(image["base64"]))
|
|
||||||
image_paths.append(image_path)
|
|
||||||
|
|
||||||
return image_paths
|
|
||||||
|
|
||||||
def generate_and_move_image(self, prompt, iteration, folder_path):
|
|
||||||
"""
|
|
||||||
Generates an image based on the given prompt and moves it to the specified folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): The prompt used to generate the image.
|
|
||||||
iteration (int): The iteration number.
|
|
||||||
folder_path (str): The path to the folder where the image will be moved.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path of the moved image.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Generate the image
|
|
||||||
image_paths = self.run(prompt)
|
|
||||||
if not image_paths:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Move the image to the specified folder
|
|
||||||
src_image_path = image_paths[0]
|
|
||||||
dst_image_path = os.path.join(
|
|
||||||
folder_path, f"image_{iteration}.jpg"
|
|
||||||
)
|
|
||||||
shutil.move(src_image_path, dst_image_path)
|
|
||||||
return dst_image_path
|
|
@ -1,62 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import timm
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from swarms.models.base_multimodal_model import BaseMultiModalModel
|
|
||||||
|
|
||||||
|
|
||||||
class TimmModel(BaseMultiModalModel):
|
|
||||||
"""
|
|
||||||
TimmModel is a class that wraps the timm library to provide a consistent
|
|
||||||
interface for creating and running models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: A string representing the name of the model to be created.
|
|
||||||
pretrained: A boolean indicating whether to use a pretrained model.
|
|
||||||
in_chans: An integer representing the number of input channels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A TimmModel instance.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
model = TimmModel('resnet18', pretrained=True, in_chans=3)
|
|
||||||
output_shape = model(input_tensor)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
pretrained: bool,
|
|
||||||
in_chans: int,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.model_name = model_name
|
|
||||||
self.pretrained = pretrained
|
|
||||||
self.in_chans = in_chans
|
|
||||||
self.models = self._get_supported_models()
|
|
||||||
|
|
||||||
def _get_supported_models(self) -> List[str]:
|
|
||||||
"""Retrieve the list of supported models from timm."""
|
|
||||||
return timm.list_models()
|
|
||||||
|
|
||||||
def __call__(self, task: Tensor, *args, **kwargs) -> torch.Size:
|
|
||||||
"""
|
|
||||||
Create and run a model specified by `model_info` on `input_tensor`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_info: An instance of TimmModelInfo containing model specifications.
|
|
||||||
input_tensor: A torch tensor representing the input data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The shape of the output from the model.
|
|
||||||
"""
|
|
||||||
model = timm.create_model(self.model_name, *args, **kwargs)
|
|
||||||
return model(task)
|
|
||||||
|
|
||||||
def list_models(self):
|
|
||||||
return timm.list_models()
|
|
@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from beartype import beartype
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.base_swarm import BaseSwarm
|
||||||
|
from swarms.utils.loguru_logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class HiearchicalSwarm(BaseSwarm):
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
director: Agent = None,
|
||||||
|
agents: List[Agent] = None,
|
||||||
|
max_loops: int = 1,
|
||||||
|
long_term_memory_system: BaseSwarm = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.director = director
|
||||||
|
self.agents = agents
|
||||||
|
self.max_loops = max_loops
|
||||||
|
self.long_term_memory_system = long_term_memory_system
|
||||||
|
|
||||||
|
# Set the director to max_one loop
|
||||||
|
self.director.max_loops = 1
|
||||||
|
|
||||||
|
# Set the long term memory system of every agent to long term memory system
|
||||||
|
if long_term_memory_system is True:
|
||||||
|
for agent in agents:
|
||||||
|
agent.long_term_memory = long_term_memory_system
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def parse_function_activate_agent(
|
||||||
|
self, json_data: str = None, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parse the JSON data and activate the selected agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_data (str): The JSON data containing the agent name and task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the activated agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.JSONDecodeError: If the JSON data is invalid.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(json_data)
|
||||||
|
name = data.get("name")
|
||||||
|
task = data.get("task")
|
||||||
|
|
||||||
|
response = self.select_agent_and_send_task(
|
||||||
|
name, task, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error("Invalid JSON data, try again.")
|
||||||
|
raise json.JSONDecodeError
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def select_agent_and_send_task(
|
||||||
|
self, name: str = None, task: str = None, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Select an agent from the list and send a task to them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the agent to send the task to.
|
||||||
|
task (str): The task to send to the agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the agent name is not found in the list of agents.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check to see if the agent name is in the list of agents
|
||||||
|
if name in self.agents:
|
||||||
|
agent = self.agents[name]
|
||||||
|
else:
|
||||||
|
return "Invalid agent name. Please select 'Account Management Agent' or 'Product Support Agent'."
|
||||||
|
|
||||||
|
response = agent.run(task, *args, **kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def run(self, task: str = None, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the hierarchical swarm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to send to the director agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the director agent.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs while running the swarm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loop = 0
|
||||||
|
|
||||||
|
# While the loop is less than max loops
|
||||||
|
while loop < self.max_loops:
|
||||||
|
# Run the director
|
||||||
|
response = self.director.run(task, *args, **kwargs)
|
||||||
|
|
||||||
|
# Run agents
|
||||||
|
response = self.parse_function_activate_agent(response)
|
||||||
|
|
||||||
|
loop += 1
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
raise e
|
Loading…
Reference in new issue