code quality

Former-commit-id: 9014683f9a
discord-bot-framework
Kye 1 year ago
parent bf3c6ac72c
commit 1739d44b37

@ -17,12 +17,15 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
class SwarmInput(BaseModel): class SwarmInput(BaseModel):
api_key: str api_key: str
objective: str objective: str
app = FastAPI() app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
async def startup(): async def startup():
redis_host = os.getenv("REDIS_HOST", "localhost") redis_host = os.getenv("REDIS_HOST", "localhost")
@ -31,6 +34,7 @@ async def startup():
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache", coder=JsonCoder()) FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache", coder=JsonCoder())
await FastAPILimiter.init(f"redis://{redis_host}:{redis_port}") await FastAPILimiter.init(f"redis://{redis_host}:{redis_port}")
@app.post("/chat", dependencies=[Depends(RateLimiter(times=2, minutes=1))]) @app.post("/chat", dependencies=[Depends(RateLimiter(times=2, minutes=1))])
@cache(expire=60) # Cache results for 1 minute @cache(expire=60) # Cache results for 1 minute
async def run(swarm_input: SwarmInput): async def run(swarm_input: SwarmInput):

@ -55,8 +55,6 @@ file_handler = FileHandler(handlers=handlers, path=BASE_DIR)
templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates") templates = Jinja2Templates(directory=BASE_DIR / "api" / "templates")
uploader = StaticUploader.from_settings( uploader = StaticUploader.from_settings(path=BASE_DIR / "static", endpoint="static")
path=BASE_DIR / "static", endpoint="static"
)
reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"] reload_dirs = [BASE_DIR / "core", BASE_DIR / "api"]

@ -11,8 +11,15 @@ from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel from pydantic import BaseModel
from api.olds.container import agent_manager, file_handler, reload_dirs, templates, uploader from api.olds.container import (
agent_manager,
file_handler,
reload_dirs,
templates,
uploader,
)
from api.olds.worker import get_task_result, start_worker, task_execute from api.olds.worker import get_task_result, start_worker, task_execute
# from env import settings # from env import settings
app = FastAPI() app = FastAPI()

@ -3,28 +3,34 @@ from langchain.llms import OpenAIChat
from swarms.agents import OmniModalAgent from swarms.agents import OmniModalAgent
# Setup # Setup
TOKEN = 'YOUR_DISCORD_BOT_TOKEN' TOKEN = "YOUR_DISCORD_BOT_TOKEN"
bot = commands.Bot(command_prefix='!') bot = commands.Bot(command_prefix="!")
# Initialize the OmniModalAgent # Initialize the OmniModalAgent
llm = OpenAIChat(model_name="gpt-4") llm = OpenAIChat(model_name="gpt-4")
agent = OmniModalAgent(llm) agent = OmniModalAgent(llm)
@bot.event @bot.event
async def on_ready(): async def on_ready():
print(f'We have logged in as {bot.user}') print(f"We have logged in as {bot.user}")
@bot.command() @bot.command()
async def greet(ctx): async def greet(ctx):
"""Greets the user.""" """Greets the user."""
await ctx.send(f'Hello, {ctx.author.name}!') await ctx.send(f"Hello, {ctx.author.name}!")
@bot.command() @bot.command()
async def run(ctx, *, description: str): async def run(ctx, *, description: str):
"""Generates a video based on the given description.""" """Generates a video based on the given description."""
response = agent.run(description) # Assuming the response provides information or a link to the generated video response = agent.run(
description
) # Assuming the response provides information or a link to the generated video
await ctx.send(response) await ctx.send(response)
@bot.command() @bot.command()
async def help_me(ctx): async def help_me(ctx):
"""Provides a list of commands and their descriptions.""" """Provides a list of commands and their descriptions."""
@ -35,4 +41,5 @@ async def help_me(ctx):
""" """
await ctx.send(help_text) await ctx.send(help_text)
bot.run(TOKEN) bot.run(TOKEN)

@ -1,4 +1,4 @@
#Import required libraries # Import required libraries
from gradio import Interface, Textbox, HTML from gradio import Interface, Textbox, HTML
import threading import threading
import os import os
@ -7,32 +7,36 @@ import base64
from langchain.llms import OpenAIChat from langchain.llms import OpenAIChat
from swarms.agents import OmniModalAgent from swarms.agents import OmniModalAgent
#Function to convert image to base64
# Function to convert image to base64
def image_to_base64(image_path): def image_to_base64(image_path):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode() return base64.b64encode(image_file.read()).decode()
#Function to get the most recently created image in the directory
# Function to get the most recently created image in the directory
def get_latest_image(): def get_latest_image():
list_of_files = glob.glob('./*.png') # Replace with your image file type list_of_files = glob.glob("./*.png") # Replace with your image file type
if not list_of_files: if not list_of_files:
return None return None
latest_file = max(list_of_files, key=os.path.getctime) latest_file = max(list_of_files, key=os.path.getctime)
return latest_file return latest_file
#Initialize your OmniModalAgent
# Initialize your OmniModalAgent
llm = OpenAIChat(model_name="gpt-4") # Replace with your actual initialization llm = OpenAIChat(model_name="gpt-4") # Replace with your actual initialization
agent = OmniModalAgent(llm) # Replace with your actual initialization agent = OmniModalAgent(llm) # Replace with your actual initialization
#Global variable to store chat history # Global variable to store chat history
chat_history = [] chat_history = []
#Function to update chat
# Function to update chat
def update_chat(user_input): def update_chat(user_input):
global chat_history global chat_history
chat_history.append({"type": "user", "content": user_input}) chat_history.append({"type": "user", "content": user_input})
#Get agent response # Get agent response
agent_response = agent.run(user_input) agent_response = agent.run(user_input)
# Handle the case where agent_response is not in the expected dictionary format # Handle the case where agent_response is not in the expected dictionary format
@ -48,38 +52,43 @@ def update_chat(user_input):
return render_chat(chat_history) return render_chat(chat_history)
#Function to render chat as HTML
# Function to render chat as HTML
def render_chat(chat_history): def render_chat(chat_history):
chat_str = "<div style='max-height:400px;overflow-y:scroll;'>" chat_str = "<div style='max-height:400px;overflow-y:scroll;'>"
for message in chat_history: for message in chat_history:
if message['type'] == 'user': if message["type"] == "user":
chat_str += f"<p><strong>User:</strong> {message['content']}</p>" chat_str += f"<p><strong>User:</strong> {message['content']}</p>"
elif message['type'] == 'text': elif message["type"] == "text":
chat_str += f"<p><strong>Agent:</strong> {message['content']}</p>" chat_str += f"<p><strong>Agent:</strong> {message['content']}</p>"
elif message['type'] == 'image': elif message["type"] == "image":
img_path = os.path.join(".", message['content']) img_path = os.path.join(".", message["content"])
base64_img = image_to_base64(img_path) base64_img = image_to_base64(img_path)
chat_str += f"<p><strong>Agent:</strong> <img src='data:image/png;base64,{base64_img}' alt='image' width='200'/></p>" chat_str += f"<p><strong>Agent:</strong> <img src='data:image/png;base64,{base64_img}' alt='image' width='200'/></p>"
chat_str += "</div>" chat_str += "</div>"
return chat_str return chat_str
#Define Gradio interface
# Define Gradio interface
iface = Interface( iface = Interface(
fn=update_chat, fn=update_chat,
inputs=Textbox(label="Your Message", type="text"), inputs=Textbox(label="Your Message", type="text"),
outputs=HTML(label="Chat History"), outputs=HTML(label="Chat History"),
live=True live=True,
) )
#Function to update the chat display
# Function to update the chat display
def update_display(): def update_display():
global chat_history global chat_history
while True: while True:
iface.update(render_chat(chat_history)) iface.update(render_chat(chat_history))
#Run the update_display function in a separate thread
# Run the update_display function in a separate thread
threading.Thread(target=update_display).start() threading.Thread(target=update_display).start()
#Run Gradio interface # Run Gradio interface
iface.launch() iface.launch()

@ -1,32 +1,19 @@
from swarms import Model, Agent, vectorstore, tools, orchestrator from swarms import Model, Agent, vectorstore, tools, orchestrator
#1 model # 1 model
Model(openai) Model(openai)
#2 agent level # 2 agent level
Agent( Agent(model, vectorstore, tools)
model,
vectorstore,
tools
)
#3 worker infrastructure level # 3 worker infrastructure level
worker_node( worker_node(Agent, human_input, tools)
Agent,
human_input,
tools
)
#4 swarm level basically handling infrastructure for multiple worker node # 4 swarm level basically handling infrastructure for multiple worker node
swarm = orchestrator( swarm = orchestrator(worker_node, 100) # nodes
worker_node,
100 # nodes
)
#5 # 5
hivemind = Hivemind( hivemind = Hivemind(swarm * 100)
swarm * 100
)
#a market different pre built worker or boss agent that have access to different tools and memory, proompts # a market different pre built worker or boss agent that have access to different tools and memory, proompts

@ -1,22 +1,17 @@
from langchain.llms import OpenAIChat from langchain.llms import OpenAIChat
from swarms import Worker from swarms import Worker
llm = OpenAIChat( llm = OpenAIChat(model_name="gpt-4", openai_api_key="api-key", temperature=0.5)
model_name='gpt-4',
openai_api_key="api-key",
temperature=0.5
)
node = Worker( node = Worker(
llm=llm, llm=llm,
ai_name="Optimus Prime", ai_name="Optimus Prime",
ai_role="Worker in a swarm", ai_role="Worker in a swarm",
external_tools = None, external_tools=None,
human_in_the_loop = False, human_in_the_loop=False,
temperature = 0.5, temperature=0.5,
) )
task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times."
response = node.run(task) response = node.run(task)
print(response) print(response)

@ -8,13 +8,13 @@ swarm = HierarchicalSwarm(
use_vectorstore=False, use_vectorstore=False,
use_async=False, use_async=False,
human_in_the_loop=False, human_in_the_loop=False,
logging_enabled=False logging_enabled=False,
) )
#run the swarm with an objective # run the swarm with an objective
result = swarm.run("Design a new car") result = swarm.run("Design a new car")
#or huggingface # or huggingface
swarm = HierarchicalSwarm( swarm = HierarchicalSwarm(
model_type="huggingface", model_type="huggingface",
model_id="tiaueu/falcon", model_id="tiaueu/falcon",

@ -1,8 +1,6 @@
from swarms.agents import MultiModalAgent from swarms.agents import MultiModalAgent
load_dict = { load_dict = {"ImageCaptioning": "cuda"}
"ImageCaptioning": "cuda"
}
node = MultiModalAgent(load_dict) node = MultiModalAgent(load_dict)
@ -12,5 +10,5 @@ img = node.run_img("/image1", "What is this image about?")
chat = node.chat( chat = node.chat(
"What is your name? Generate a picture of yourself. What is this image about?", "What is your name? Generate a picture of yourself. What is this image about?",
streaming=True streaming=True,
) )

@ -1,11 +1,8 @@
#pip3 install exxa # pip3 install exxa
from exa import Inference from exa import Inference
from swarms.agents import OmniModalAgent from swarms.agents import OmniModalAgent
llm = Inference( llm = Inference(model_id="mistralai/Mistral-7B-v0.1", quantize=True)
model_id="mistralai/Mistral-7B-v0.1",
quantize=True
)
agent = OmniModalAgent(llm) agent = OmniModalAgent(llm)

@ -1,9 +1,6 @@
from swarms.models import Mistral from swarms.models import Mistral
model = Mistral( model = Mistral(device="cuda", use_flash_attention=True)
device="cuda",
use_flash_attention=True
)
prompt = "My favourite condiment is" prompt = "My favourite condiment is"
result = model.run(prompt) result = model.run(prompt)

@ -7,12 +7,12 @@ prompt2 = "Develop a self attention using pytorch"
task1 = Task("task1", prompt) task1 = Task("task1", prompt)
task2 = Task("task2", prompt2, parents=[task1]) task2 = Task("task2", prompt2, parents=[task1])
#add tasks to workflow # add tasks to workflow
workflow = NonLinearWorkflow(agent) workflow = NonLinearWorkflow(agent)
#add tasks to tree # add tasks to tree
workflow.add(task1) workflow.add(task1)
workflow.add(task2) workflow.add(task2)
#run # run
workflow.run() workflow.run()

@ -5,4 +5,3 @@ auto_scaler.start()
for i in range(100): for i in range(100):
auto_scaler.add_task(f"Task {i}") auto_scaler.add_task(f"Task {i}")

@ -1,11 +1,7 @@
from swarms import Orchestrator, Worker from swarms import Orchestrator, Worker
# Instantiate the Orchestrator with 10 agents # Instantiate the Orchestrator with 10 agents
orchestrator = Orchestrator( orchestrator = Orchestrator(Worker, agent_list=[Worker] * 10, task_queue=[])
Worker,
agent_list=[Worker]*10,
task_queue=[]
)
# Agent 1 sends a message to Agent 2 # Agent 1 sends a message to Agent 2
orchestrator.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!") orchestrator.chat(sender_id=1, receiver_id=2, message="Hello, Agent 2!")

@ -89,6 +89,7 @@ class DialogueSimulator:
return speaker.name, message return speaker.name, message
class BiddingDialogueAgent(DialogueAgent): class BiddingDialogueAgent(DialogueAgent):
def __init__( def __init__(
self, self,
@ -114,6 +115,7 @@ class BiddingDialogueAgent(DialogueAgent):
bid_string = self.model([SystemMessage(content=prompt)]).content bid_string = self.model([SystemMessage(content=prompt)]).content
return bid_string return bid_string
character_names = ["Donald Trump", "Kanye West", "Elizabeth Warren"] character_names = ["Donald Trump", "Kanye West", "Elizabeth Warren"]
topic = "transcontinental high speed rail" topic = "transcontinental high speed rail"
word_limit = 50 word_limit = 50
@ -203,8 +205,6 @@ for (
print(f"\n{character_system_message.content}") print(f"\n{character_system_message.content}")
class BidOutputParser(RegexParser): class BidOutputParser(RegexParser):
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return "Your response should be an integer delimited by angled brackets, like this: <int>." return "Your response should be an integer delimited by angled brackets, like this: <int>."
@ -214,6 +214,7 @@ bid_parser = BidOutputParser(
regex=r"<(\d+)>", output_keys=["bid"], default_output_key="bid" regex=r"<(\d+)>", output_keys=["bid"], default_output_key="bid"
) )
def generate_character_bidding_template(character_header): def generate_character_bidding_template(character_header):
bidding_template = f"""{character_header} bidding_template = f"""{character_header}
@ -232,6 +233,7 @@ def generate_character_bidding_template(character_header):
""" """
return bidding_template return bidding_template
character_bidding_templates = [ character_bidding_templates = [
generate_character_bidding_template(character_header) generate_character_bidding_template(character_header)
for character_header in character_headers for character_header in character_headers
@ -263,6 +265,7 @@ specified_topic = ChatOpenAI(temperature=1.0)(topic_specifier_prompt).content
print(f"Original topic:\n{topic}\n") print(f"Original topic:\n{topic}\n")
print(f"Detailed topic:\n{specified_topic}\n") print(f"Detailed topic:\n{specified_topic}\n")
@tenacity.retry( @tenacity.retry(
stop=tenacity.stop_after_attempt(2), stop=tenacity.stop_after_attempt(2),
wait=tenacity.wait_none(), # No waiting time between retries wait=tenacity.wait_none(), # No waiting time between retries
@ -280,6 +283,7 @@ def ask_for_bid(agent) -> str:
bid = int(bid_parser.parse(bid_string)["bid"]) bid = int(bid_parser.parse(bid_string)["bid"])
return bid return bid
def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int: def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int:
bids = [] bids = []
for agent in agents: for agent in agents:
@ -300,6 +304,7 @@ def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int:
print("\n") print("\n")
return idx return idx
characters = [] characters = []
for character_name, character_system_message, bidding_template in zip( for character_name, character_system_message, bidding_template in zip(
character_names, character_system_messages, character_bidding_templates character_names, character_system_messages, character_bidding_templates

@ -9,7 +9,7 @@ collab = DialogueSimulator(
) )
collab.run( collab.run(
max_iters = 4, max_iters=4,
name = "plinus", name="plinus",
message = "how can we enable multi agent collaboration", message="how can we enable multi agent collaboration",
) )

@ -5,4 +5,3 @@ api_key = "APIKEY"
objective = "What is the capital of the UK?" objective = "What is the capital of the UK?"
result = swarm(api_key, objective) result = swarm(api_key, objective)
print(result) # Prints: "The capital of the UK is London." print(result) # Prints: "The capital of the UK is London."

@ -1,4 +1,3 @@
from langchain.models import Anthropic, GooglePalm, OpenAIChat from langchain.models import Anthropic, GooglePalm, OpenAIChat
from swarms.swarms import GodMode from swarms.swarms import GodMode
@ -7,11 +6,7 @@ palm = GooglePalm(google_api_key="")
gpt = OpenAIChat(openai_api_key="") gpt = OpenAIChat(openai_api_key="")
# Usage # Usage
llms = [ llms = [claude, palm, gpt]
claude,
palm,
gpt
]
god_mode = GodMode(llms) god_mode = GodMode(llms)

@ -1,2 +1 @@
from swarms.swarms import GroupChat from swarms.swarms import GroupChat

@ -2,44 +2,36 @@ from langchain.llms import OpenAIChat
from swarms.swarms import GroupChat, GroupChatManager from swarms.swarms import GroupChat, GroupChatManager
from swarms.workers import Worker from swarms.workers import Worker
llm = OpenAIChat( llm = OpenAIChat(model_name="gpt-4", openai_api_key="api-key", temperature=0.5)
model_name='gpt-4',
openai_api_key="api-key",
temperature=0.5
)
node = Worker( node = Worker(
llm=llm, llm=llm,
ai_name="Optimus Prime", ai_name="Optimus Prime",
ai_role="Worker in a swarm", ai_role="Worker in a swarm",
external_tools = None, external_tools=None,
human_in_the_loop = False, human_in_the_loop=False,
temperature = 0.5, temperature=0.5,
) )
node2 = Worker( node2 = Worker(
llm=llm, llm=llm,
ai_name="Optimus Prime", ai_name="Optimus Prime",
ai_role="Worker in a swarm", ai_role="Worker in a swarm",
external_tools = None, external_tools=None,
human_in_the_loop = False, human_in_the_loop=False,
temperature = 0.5, temperature=0.5,
) )
node3 = Worker( node3 = Worker(
llm=llm, llm=llm,
ai_name="Optimus Prime", ai_name="Optimus Prime",
ai_role="Worker in a swarm", ai_role="Worker in a swarm",
external_tools = None, external_tools=None,
human_in_the_loop = False, human_in_the_loop=False,
temperature = 0.5, temperature=0.5,
) )
nodes = [ nodes = [node, node2, node3]
node,
node2,
node3
]
messages = [ messages = [
{ {

@ -1,9 +1,11 @@
from swarms import DialogueSimulator, Worker from swarms import DialogueSimulator, Worker
def select_next_speaker(step: int, agents) -> int: def select_next_speaker(step: int, agents) -> int:
idx = (step) % len(agents) idx = (step) % len(agents)
return idx return idx
debate = DialogueSimulator(Worker, select_next_speaker) debate = DialogueSimulator(Worker, select_next_speaker)
debate.run() debate.run()

@ -5,11 +5,7 @@ worker1 = Worker(openai_api_key="", ai_name="Optimus Prime")
worker2 = Worker(openai_api_key="", ai_name="Bumblebee") worker2 = Worker(openai_api_key="", ai_name="Bumblebee")
worker3 = Worker(openai_api_key="", ai_name="Megatron") worker3 = Worker(openai_api_key="", ai_name="Megatron")
agents = [ agents = [worker1, worker2, worker3]
worker1,
worker2,
worker3
]
# Initialize multi-agent debate with the selection function # Initialize multi-agent debate with the selection function
debate = MultiAgentDebate(agents, select_speaker) debate = MultiAgentDebate(agents, select_speaker)

@ -3,12 +3,11 @@ from swarms import Worker, Orchestrator
node = Worker( node = Worker(
openai_api_key="", openai_api_key="",
ai_name="Optimus Prime", ai_name="Optimus Prime",
) )
# Instantiate the Orchestrator with 10 agents # Instantiate the Orchestrator with 10 agents
orchestrator = Orchestrator(node, agent_list=[node]*10, task_queue=[]) orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[])
# Agent 7 sends a message to Agent 9 # Agent 7 sends a message to Agent 9
orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?") orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?")

@ -3,12 +3,11 @@ from swarms import Worker, Orchestrator
node = Worker( node = Worker(
openai_api_key="", openai_api_key="",
ai_name="Optimus Prime", ai_name="Optimus Prime",
) )
# Instantiate the Orchestrator with 10 agents # Instantiate the Orchestrator with 10 agents
orchestrator = Orchestrator(node, agent_list=[node]*10, task_queue=[]) orchestrator = Orchestrator(node, agent_list=[node] * 10, task_queue=[])
# Agent 7 sends a message to Agent 9 # Agent 7 sends a message to Agent 9
orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?") orchestrator.chat(sender_id=7, receiver_id=9, message="Can you help me with this task?")

@ -1,19 +1,15 @@
from langchain.models import OpenAIChat from langchain.models import OpenAIChat
from swarms import Worker from swarms import Worker
llm = OpenAIChat( llm = OpenAIChat(model_name="gpt-4", openai_api_key="api-key", temperature=0.5)
model_name='gpt-4',
openai_api_key="api-key",
temperature=0.5
)
node = Worker( node = Worker(
llm=llm, llm=llm,
ai_name="Optimus Prime", ai_name="Optimus Prime",
ai_role="Worker in a swarm", ai_role="Worker in a swarm",
external_tools = None, external_tools=None,
human_in_the_loop = False, human_in_the_loop=False,
temperature = 0.5, temperature=0.5,
) )
task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times."

@ -1,4 +1,3 @@
from swarms import Workflow from swarms import Workflow
from swarms.tools.autogpt import ChatOpenAI from swarms.tools.autogpt import ChatOpenAI

@ -1,50 +1,50 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
setup( setup(
name = 'swarms', name="swarms",
packages = find_packages(exclude=[]), packages=find_packages(exclude=[]),
version = '1.4.1', version="1.4.1",
license='MIT', license="MIT",
description = 'Swarms - Pytorch', description="Swarms - Pytorch",
author = 'Kye Gomez', author="Kye Gomez",
author_email = 'kye@apac.ai', author_email="kye@apac.ai",
long_description_content_type = 'text/markdown', long_description_content_type="text/markdown",
url = 'https://github.com/kyegomez/swarms', url="https://github.com/kyegomez/swarms",
keywords = [ keywords=[
'artificial intelligence', "artificial intelligence",
'deep learning', "deep learning",
'optimizers', "optimizers",
"Prompt Engineering" "Prompt Engineering",
], ],
install_requires=[ install_requires=[
'transformers', "transformers",
'openai', "openai",
'langchain==0.0.240', "langchain==0.0.240",
'asyncio', "asyncio",
'nest_asyncio', "nest_asyncio",
'pegasusx', "pegasusx",
'google-generativeai', "google-generativeai",
'oceandb', "oceandb",
'langchain-experimental', "langchain-experimental",
'playwright', "playwright",
'duckduckgo_search', "duckduckgo_search",
'faiss-cpu', "faiss-cpu",
'wget', "wget",
'httpx', "httpx",
'ggl', "ggl",
'beautifulsoup4', "beautifulsoup4",
'pydantic', "pydantic",
'tenacity', "tenacity",
'celery', "celery",
'redis', "redis",
'google-search-results==2.4.2', "google-search-results==2.4.2",
'Pillow', "Pillow",
], ],
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', "Development Status :: 4 - Beta",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'Topic :: Scientific/Engineering :: Artificial Intelligence', "Topic :: Scientific/Engineering :: Artificial Intelligence",
'License :: OSI Approved :: MIT License', "License :: OSI Approved :: MIT License",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
], ],
) )

@ -7,6 +7,7 @@ from swarms import models
from swarms.workers.worker import Worker from swarms.workers.worker import Worker
from swarms import workers from swarms import workers
from swarms.logo import logo2 from swarms.logo import logo2
print(logo2) print(logo2)
# worker # worker

@ -1,4 +1,3 @@
"""Agent Infrastructure, models, memory, utils, tools""" """Agent Infrastructure, models, memory, utils, tools"""
# agents # agents

@ -4,7 +4,9 @@ import time
import openai import openai
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,11 +27,13 @@ class OpenAI:
raise Exception("Please provide OpenAI API key") raise Exception("Please provide OpenAI API key")
if api_base == "" or api_base is None: if api_base == "" or api_base is None:
api_base = os.environ.get("OPENAI_API_BASE", "") # if not set, use the default base path of "https://api.openai.com/v1" api_base = os.environ.get(
"OPENAI_API_BASE", ""
) # if not set, use the default base path of "https://api.openai.com/v1"
if api_base != "": if api_base != "":
# e.g. https://api.openai.com/v1/ or your custom url # e.g. https://api.openai.com/v1/ or your custom url
openai.api_base = api_base openai.api_base = api_base
print(f'Using custom api_base {api_base}') print(f"Using custom api_base {api_base}")
if api_model == "" or api_model is None: if api_model == "" or api_model is None:
api_model = os.environ.get("OPENAI_API_MODEL", "") api_model = os.environ.get("OPENAI_API_MODEL", "")
@ -37,29 +41,17 @@ class OpenAI:
self.api_model = api_model self.api_model = api_model
else: else:
self.api_model = "text-davinci-003" self.api_model = "text-davinci-003"
print(f'Using api_model {self.api_model}') print(f"Using api_model {self.api_model}")
self.use_chat_api = 'gpt' in self.api_model self.use_chat_api = "gpt" in self.api_model
self.strategy = strategy self.strategy = strategy
self.evaluation_strategy = evaluation_strategy self.evaluation_strategy = evaluation_strategy
def run( def run(self, prompt, max_tokens, temperature, k=1, stop=None):
self,
prompt,
max_tokens,
temperature,
k=1,
stop=None
):
while True: while True:
try: try:
if self.use_chat_api: if self.use_chat_api:
messages = [ messages = [{"role": "user", "content": prompt}]
{
"role": "user",
"content": prompt
}
]
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.api_model, model=self.api_model,
messages=messages, messages=messages,
@ -75,17 +67,21 @@ class OpenAI:
stop=stop, stop=stop,
temperature=temperature, temperature=temperature,
) )
with open("openai.logs", 'a') as log_file: with open("openai.logs", "a") as log_file:
log_file.write("\n" + "-----------" + '\n' + "Prompt : " + prompt + "\n") log_file.write(
"\n" + "-----------" + "\n" + "Prompt : " + prompt + "\n"
)
return response return response
except openai.error.RateLimitError as e: except openai.error.RateLimitError as e:
sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30) sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30)
print(f'{str(e)}, sleep for {sleep_duratoin}s, set it by env OPENAI_RATE_TIMEOUT') print(
f"{str(e)}, sleep for {sleep_duratoin}s, set it by env OPENAI_RATE_TIMEOUT"
)
time.sleep(sleep_duratoin) time.sleep(sleep_duratoin)
def openai_choice2text_handler(self, choice): def openai_choice2text_handler(self, choice):
if self.use_chat_api: if self.use_chat_api:
text = choice['message']['content'] text = choice["message"]["content"]
else: else:
text = choice.text.strip() text = choice.text.strip()
return text return text
@ -102,20 +98,16 @@ class OpenAI:
else: else:
response = self.run(prompt, 300, 0.5, k) response = self.run(prompt, 300, 0.5, k)
thoughts = [self.openai_choice2text_handler(choice) for choice in response.choices] thoughts = [
self.openai_choice2text_handler(choice) for choice in response.choices
]
return thoughts return thoughts
def generate_thoughts( def generate_thoughts(self, state, k, initial_prompt, rejected_solutions=None):
self, if isinstance(state, str):
state,
k,
initial_prompt,
rejected_solutions=None
):
if (isinstance(state, str)):
state_text = state state_text = state
else: else:
state_text = '\n'.join(state) state_text = "\n".join(state)
print("New state generating thought:", state, "\n\n") print("New state generating thought:", state, "\n\n")
prompt = f""" prompt = f"""
Accomplish the task below by decomposing it as many very explicit subtasks as possible, be very explicit and thorough denoted by Accomplish the task below by decomposing it as many very explicit subtasks as possible, be very explicit and thorough denoted by
@ -135,14 +127,10 @@ class OpenAI:
# print(f"Generated thoughts: {thoughts}") # print(f"Generated thoughts: {thoughts}")
return thoughts return thoughts
def generate_solution(self, def generate_solution(self, initial_prompt, state, rejected_solutions=None):
initial_prompt,
state,
rejected_solutions=None):
try: try:
if isinstance(state, list): if isinstance(state, list):
state_text = '\n'.join(state) state_text = "\n".join(state)
else: else:
state_text = state state_text = state
@ -156,7 +144,7 @@ class OpenAI:
###{rejected_solutions}###, ###{rejected_solutions}###,
complete the {initial_prompt} without making the same mistakes you did with the evaluated rejected solutions. Be simple. Be direct. Provide intuitive solutions as soon as you think of them.""" complete the {initial_prompt} without making the same mistakes you did with the evaluated rejected solutions. Be simple. Be direct. Provide intuitive solutions as soon as you think of them."""
answer = self.generate_text(prompt, 1) answer = self.generate_text(prompt, 1)
print(f'Generated Solution Summary {answer}') print(f"Generated Solution Summary {answer}")
return answer return answer
except Exception as e: except Exception as e:
logger.error(f"Error in generate_solutions: {e}") logger.error(f"Error in generate_solutions: {e}")
@ -166,14 +154,20 @@ class OpenAI:
if not states: if not states:
return {} return {}
if self.evaluation_strategy == 'value': if self.evaluation_strategy == "value":
state_values = {} state_values = {}
for state in states: for state in states:
if (isinstance(state, str)): if isinstance(state, str):
state_text = state state_text = state
else: else:
state_text = '\n'.join(state) state_text = "\n".join(state)
print("We receive a state of type", type(state), "For state: ", state, "\n\n") print(
"We receive a state of type",
type(state),
"For state: ",
state,
"\n\n",
)
prompt = f""" To achieve the following goal: '{initial_prompt}', pessimistically value the context of the past solutions and more importantly the latest generated solution you had AS A FLOAT BETWEEN 0 AND 1\n prompt = f""" To achieve the following goal: '{initial_prompt}', pessimistically value the context of the past solutions and more importantly the latest generated solution you had AS A FLOAT BETWEEN 0 AND 1\n
Past solutions:\n\n Past solutions:\n\n
{state_text}\n {state_text}\n
@ -244,7 +238,11 @@ class AoTAgent:
for next_state in thoughts: for next_state in thoughts:
state_value = self.evaluated_thoughts[next_state] state_value = self.evaluated_thoughts[next_state]
if state_value > self.value_threshold: if state_value > self.value_threshold:
child = (state, next_state) if isinstance(state, str) else (*state, next_state) child = (
(state, next_state)
if isinstance(state, str)
else (*state, next_state)
)
self.dfs(child, step + 1) self.dfs(child, step + 1)
# backtracking # backtracking
@ -255,17 +253,18 @@ class AoTAgent:
def generate_and_filter_thoughts(self, state): def generate_and_filter_thoughts(self, state):
thoughts = self.model.generate_thoughts( thoughts = self.model.generate_thoughts(
state, state, self.num_thoughts, self.initial_prompt
self.num_thoughts,
self.initial_prompt
) )
self.evaluated_thoughts = self.model.evaluate_states( self.evaluated_thoughts = self.model.evaluate_states(
thoughts, thoughts, self.initial_prompt
self.initial_prompt
) )
filtered_thoughts = [thought for thought in thoughts if self.evaluated_thoughts[thought] >= self.pruning_threshold] filtered_thoughts = [
thought
for thought in thoughts
if self.evaluated_thoughts[thought] >= self.pruning_threshold
]
print(f"filtered_thoughts: {filtered_thoughts}") print(f"filtered_thoughts: {filtered_thoughts}")
return filtered_thoughts return filtered_thoughts

@ -18,7 +18,7 @@ class AbstractAgent:
self, self,
name: str, name: str,
# tools: List[Tool], # tools: List[Tool],
#memory: Memory # memory: Memory
): ):
""" """
Args: Args:
@ -51,10 +51,7 @@ class AbstractAgent:
def chat(self, messages: List[Dict]): def chat(self, messages: List[Dict]):
"""Chat with the agent""" """Chat with the agent"""
def _achat( def _achat(self, messages: List[Dict]):
self,
messages: List[Dict]
):
"""Asynchronous Chat""" """Asynchronous Chat"""
def step(self, message: str): def step(self, message: str):

@ -43,7 +43,9 @@ class ConversableAgent(Agent):
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"model": DEFAULT_MODEL, "model": DEFAULT_MODEL,
} }
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) MAX_CONSECUTIVE_AUTO_REPLY = (
100 # maximum number of consecutive auto replies (subject to future change)
)
def __init__( def __init__(
self, self,
@ -103,7 +105,9 @@ class ConversableAgent(Agent):
self._oai_messages = defaultdict(list) self._oai_messages = defaultdict(list)
self._oai_system_message = [{"content": system_message, "role": "system"}] self._oai_system_message = [{"content": system_message, "role": "system"}]
self._is_termination_msg = ( self._is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "TERMINATE") is_termination_msg
if is_termination_msg is not None
else (lambda x: x.get("content") == "TERMINATE")
) )
if llm_config is False: if llm_config is False:
self.llm_config = False self.llm_config = False
@ -112,21 +116,33 @@ class ConversableAgent(Agent):
if isinstance(llm_config, dict): if isinstance(llm_config, dict):
self.llm_config.update(llm_config) self.llm_config.update(llm_config)
self._code_execution_config = {} if code_execution_config is None else code_execution_config self._code_execution_config = (
{} if code_execution_config is None else code_execution_config
)
self.human_input_mode = human_input_mode self.human_input_mode = human_input_mode
self._max_consecutive_auto_reply = ( self._max_consecutive_auto_reply = (
max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY max_consecutive_auto_reply
if max_consecutive_auto_reply is not None
else self.MAX_CONSECUTIVE_AUTO_REPLY
) )
self._consecutive_auto_reply_counter = defaultdict(int) self._consecutive_auto_reply_counter = defaultdict(int)
self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply) self._max_consecutive_auto_reply_dict = defaultdict(
self.max_consecutive_auto_reply
)
self._function_map = {} if function_map is None else function_map self._function_map = {} if function_map is None else function_map
self._default_auto_reply = default_auto_reply self._default_auto_reply = default_auto_reply
self._reply_func_list = [] self._reply_func_list = []
self.reply_at_receive = defaultdict(bool) self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply) self.register_reply(
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) [Agent, None], ConversableAgent.generate_code_execution_reply
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) )
self.register_reply(
[Agent, None], ConversableAgent.generate_function_call_reply
)
self.register_reply(
[Agent, None], ConversableAgent.check_termination_and_human_reply
)
def register_reply( def register_reply(
self, self,
@ -170,7 +186,9 @@ class ConversableAgent(Agent):
The function returns None. Signature: ```def reset_config(config: Any)``` The function returns None. Signature: ```def reset_config(config: Any)```
""" """
if not isinstance(trigger, (type, str, Agent, Callable, list)): if not isinstance(trigger, (type, str, Agent, Callable, list)):
raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") raise ValueError(
"trigger must be a class, a string, an agent, a callable or a list."
)
self._reply_func_list.insert( self._reply_func_list.insert(
position, position,
{ {
@ -195,7 +213,9 @@ class ConversableAgent(Agent):
""" """
self._oai_system_message[0]["content"] = system_message self._oai_system_message[0]["content"] = system_message
def update_max_consecutive_auto_reply(self, value: int, sender: Optional[Agent] = None): def update_max_consecutive_auto_reply(
self, value: int, sender: Optional[Agent] = None
):
"""Update the maximum number of consecutive auto replies. """Update the maximum number of consecutive auto replies.
Args: Args:
@ -211,7 +231,11 @@ class ConversableAgent(Agent):
def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int: def max_consecutive_auto_reply(self, sender: Optional[Agent] = None) -> int:
"""The maximum number of consecutive auto replies.""" """The maximum number of consecutive auto replies."""
return self._max_consecutive_auto_reply if sender is None else self._max_consecutive_auto_reply_dict[sender] return (
self._max_consecutive_auto_reply
if sender is None
else self._max_consecutive_auto_reply_dict[sender]
)
@property @property
def chat_messages(self) -> Dict[Agent, List[Dict]]: def chat_messages(self) -> Dict[Agent, List[Dict]]:
@ -236,7 +260,9 @@ class ConversableAgent(Agent):
if n_conversations == 1: if n_conversations == 1:
for conversation in self._oai_messages.values(): for conversation in self._oai_messages.values():
return conversation[-1] return conversation[-1]
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.") raise ValueError(
"More than one conversation is found. Please specify the sender to get the last message."
)
return self._oai_messages[agent][-1] return self._oai_messages[agent][-1]
@property @property
@ -244,7 +270,11 @@ class ConversableAgent(Agent):
"""Bool value of whether to use docker to execute the code, """Bool value of whether to use docker to execute the code,
or str value of the docker image name to use, or None when code execution is disabled. or str value of the docker image name to use, or None when code execution is disabled.
""" """
return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") return (
None
if self._code_execution_config is False
else self._code_execution_config.get("use_docker")
)
@staticmethod @staticmethod
def _message_to_dict(message: Union[Dict, str]): def _message_to_dict(message: Union[Dict, str]):
@ -257,7 +287,9 @@ class ConversableAgent(Agent):
else: else:
return message return message
def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool: def _append_oai_message(
self, message: Union[Dict, str], role, conversation_id: Agent
) -> bool:
"""Append a message to the ChatCompletion conversation. """Append a message to the ChatCompletion conversation.
If the message received is a string, it will be put in the "content" field of the new dictionary. If the message received is a string, it will be put in the "content" field of the new dictionary.
@ -275,16 +307,24 @@ class ConversableAgent(Agent):
""" """
message = self._message_to_dict(message) message = self._message_to_dict(message)
# create oai message to be appended to the oai conversation that can be passed to oai directly. # create oai message to be appended to the oai conversation that can be passed to oai directly.
oai_message = {k: message[k] for k in ("content", "function_call", "name", "context") if k in message} oai_message = {
k: message[k]
for k in ("content", "function_call", "name", "context")
if k in message
}
if "content" not in oai_message: if "content" not in oai_message:
if "function_call" in oai_message: if "function_call" in oai_message:
oai_message["content"] = None # if only function_call is provided, content will be set to None. oai_message[
"content"
] = None # if only function_call is provided, content will be set to None.
else: else:
return False return False
oai_message["role"] = "function" if message.get("role") == "function" else role oai_message["role"] = "function" if message.get("role") == "function" else role
if "function_call" in oai_message: if "function_call" in oai_message:
oai_message["role"] = "assistant" # only messages with role 'assistant' can have a function call. oai_message[
"role"
] = "assistant" # only messages with role 'assistant' can have a function call.
self._oai_messages[conversation_id].append(oai_message) self._oai_messages[conversation_id].append(oai_message)
return True return True
@ -390,7 +430,9 @@ class ConversableAgent(Agent):
# print the message received # print the message received
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
if message.get("role") == "function": if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****" func_print = (
f"***** Response from calling function \"{message['name']}\" *****"
)
print(colored(func_print, "green"), flush=True) print(colored(func_print, "green"), flush=True)
print(message["content"], flush=True) print(message["content"], flush=True)
print(colored("*" * len(func_print), "green"), flush=True) print(colored("*" * len(func_print), "green"), flush=True)
@ -401,7 +443,8 @@ class ConversableAgent(Agent):
content = oai.ChatCompletion.instantiate( content = oai.ChatCompletion.instantiate(
content, content,
message["context"], message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False), self.llm_config
and self.llm_config.get("allow_format_str_template", False),
) )
print(content, flush=True) print(content, flush=True)
if "function_call" in message: if "function_call" in message:
@ -457,7 +500,11 @@ class ConversableAgent(Agent):
ValueError: if the message can't be converted into a valid ChatCompletion message. ValueError: if the message can't be converted into a valid ChatCompletion message.
""" """
self._process_received_message(message, sender, silent) self._process_received_message(message, sender, silent)
if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: if (
request_reply is False
or request_reply is None
and self.reply_at_receive[sender] is False
):
return return
reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender)
if reply is not None: if reply is not None:
@ -493,7 +540,11 @@ class ConversableAgent(Agent):
ValueError: if the message can't be converted into a valid ChatCompletion message. ValueError: if the message can't be converted into a valid ChatCompletion message.
""" """
self._process_received_message(message, sender, silent) self._process_received_message(message, sender, silent)
if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: if (
request_reply is False
or request_reply is None
and self.reply_at_receive[sender] is False
):
return return
reply = await self.a_generate_reply(sender=sender) reply = await self.a_generate_reply(sender=sender)
if reply is not None: if reply is not None:
@ -551,7 +602,9 @@ class ConversableAgent(Agent):
"message" needs to be provided if the `generate_init_message` method is not overridden. "message" needs to be provided if the `generate_init_message` method is not overridden.
""" """
self._prepare_chat(recipient, clear_history) self._prepare_chat(recipient, clear_history)
await self.a_send(self.generate_init_message(**context), recipient, silent=silent) await self.a_send(
self.generate_init_message(**context), recipient, silent=silent
)
def reset(self): def reset(self):
"""Reset the agent.""" """Reset the agent."""
@ -604,7 +657,9 @@ class ConversableAgent(Agent):
# TODO: #1143 handle token limit exceeded error # TODO: #1143 handle token limit exceeded error
response = oai.ChatCompletion.create( response = oai.ChatCompletion.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **llm_config context=messages[-1].pop("context", None),
messages=self._oai_system_message + messages,
**llm_config,
) )
return True, oai.ChatCompletion.extract_text_or_function_call(response)[0] return True, oai.ChatCompletion.extract_text_or_function_call(response)[0]
@ -615,7 +670,9 @@ class ConversableAgent(Agent):
config: Optional[Any] = None, config: Optional[Any] = None,
): ):
"""Generate a reply using code execution.""" """Generate a reply using code execution."""
code_execution_config = config if config is not None else self._code_execution_config code_execution_config = (
config if config is not None else self._code_execution_config
)
if code_execution_config is False: if code_execution_config is False:
return False, None return False, None
if messages is None: if messages is None:
@ -634,7 +691,9 @@ class ConversableAgent(Agent):
# found code blocks, execute code and push "last_n_messages" back # found code blocks, execute code and push "last_n_messages" back
exitcode, logs = self.execute_code_blocks(code_blocks) exitcode, logs = self.execute_code_blocks(code_blocks)
code_execution_config["last_n_messages"] = last_n_messages code_execution_config["last_n_messages"] = last_n_messages
exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" exitcode2str = (
"execution succeeded" if exitcode == 0 else "execution failed"
)
return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}" return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}"
# no code blocks are found, push last_n_messages back and return. # no code blocks are found, push last_n_messages back and return.
@ -681,7 +740,10 @@ class ConversableAgent(Agent):
# if the human input is empty, and the message is a termination message, then we will terminate the conversation # if the human input is empty, and the message is a termination message, then we will terminate the conversation
reply = reply if reply or not self._is_termination_msg(message) else "exit" reply = reply if reply or not self._is_termination_msg(message) else "exit"
else: else:
if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: if (
self._consecutive_auto_reply_counter[sender]
>= self._max_consecutive_auto_reply_dict[sender]
):
if self.human_input_mode == "NEVER": if self.human_input_mode == "NEVER":
reply = "exit" reply = "exit"
else: else:
@ -776,7 +838,12 @@ class ConversableAgent(Agent):
if asyncio.coroutines.iscoroutinefunction(reply_func): if asyncio.coroutines.iscoroutinefunction(reply_func):
continue continue
if self._match_trigger(reply_func_tuple["trigger"], sender): if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) final, reply = reply_func(
self,
messages=messages,
sender=sender,
config=reply_func_tuple["config"],
)
if final: if final:
return reply return reply
return self._default_auto_reply return self._default_auto_reply
@ -827,10 +894,18 @@ class ConversableAgent(Agent):
if self._match_trigger(reply_func_tuple["trigger"], sender): if self._match_trigger(reply_func_tuple["trigger"], sender):
if asyncio.coroutines.iscoroutinefunction(reply_func): if asyncio.coroutines.iscoroutinefunction(reply_func):
final, reply = await reply_func( final, reply = await reply_func(
self, messages=messages, sender=sender, config=reply_func_tuple["config"] self,
messages=messages,
sender=sender,
config=reply_func_tuple["config"],
) )
else: else:
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) final, reply = reply_func(
self,
messages=messages,
sender=sender,
config=reply_func_tuple["config"],
)
if final: if final:
return reply return reply
return self._default_auto_reply return self._default_auto_reply
@ -897,10 +972,12 @@ class ConversableAgent(Agent):
flush=True, flush=True,
) )
if lang in ["bash", "shell", "sh"]: if lang in ["bash", "shell", "sh"]:
exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config) exitcode, logs, image = self.run_code(
code, lang=lang, **self._code_execution_config
)
elif lang in ["python", "Python"]: elif lang in ["python", "Python"]:
if code.startswith("# filename: "): if code.startswith("# filename: "):
filename = code[11: code.find("\n")].strip() filename = code[11 : code.find("\n")].strip()
else: else:
filename = None filename = None
exitcode, logs, image = self.run_code( exitcode, logs, image = self.run_code(

@ -66,7 +66,9 @@ class CocoGroundingEvaluator(object):
def synchronize_between_processes(self): def synchronize_between_processes(self):
for iou_type in self.iou_types: for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) create_common_coco_eval(
self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]
)
def accumulate(self): def accumulate(self):
for coco_eval in self.coco_eval.values(): for coco_eval in self.coco_eval.values():
@ -127,7 +129,9 @@ class CocoGroundingEvaluator(object):
labels = prediction["labels"].tolist() labels = prediction["labels"].tolist()
rles = [ rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] mask_util.encode(
np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F")
)[0]
for mask in masks for mask in masks
] ]
for rle in rles: for rle in rles:
@ -227,7 +231,9 @@ def evaluate(self):
# add backward compatibility if useSegm is specified in params # add backward compatibility if useSegm is specified in params
if p.useSegm is not None: if p.useSegm is not None:
p.iouType = "segm" if p.useSegm == 1 else "bbox" p.iouType = "segm" if p.useSegm == 1 else "bbox"
print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) print(
"useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)
)
# print('Evaluate annotation type *{}*'.format(p.iouType)) # print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds)) p.imgIds = list(np.unique(p.imgIds))
if p.useCats: if p.useCats:
@ -246,7 +252,8 @@ def evaluate(self):
self.ious = { self.ious = {
(imgId, catId): computeIoU(imgId, catId) (imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds for imgId in p.imgIds
for catId in catIds} for catId in catIds
}
evaluateImg = self.evaluateImg evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1] maxDet = p.maxDets[-1]

@ -38,7 +38,7 @@ def crop(image, target, region):
if "masks" in target: if "masks" in target:
# FIXME should we update the area here if there are no boxes? # FIXME should we update the area here if there are no boxes?
target["masks"] = target["masks"][:, i: i + h, j: j + w] target["masks"] = target["masks"][:, i : i + h, j : j + w]
fields.append("masks") fields.append("masks")
# remove elements for which the boxes or masks that have zero area # remove elements for which the boxes or masks that have zero area
@ -73,9 +73,9 @@ def hflip(image, target):
target = target.copy() target = target.copy()
if "boxes" in target: if "boxes" in target:
boxes = target["boxes"] boxes = target["boxes"]
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[w, 0, w, 0] [-1, 1, -1, 1]
) ) + torch.as_tensor([w, 0, w, 0])
target["boxes"] = boxes target["boxes"] = boxes
if "masks" in target: if "masks" in target:
@ -119,7 +119,9 @@ def resize(image, target, size, max_size=None):
if target is None: if target is None:
return rescaled_image, None return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) ratios = tuple(
float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
)
ratio_width, ratio_height = ratios ratio_width, ratio_height = ratios
target = target.copy() target = target.copy()
@ -140,7 +142,8 @@ def resize(image, target, size, max_size=None):
if "masks" in target: if "masks" in target:
target["masks"] = ( target["masks"] = (
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
> 0.5
) )
return rescaled_image, target return rescaled_image, target
@ -155,7 +158,9 @@ def pad(image, target, padding):
# should we do something wrt the original size? # should we do something wrt the original size?
target["size"] = torch.tensor(padded_image.size[::-1]) target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target: if "masks" in target:
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) target["masks"] = torch.nn.functional.pad(
target["masks"], (0, padding[0], 0, padding[1])
)
return padded_image, target return padded_image, target

@ -47,14 +47,27 @@ class FrozenBatchNorm2d(torch.nn.Module):
self.register_buffer("running_var", torch.ones(n)) self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict( def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
): ):
num_batches_tracked_key = prefix + "num_batches_tracked" num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict: if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key] del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict( super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
) )
def forward(self, x): def forward(self, x):
@ -91,7 +104,11 @@ class BackboneBase(nn.Module):
return_layers = {} return_layers = {}
for idx, layer_index in enumerate(return_interm_indices): for idx, layer_index in enumerate(return_interm_indices):
return_layers.update( return_layers.update(
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)} {
"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(
layer_index
)
}
) )
# if len: # if len:
@ -136,10 +153,13 @@ class Backbone(BackboneBase):
else: else:
raise NotImplementedError("Why you can get here with name {}".format(name)) raise NotImplementedError("Why you can get here with name {}".format(name))
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." assert name not in (
"resnet18",
"resnet34",
), "Only resnet50 and resnet101 are available."
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
num_channels_all = [256, 512, 1024, 2048] num_channels_all = [256, 512, 1024, 2048]
num_channels = num_channels_all[4 - len(return_interm_indices):] num_channels = num_channels_all[4 - len(return_interm_indices) :]
super().__init__(backbone, train_backbone, num_channels, return_interm_indices) super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
@ -204,7 +224,7 @@ def build_backbone(args):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
) )
bb_num_channels = backbone.num_features[4 - len(return_interm_indices):] bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
else: else:
raise NotImplementedError("Unknown backbone {}".format(args.backbone)) raise NotImplementedError("Unknown backbone {}".format(args.backbone))

@ -33,7 +33,9 @@ class PositionEmbeddingSine(nn.Module):
used by the Attention is all you need paper, generalized to work on images. used by the Attention is all you need paper, generalized to work on images.
""" """
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__() super().__init__()
self.num_pos_feats = num_pos_feats self.num_pos_feats = num_pos_feats
self.temperature = temperature self.temperature = temperature
@ -82,7 +84,12 @@ class PositionEmbeddingSineHW(nn.Module):
""" """
def __init__( def __init__(
self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None self,
num_pos_feats=64,
temperatureH=10000,
temperatureW=10000,
normalize=False,
scale=None,
): ):
super().__init__() super().__init__()
self.num_pos_feats = num_pos_feats self.num_pos_feats = num_pos_feats
@ -111,11 +118,15 @@ class PositionEmbeddingSineHW(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats) dim_tx = self.temperatureW ** (
2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats
)
pos_x = x_embed[:, :, :, None] / dim_tx pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats) dim_ty = self.temperatureH ** (
2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats
)
pos_y = y_embed[:, :, :, None] / dim_ty pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack( pos_x = torch.stack(

@ -25,7 +25,12 @@ class Mlp(nn.Module):
"""Multilayer perceptron.""" """Multilayer perceptron."""
def __init__( def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
): ):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
@ -54,7 +59,9 @@ def window_partition(x, window_size):
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows return windows
@ -69,7 +76,9 @@ def window_reverse(windows, window_size, H, W):
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size / window_size)) B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(
B, H // window_size, W // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x return x
@ -97,7 +106,6 @@ class WindowAttention(nn.Module):
attn_drop=0.0, attn_drop=0.0,
proj_drop=0.0, proj_drop=0.0,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.window_size = window_size # Wh, Ww self.window_size = window_size # Wh, Ww
@ -115,8 +123,12 @@ class WindowAttention(nn.Module):
coords_w = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = (
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
@ -143,7 +155,11 @@ class WindowAttention(nn.Module):
.reshape(B_, N, 3, self.num_heads, C // self.num_heads) .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4) .permute(2, 0, 3, 1, 4)
) )
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
@ -151,7 +167,9 @@ class WindowAttention(nn.Module):
relative_position_bias = self.relative_position_bias_table[ relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1) self.relative_position_index.view(-1)
].view( ].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH ) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute( relative_position_bias = relative_position_bias.permute(
2, 0, 1 2, 0, 1
@ -160,7 +178,9 @@ class WindowAttention(nn.Module):
if mask is not None: if mask is not None:
nW = mask.shape[0] nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N) attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn) attn = self.softmax(attn)
else: else:
@ -212,7 +232,9 @@ class SwinTransformerBlock(nn.Module):
self.window_size = window_size self.window_size = window_size
self.shift_size = shift_size self.shift_size = shift_size
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" assert (
0 <= self.shift_size < self.window_size
), "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = WindowAttention( self.attn = WindowAttention(
@ -229,7 +251,10 @@ class SwinTransformerBlock(nn.Module):
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
) )
self.H = None self.H = None
@ -259,7 +284,9 @@ class SwinTransformerBlock(nn.Module):
# cyclic shift # cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
attn_mask = mask_matrix attn_mask = mask_matrix
else: else:
shifted_x = x shifted_x = x
@ -274,7 +301,9 @@ class SwinTransformerBlock(nn.Module):
) # nW*B, window_size*window_size, C ) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA # W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C attn_windows = self.attn(
x_windows, mask=attn_mask
) # nW*B, window_size*window_size, C
# merge windows # merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
@ -282,7 +311,9 @@ class SwinTransformerBlock(nn.Module):
# reverse cyclic shift # reverse cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else: else:
x = shifted_x x = shifted_x
@ -393,7 +424,9 @@ class BasicLayer(nn.Module):
qk_scale=qk_scale, qk_scale=qk_scale,
drop=drop, drop=drop,
attn_drop=attn_drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
for i in range(depth) for i in range(depth)
@ -473,7 +506,9 @@ class PatchEmbed(nn.Module):
self.in_chans = in_chans self.in_chans = in_chans
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if norm_layer is not None: if norm_layer is not None:
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
else: else:
@ -614,7 +649,7 @@ class SwinTransformer(nn.Module):
qk_scale=qk_scale, qk_scale=qk_scale,
drop=drop_rate, drop=drop_rate,
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]): sum(depths[: i_layer + 1])], drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer, norm_layer=norm_layer,
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
downsample=downsamplelist[i_layer], downsample=downsamplelist[i_layer],
@ -700,7 +735,11 @@ class SwinTransformer(nn.Module):
norm_layer = getattr(self, f"norm{i}") norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out) x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() out = (
x_out.view(-1, H, W, self.num_features[i])
.permute(0, 3, 1, 2)
.contiguous()
)
outs.append(out) outs.append(out)
# in: # in:
# torch.Size([2, 3, 1024, 1024]) # torch.Size([2, 3, 1024, 1024])
@ -735,7 +774,11 @@ class SwinTransformer(nn.Module):
norm_layer = getattr(self, f"norm{i}") norm_layer = getattr(self, f"norm{i}")
x_out = norm_layer(x_out) x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() out = (
x_out.view(-1, H, W, self.num_features[i])
.permute(0, 3, 1, 2)
.contiguous()
)
outs.append(out) outs.append(out)
# in: # in:
# torch.Size([2, 3, 1024, 1024]) # torch.Size([2, 3, 1024, 1024])
@ -748,7 +791,9 @@ class SwinTransformer(nn.Module):
for idx, out_i in enumerate(outs): for idx, out_i in enumerate(outs):
m = tensor_list.mask m = tensor_list.mask
assert m is not None assert m is not None
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0] mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[
0
]
outs_dict[idx] = NestedTensor(out_i, mask) outs_dict[idx] = NestedTensor(out_i, mask)
return outs_dict return outs_dict
@ -776,13 +821,22 @@ def build_swin_transformer(modelname, pretrain_img_size, **kw):
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7 embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
), ),
"swin_B_384_22k": dict( "swin_B_384_22k": dict(
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12 embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=12,
), ),
"swin_L_224_22k": dict( "swin_L_224_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7 embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=7,
), ),
"swin_L_384_22k": dict( "swin_L_384_22k": dict(
embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12 embed_dim=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
), ),
} }
kw_cgf = model_para_dict[modelname] kw_cgf = model_para_dict[modelname]

@ -61,14 +61,18 @@ class BertModelWarper(nn.Module):
decoding (see :obj:`past_key_values`). decoding (see :obj:`past_key_values`).
""" """
output_attentions = ( output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions output_attentions
if output_attentions is not None
else self.config.output_attentions
) )
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states
if output_hidden_states is not None if output_hidden_states is not None
else self.config.output_hidden_states else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if self.config.is_decoder: if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -76,7 +80,9 @@ class BertModelWarper(nn.Module):
use_cache = False use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
@ -109,11 +115,17 @@ class BertModelWarper(nn.Module):
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None: if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() (
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
@ -147,7 +159,9 @@ class BertModelWarper(nn.Module):
return_dict=return_dict, return_dict=return_dict,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
@ -193,7 +207,10 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer
# generate attention mask and positional ids # generate attention mask and positional ids
attention_mask = ( attention_mask = (
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1) torch.eye(num_token, device=input_ids.device)
.bool()
.unsqueeze(0)
.repeat(bs, 1, 1)
) )
position_ids = torch.zeros((bs, num_token), device=input_ids.device) position_ids = torch.zeros((bs, num_token), device=input_ids.device)
previous_col = 0 previous_col = 0
@ -203,8 +220,10 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer
attention_mask[row, col, col] = True attention_mask[row, col, col] = True
position_ids[row, col] = 0 position_ids[row, col] = 0
else: else:
attention_mask[row, previous_col + 1: col + 1, previous_col + 1: col + 1] = True attention_mask[
position_ids[row, previous_col + 1: col + 1] = torch.arange( row, previous_col + 1 : col + 1, previous_col + 1 : col + 1
] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
0, col - previous_col, device=input_ids.device 0, col - previous_col, device=input_ids.device
) )
@ -217,7 +236,9 @@ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer
return attention_mask, position_ids.to(torch.long) return attention_mask, position_ids.to(torch.long)
def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer): def generate_masks_with_special_tokens_and_transfer_map(
tokenized, special_tokens_list, tokenizer
):
"""Generate attention mask between each pair of special tokens """Generate attention mask between each pair of special tokens
Args: Args:
input_ids (torch.Tensor): input ids. Shape: [bs, num_token] input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
@ -237,7 +258,10 @@ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_token
# generate attention mask and positional ids # generate attention mask and positional ids
attention_mask = ( attention_mask = (
torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1) torch.eye(num_token, device=input_ids.device)
.bool()
.unsqueeze(0)
.repeat(bs, 1, 1)
) )
position_ids = torch.zeros((bs, num_token), device=input_ids.device) position_ids = torch.zeros((bs, num_token), device=input_ids.device)
cate_to_token_mask_list = [[] for _ in range(bs)] cate_to_token_mask_list = [[] for _ in range(bs)]
@ -248,12 +272,14 @@ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_token
attention_mask[row, col, col] = True attention_mask[row, col, col] = True
position_ids[row, col] = 0 position_ids[row, col] = 0
else: else:
attention_mask[row, previous_col + 1: col + 1, previous_col + 1: col + 1] = True attention_mask[
position_ids[row, previous_col + 1: col + 1] = torch.arange( row, previous_col + 1 : col + 1, previous_col + 1 : col + 1
] = True
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
0, col - previous_col, device=input_ids.device 0, col - previous_col, device=input_ids.device
) )
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool() c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
c2t_maski[previous_col + 1: col] = True c2t_maski[previous_col + 1 : col] = True
cate_to_token_mask_list[row].append(c2t_maski) cate_to_token_mask_list[row].append(c2t_maski)
previous_col = col previous_col = col

@ -127,7 +127,11 @@ class BiMultiHeadAttention(nn.Module):
self._reset_parameters() self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def _reset_parameters(self): def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight)
@ -171,7 +175,9 @@ class BiMultiHeadAttention(nn.Module):
value_l_states = value_l_states.view(*proj_shape) value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt attn_weights = torch.bmm(
query_states, key_states.transpose(1, 2)
) # bs*nhead, nimg, ntxt
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError( raise ValueError(
@ -191,7 +197,9 @@ class BiMultiHeadAttention(nn.Module):
) # Do not increase 50000, data type half has quite limited range ) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2) attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0] attn_weights_l = (
attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
)
if self.clamp_min_for_underflow: if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp( attn_weights_l = torch.clamp(
attn_weights_l, min=-50000 attn_weights_l, min=-50000
@ -204,7 +212,9 @@ class BiMultiHeadAttention(nn.Module):
# mask vison for language # mask vison for language
if attention_mask_v is not None: if attention_mask_v is not None:
attention_mask_v = ( attention_mask_v = (
attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) attention_mask_v[:, None, None, :]
.repeat(1, self.num_heads, 1, 1)
.flatten(0, 1)
) )
attn_weights_l.masked_fill_(attention_mask_v, float("-inf")) attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
@ -213,7 +223,9 @@ class BiMultiHeadAttention(nn.Module):
# mask language for vision # mask language for vision
if attention_mask_l is not None: if attention_mask_l is not None:
attention_mask_l = ( attention_mask_l = (
attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) attention_mask_l[:, None, None, :]
.repeat(1, self.num_heads, 1, 1)
.flatten(0, 1)
) )
attn_weights.masked_fill_(attention_mask_l, float("-inf")) attn_weights.masked_fill_(attention_mask_l, float("-inf"))
attn_weights_v = attn_weights.softmax(dim=-1) attn_weights_v = attn_weights.softmax(dim=-1)
@ -275,13 +287,21 @@ class BiAttentionBlock(nn.Module):
self.layer_norm_v = nn.LayerNorm(v_dim) self.layer_norm_v = nn.LayerNorm(v_dim)
self.layer_norm_l = nn.LayerNorm(l_dim) self.layer_norm_l = nn.LayerNorm(l_dim)
self.attn = BiMultiHeadAttention( self.attn = BiMultiHeadAttention(
v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout v_dim=v_dim,
l_dim=l_dim,
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
) )
# add layer scale for training stability # add layer scale for training stability
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True) self.gamma_v = nn.Parameter(
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True) init_values * torch.ones((v_dim)), requires_grad=True
)
self.gamma_l = nn.Parameter(
init_values * torch.ones((l_dim)), requires_grad=True
)
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v) v = self.layer_norm_v(v)

@ -100,13 +100,17 @@ class GroundingDINO(nn.Module):
self.bert.pooler.dense.bias.requires_grad_(False) self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert) self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True) self.feat_map = nn.Linear(
self.bert.config.hidden_size, self.hidden_dim, bias=True
)
nn.init.constant_(self.feat_map.bias.data, 0) nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data) nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze # freeze
# special tokens # special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"]) self.specical_tokens = self.tokenizer.convert_tokens_to_ids(
["[CLS]", "[SEP]", ".", "?"]
)
# prepare input projection layers # prepare input projection layers
if num_feature_levels > 1: if num_feature_levels > 1:
@ -123,14 +127,18 @@ class GroundingDINO(nn.Module):
for _ in range(num_feature_levels - num_backbone_outs): for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append( input_proj_list.append(
nn.Sequential( nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.Conv2d(
in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
),
nn.GroupNorm(32, hidden_dim), nn.GroupNorm(32, hidden_dim),
) )
) )
in_channels = hidden_dim in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list) self.input_proj = nn.ModuleList(input_proj_list)
else: else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!" assert (
two_stage_type == "no"
), "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList( self.input_proj = nn.ModuleList(
[ [
nn.Sequential( nn.Sequential(
@ -157,12 +165,17 @@ class GroundingDINO(nn.Module):
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share: if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)] box_embed_layerlist = [
_bbox_embed for i in range(transformer.num_decoder_layers)
]
else: else:
box_embed_layerlist = [ box_embed_layerlist = [
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) copy.deepcopy(_bbox_embed)
for i in range(transformer.num_decoder_layers)
]
class_embed_layerlist = [
_class_embed for i in range(transformer.num_decoder_layers)
] ]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist) self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist) self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed self.transformer.decoder.bbox_embed = self.bbox_embed
@ -170,9 +183,10 @@ class GroundingDINO(nn.Module):
# two stage # two stage
self.two_stage_type = two_stage_type self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( assert two_stage_type in [
two_stage_type "no",
) "standard",
], "unknown param {} of two_stage_type".format(two_stage_type)
if two_stage_type != "no": if two_stage_type != "no":
if two_stage_bbox_embed_share: if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share assert dec_pred_bbox_embed_share
@ -237,12 +251,18 @@ class GroundingDINO(nn.Module):
] ]
position_ids = position_ids[:, : self.max_text_len] position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] tokenized["attention_mask"] = tokenized["attention_mask"][
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] :, : self.max_text_len
]
tokenized["token_type_ids"] = tokenized["token_type_ids"][
:, : self.max_text_len
]
# extract text embeddings # extract text embeddings
if self.sub_sentence_present: if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} tokenized_for_encoder = {
k: v for k, v in tokenized.items() if k != "attention_mask"
}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids tokenized_for_encoder["position_ids"] = position_ids
else: else:
@ -251,7 +271,9 @@ class GroundingDINO(nn.Module):
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768 bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model encoded_text = self.feat_map(
bert_output["last_hidden_state"]
) # bs, 195, d_model
text_token_mask = tokenized.attention_mask.bool() # bs, 195 text_token_mask = tokenized.attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask # text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask # text_self_attention_masks: True for nomask, False for mask
@ -292,7 +314,9 @@ class GroundingDINO(nn.Module):
else: else:
src = self.input_proj[l](srcs[-1]) src = self.input_proj[l](srcs[-1])
m = samples.mask m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
torch.bool
)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src) srcs.append(src)
masks.append(mask) masks.append(mask)
@ -350,7 +374,6 @@ class GroundingDINO(nn.Module):
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino") @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args): def build_groundingdino(args):
backbone = build_backbone(args) backbone = build_backbone(args)
transformer = build_transformer(args) transformer = build_transformer(args)

@ -34,7 +34,9 @@ except BaseException:
# helpers # helpers
def _is_power_of_2(n): def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0): if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) raise ValueError(
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
)
return (n & (n - 1) == 0) and n != 0 return (n & (n - 1) == 0) and n != 0
@ -96,7 +98,6 @@ def multi_scale_deformable_attn_pytorch(
sampling_locations: torch.Tensor, sampling_locations: torch.Tensor,
attention_weights: torch.Tensor, attention_weights: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
bs, _, num_heads, embed_dims = value.shape bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
@ -108,7 +109,10 @@ def multi_scale_deformable_attn_pytorch(
# bs, num_heads*embed_dims, H_*W_ -> # bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_ # bs*num_heads, embed_dims, H_, W_
value_l_ = ( value_l_ = (
value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) value_list[level]
.flatten(2)
.transpose(1, 2)
.reshape(bs * num_heads, embed_dims, H_, W_)
) )
# bs, num_queries, num_heads, num_points, 2 -> # bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 ->
@ -116,7 +120,11 @@ def multi_scale_deformable_attn_pytorch(
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points # bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample( sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False value_l_,
sampling_grid_l_,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
) )
sampling_value_list.append(sampling_value_l_) sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) -> # (bs, num_queries, num_heads, num_levels, num_points) ->
@ -184,8 +192,12 @@ class MultiScaleDeformableAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.num_levels = num_levels self.num_levels = num_levels
self.num_points = num_points self.num_points = num_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2) self.sampling_offsets = nn.Linear(
self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points) embed_dim, num_heads * num_levels * num_points * 2
)
self.attention_weights = nn.Linear(
embed_dim, num_heads * num_levels * num_points
)
self.value_proj = nn.Linear(embed_dim, embed_dim) self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim)
@ -306,7 +318,9 @@ class MultiScaleDeformableAttention(nn.Module):
# bs, num_query, num_heads, num_levels, num_points, 2 # bs, num_query, num_heads, num_levels, num_points, 2
if reference_points.shape[-1] == 2: if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1
)
sampling_locations = ( sampling_locations = (
reference_points[:, :, None, :, None, :] reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
@ -370,7 +384,9 @@ def create_dummy_class(klass, dependency, message=""):
Returns: Returns:
class: a class object class: a class object
""" """
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass) err = "Cannot import '{}', therefore '{}' is not available.".format(
dependency, klass
)
if message: if message:
err = err + " " + message err = err + " " + message
@ -399,7 +415,9 @@ def create_dummy_func(func, dependency, message=""):
Returns: Returns:
function: a function object function: a function object
""" """
err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func) err = "Cannot import '{}', therefore '{}' is not available.".format(
dependency, func
)
if message: if message:
err = err + " " + message err = err + " " + message

@ -82,7 +82,13 @@ class Transformer(nn.Module):
# choose encoder layer type # choose encoder layer type
encoder_layer = DeformableTransformerEncoderLayer( encoder_layer = DeformableTransformerEncoderLayer(
d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points d_model,
dim_feedforward,
dropout,
activation,
num_feature_levels,
nhead,
enc_n_points,
) )
if use_text_enhancer: if use_text_enhancer:
@ -154,7 +160,9 @@ class Transformer(nn.Module):
if num_feature_levels > 1: if num_feature_levels > 1:
if self.num_encoder_layers > 0: if self.num_encoder_layers > 0:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self.level_embed = nn.Parameter(
torch.Tensor(num_feature_levels, d_model)
)
else: else:
self.level_embed = None self.level_embed = None
@ -169,9 +177,10 @@ class Transformer(nn.Module):
# for two stage # for two stage
self.two_stage_type = two_stage_type self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( assert two_stage_type in [
two_stage_type "no",
) "standard",
], "unknown param {} of two_stage_type".format(two_stage_type)
if two_stage_type == "standard": if two_stage_type == "standard":
# anchor selection at the output of encoder # anchor selection at the output of encoder
self.enc_output = nn.Linear(d_model, d_model) self.enc_output = nn.Linear(d_model, d_model)
@ -208,7 +217,16 @@ class Transformer(nn.Module):
def init_ref_points(self, use_num_queries): def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, 4) self.refpoint_embed = nn.Embedding(use_num_queries, 4)
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None): def forward(
self,
srcs,
masks,
refpoint_embed,
pos_embeds,
tgt,
attn_mask=None,
text_dict=None,
):
""" """
Input: Input:
- srcs: List of multi features [bs, ci, hi, wi] - srcs: List of multi features [bs, ci, hi, wi]
@ -287,7 +305,9 @@ class Transformer(nn.Module):
output_memory = self.enc_output_norm(self.enc_output(output_memory)) output_memory = self.enc_output_norm(self.enc_output(output_memory))
if text_dict is not None: if text_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict) enc_outputs_class_unselected = self.enc_out_class_embed(
output_memory, text_dict
)
else: else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
@ -301,7 +321,9 @@ class Transformer(nn.Module):
# gather boxes # gather boxes
refpoint_embed_undetach = torch.gather( refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) enc_outputs_coord_unselected,
1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
) # unsigmoid ) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach() refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather( init_box_proposal = torch.gather(
@ -310,7 +332,9 @@ class Transformer(nn.Module):
# gather tgt # gather tgt
tgt_undetach = torch.gather( tgt_undetach = torch.gather(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) output_memory,
1,
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model),
) )
if self.embed_init_tgt: if self.embed_init_tgt:
tgt_ = ( tgt_ = (
@ -350,7 +374,9 @@ class Transformer(nn.Module):
init_box_proposal = refpoint_embed_.sigmoid() init_box_proposal = refpoint_embed_.sigmoid()
else: else:
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type)) raise NotImplementedError(
"unknown two_stage_type {}".format(self.two_stage_type)
)
######################################################### #########################################################
# End preparing tgt # End preparing tgt
# - tgt: bs, NQ, d_model # - tgt: bs, NQ, d_model
@ -432,7 +458,9 @@ class TransformerEncoder(nn.Module):
self.text_layers = [] self.text_layers = []
self.fusion_layers = [] self.fusion_layers = []
if num_layers > 0: if num_layers > 0:
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share) self.layers = _get_clones(
encoder_layer, num_layers, layer_share=enc_layer_share
)
if text_enhance_layer is not None: if text_enhance_layer is not None:
self.text_layers = _get_clones( self.text_layers = _get_clones(
@ -465,7 +493,6 @@ class TransformerEncoder(nn.Module):
def get_reference_points(spatial_shapes, valid_ratios, device): def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = [] reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes): for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
@ -534,7 +561,9 @@ class TransformerEncoder(nn.Module):
.unsqueeze(-1) .unsqueeze(-1)
.repeat(bs, 1, 1) .repeat(bs, 1, 1)
) )
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False) pos_text = get_sine_pos_embed(
pos_text, num_pos_feats=256, exchange_xy=False
)
if position_ids is not None: if position_ids is not None:
pos_text = get_sine_pos_embed( pos_text = get_sine_pos_embed(
position_ids[..., None], num_pos_feats=256, exchange_xy=False position_ids[..., None], num_pos_feats=256, exchange_xy=False
@ -662,7 +691,6 @@ class TransformerDecoder(nn.Module):
ref_points = [reference_points] ref_points = [reference_points]
for layer_id, layer in enumerate(self.layers): for layer_id, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4: if reference_points.shape[-1] == 4:
reference_points_input = ( reference_points_input = (
reference_points[:, :, None] reference_points[:, :, None]
@ -670,7 +698,9 @@ class TransformerDecoder(nn.Module):
) # nq, bs, nlevel, 4 ) # nq, bs, nlevel, 4
else: else:
assert reference_points.shape[-1] == 2 assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :] reference_points_input = (
reference_points[:, :, None] * valid_ratios[None, :]
)
query_sine_embed = gen_sineembed_for_position( query_sine_embed = gen_sineembed_for_position(
reference_points_input[:, :, 0, :] reference_points_input[:, :, 0, :]
) # nq, bs, 256*2 ) # nq, bs, 256*2
@ -777,7 +807,13 @@ class DeformableTransformerEncoderLayer(nn.Module):
return src return src
def forward( def forward(
self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None self,
src,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask=None,
): ):
# self attention # self attention
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()

@ -26,7 +26,9 @@ from .utils import (
class TextTransformer(nn.Module): class TextTransformer(nn.Module):
def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): def __init__(
self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1
):
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.d_model = d_model self.d_model = d_model
@ -35,7 +37,10 @@ class TextTransformer(nn.Module):
self.norm = None self.norm = None
single_encoder_layer = TransformerEncoderLayer( single_encoder_layer = TransformerEncoderLayer(
d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout d_model=d_model,
nhead=nheads,
dim_feedforward=dim_feedforward,
dropout=dropout,
) )
self.layers = _get_clones(single_encoder_layer, num_layers) self.layers = _get_clones(single_encoder_layer, num_layers)

@ -39,14 +39,20 @@ def get_sine_pos_embed(
""" """
scale = 2 * math.pi scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) dim_t = temperature ** (
2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats
)
def sine_func(x: torch.Tensor): def sine_func(x: torch.Tensor):
sin_x = x * scale / dim_t sin_x = x * scale / dim_t
sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2) sin_x = torch.stack(
(sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3
).flatten(2)
return sin_x return sin_x
pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)] pos_res = [
sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)
]
if exchange_xy: if exchange_xy:
pos_res[0], pos_res[1] = pos_res[1], pos_res[0] pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = torch.cat(pos_res, dim=-1) pos_res = torch.cat(pos_res, dim=-1)
@ -70,7 +76,9 @@ def gen_encoder_output_proposals(
proposals = [] proposals = []
_cur = 0 _cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes): for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur: (_cur + H_ * W_)].view(N_, H_, W_, 1) mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
N_, H_, W_, 1
)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
@ -82,7 +90,9 @@ def gen_encoder_output_proposals(
) )
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
N_, 1, 1, 2
)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
if learnedwh is not None: if learnedwh is not None:
@ -99,15 +109,21 @@ def gen_encoder_output_proposals(
_cur += H_ * W_ _cur += H_ * W_
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
output_proposals = torch.cat(proposals, 1) output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( output_proposals_valid = (
-1, keepdim=True (output_proposals > 0.01) & (output_proposals < 0.99)
) ).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) output_proposals = output_proposals.masked_fill(
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float("inf")
)
output_memory = memory output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(
memory_padding_mask.unsqueeze(-1), float(0)
)
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
# output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
@ -136,7 +152,12 @@ class RandomBoxPerturber:
def sigmoid_focal_loss( def sigmoid_focal_loss(
inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False inputs,
targets,
num_boxes,
alpha: float = 0.25,
gamma: float = 2,
no_reduction=False,
): ):
""" """
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
@ -206,23 +227,31 @@ def gen_sineembed_for_position(pos_tensor):
# sineembed_tensor = torch.zeros(n_query, bs, 256) # sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128) dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / 128)
x_embed = pos_tensor[:, :, 0] * scale x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_x = torch.stack(
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
).flatten(2)
pos_y = torch.stack(
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
).flatten(2)
if pos_tensor.size(-1) == 2: if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2) pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4: elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) pos_w = torch.stack(
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos_h = torch.stack(
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else: else:
@ -262,7 +291,9 @@ class ContrastiveEmbed(nn.Module):
res.masked_fill_(~text_token_mask[:, None, :], float("-inf")) res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
# padding to max_text_len # padding to max_text_len
new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device) new_res = torch.full(
(*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device
)
new_res[..., : res.shape[-1]] = res new_res[..., : res.shape[-1]] = res
return new_res return new_res

@ -57,7 +57,9 @@ class Registry(object):
if module_name is None: if module_name is None:
module_name = module_build_function.__name__ module_name = module_build_function.__name__
if not force and module_name in self._module_dict: if not force and module_name in self._module_dict:
raise KeyError("{} is already registered in {}".format(module_name, self.name)) raise KeyError(
"{} is already registered in {}".format(module_name, self.name)
)
self._module_dict[module_name] = module_build_function self._module_dict[module_name] = module_build_function
return module_build_function return module_build_function

@ -22,7 +22,9 @@ def get_tokenlizer(text_encoder_type):
def get_pretrained_language_model(text_encoder_type): def get_pretrained_language_model(text_encoder_type):
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)): if text_encoder_type == "bert-base-uncased" or (
os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)
):
return BertModel.from_pretrained(text_encoder_type) return BertModel.from_pretrained(text_encoder_type)
if text_encoder_type == "roberta-base": if text_encoder_type == "roberta-base":
return RobertaModel.from_pretrained(text_encoder_type) return RobertaModel.from_pretrained(text_encoder_type)

@ -26,7 +26,9 @@ def preprocess_caption(caption: str) -> str:
return result + "." return result + "."
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): def load_model(
model_config_path: str, model_checkpoint_path: str, device: str = "cuda"
):
args = SLConfig.fromfile(model_config_path) args = SLConfig.fromfile(model_config_path)
args.device = device args.device = device
model = build_model(args) model = build_model(args)
@ -57,7 +59,7 @@ def predict(
box_threshold: float, box_threshold: float,
text_threshold: float, text_threshold: float,
device: str = "cuda", device: str = "cuda",
remove_combined: bool = False remove_combined: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption) caption = preprocess_caption(caption=caption)
@ -67,8 +69,12 @@ def predict(
with torch.no_grad(): with torch.no_grad():
outputs = model(image[None], captions=[caption]) outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) prediction_logits = (
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) outputs["pred_logits"].cpu().sigmoid()[0]
) # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[
0
] # prediction_boxes.shape = (nq, 4)
mask = prediction_logits.max(dim=1)[0] > box_threshold mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256) logits = prediction_logits[mask] # logits.shape = (n, 256)
@ -78,7 +84,11 @@ def predict(
tokenized = tokenizer(caption) tokenized = tokenizer(caption)
if remove_combined: if remove_combined:
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]] sep_idx = [
i
for i in range(len(tokenized["input_ids"]))
if tokenized["input_ids"][i] in [101, 102, 1012]
]
phrases = [] phrases = []
for logit in logits: for logit in logits:
@ -86,32 +96,40 @@ def predict(
insert_idx = bisect.bisect_left(sep_idx, max_idx) insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx] right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1] left_idx = sep_idx[insert_idx - 1]
phrases.append(get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', '')) phrases.append(
get_phrases_from_posmap(
logit > text_threshold, tokenized, tokenizer, left_idx, right_idx
).replace(".", "")
)
else: else:
phrases = [ phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') get_phrases_from_posmap(
for logit logit > text_threshold, tokenized, tokenizer
in logits ).replace(".", "")
for logit in logits
] ]
return boxes, logits.max(dim=1)[0], phrases return boxes, logits.max(dim=1)[0], phrases
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray: def annotate(
image_source: np.ndarray,
boxes: torch.Tensor,
logits: torch.Tensor,
phrases: List[str],
) -> np.ndarray:
h, w, _ = image_source.shape h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h]) boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
detections = sv.Detections(xyxy=xyxy) detections = sv.Detections(xyxy=xyxy)
labels = [ labels = [f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits)]
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(phrases, logits)
]
box_annotator = sv.BoxAnnotator() box_annotator = sv.BoxAnnotator()
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) annotated_frame = box_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
return annotated_frame return annotated_frame
@ -121,17 +139,13 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
class Model: class Model:
def __init__( def __init__(
self, self, model_config_path: str, model_checkpoint_path: str, device: str = "cuda"
model_config_path: str,
model_checkpoint_path: str,
device: str = "cuda"
): ):
self.model = load_model( self.model = load_model(
model_config_path=model_config_path, model_config_path=model_config_path,
model_checkpoint_path=model_checkpoint_path, model_checkpoint_path=model_checkpoint_path,
device=device device=device,
).to(device) ).to(device)
self.device = device self.device = device
@ -140,7 +154,7 @@ class Model:
image: np.ndarray, image: np.ndarray,
caption: str, caption: str,
box_threshold: float = 0.35, box_threshold: float = 0.35,
text_threshold: float = 0.25 text_threshold: float = 0.25,
) -> Tuple[sv.Detections, List[str]]: ) -> Tuple[sv.Detections, List[str]]:
""" """
import cv2 import cv2
@ -167,13 +181,12 @@ class Model:
caption=caption, caption=caption,
box_threshold=box_threshold, box_threshold=box_threshold,
text_threshold=text_threshold, text_threshold=text_threshold,
device=self.device) device=self.device,
)
source_h, source_w, _ = image.shape source_h, source_w, _ = image.shape
detections = Model.post_process_result( detections = Model.post_process_result(
source_h=source_h, source_h=source_h, source_w=source_w, boxes=boxes, logits=logits
source_w=source_w, )
boxes=boxes,
logits=logits)
return detections, phrases return detections, phrases
def predict_with_classes( def predict_with_classes(
@ -181,7 +194,7 @@ class Model:
image: np.ndarray, image: np.ndarray,
classes: List[str], classes: List[str],
box_threshold: float, box_threshold: float,
text_threshold: float text_threshold: float,
) -> sv.Detections: ) -> sv.Detections:
""" """
import cv2 import cv2
@ -210,13 +223,12 @@ class Model:
caption=caption, caption=caption,
box_threshold=box_threshold, box_threshold=box_threshold,
text_threshold=text_threshold, text_threshold=text_threshold,
device=self.device) device=self.device,
)
source_h, source_w, _ = image.shape source_h, source_w, _ = image.shape
detections = Model.post_process_result( detections = Model.post_process_result(
source_h=source_h, source_h=source_h, source_w=source_w, boxes=boxes, logits=logits
source_w=source_w, )
boxes=boxes,
logits=logits)
class_id = Model.phrases2classes(phrases=phrases, classes=classes) class_id = Model.phrases2classes(phrases=phrases, classes=classes)
detections.class_id = class_id detections.class_id = class_id
return detections return detections
@ -236,10 +248,7 @@ class Model:
@staticmethod @staticmethod
def post_process_result( def post_process_result(
source_h: int, source_h: int, source_w: int, boxes: torch.Tensor, logits: torch.Tensor
source_w: int,
boxes: torch.Tensor,
logits: torch.Tensor
) -> sv.Detections: ) -> sv.Detections:
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h]) boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()

@ -29,7 +29,9 @@ class _ColorfulFormatter(logging.Formatter):
# so that calling setup_logger multiple times won't add many handlers # so that calling setup_logger multiple times won't add many handlers
@functools.lru_cache() @functools.lru_cache()
def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None): def setup_logger(
output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None
):
""" """
Initialize the detectron2 logger and set its verbosity level to "INFO". Initialize the detectron2 logger and set its verbosity level to "INFO".

@ -135,7 +135,9 @@ def all_gather_cpu(data):
# obtain Tensor size of each rank # obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)] size_list = [
torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
]
if cpu_group is None: if cpu_group is None:
dist.all_gather(size_list, local_size) dist.all_gather(size_list, local_size)
else: else:
@ -153,7 +155,9 @@ def all_gather_cpu(data):
for _ in size_list: for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
if local_size != max_size: if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device) padding = torch.empty(
size=(max_size - local_size,), dtype=torch.uint8, device=device
)
tensor = torch.cat((tensor, padding), dim=0) tensor = torch.cat((tensor, padding), dim=0)
if cpu_group is None: if cpu_group is None:
dist.all_gather(tensor_list, tensor) dist.all_gather(tensor_list, tensor)
@ -205,7 +209,9 @@ def all_gather(data):
for _ in size_list: for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size: if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") padding = torch.empty(
size=(max_size - local_size,), dtype=torch.uint8, device="cuda"
)
tensor = torch.cat((tensor, padding), dim=0) tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor) dist.all_gather(tensor_list, tensor)
@ -261,7 +267,9 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
@ -434,7 +442,9 @@ class NestedTensor(object):
return NestedTensor(cast_tensor, cast_mask) return NestedTensor(cast_tensor, cast_mask)
def to_img_list_single(self, tensor, mask): def to_img_list_single(self, tensor, mask):
assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim()) assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(
tensor.dim()
)
maxH = (~mask).sum(0).max() maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max() maxW = (~mask).sum(1).max()
img = tensor[:, :maxH, :maxW] img = tensor[:, :maxH, :maxW]
@ -516,11 +526,15 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
padded_masks = [] padded_masks = []
for img in tensor_list: for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_img = torch.nn.functional.pad(
img, (0, padding[2], 0, padding[1], 0, padding[0])
)
padded_imgs.append(padded_img) padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_mask = torch.nn.functional.pad(
m, (0, padding[2], 0, padding[1]), "constant", 1
)
padded_masks.append(padded_mask.to(torch.bool)) padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs) tensor = torch.stack(padded_imgs)
@ -575,7 +589,9 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args): def init_distributed_mode(args):
if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and if (
"WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != ""
): # 'RANK' in os.environ and
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"]) args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"]) args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
@ -615,11 +631,17 @@ def init_distributed_mode(args):
args.local_rank = 0 args.local_rank = 0
return return
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) print(
"world_size:{} rank:{} local_rank:{}".format(
args.world_size, args.rank, args.local_rank
)
)
args.distributed = True args.distributed = True
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
args.dist_backend = "nccl" args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) print(
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.dist_backend, backend=args.dist_backend,
@ -666,7 +688,9 @@ def accuracy_onehot(pred, gt):
return acc return acc
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
""" """
Equivalent to nn.functional.interpolate, but with support for empty batch sizes. Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
@ -675,13 +699,17 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
""" """
if __torchvision_need_compat_flag < 0.7: if __torchvision_need_compat_flag < 0.7:
if input.numel() > 0: if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor) output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape) output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape) return _new_empty_tensor(input, output_shape)
else: else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) return torchvision.ops.misc.interpolate(
input, size, scale_factor, mode, align_corners
)
class color_sys: class color_sys:
@ -693,7 +721,12 @@ class color_sys:
lightness = (50 + np.random.rand() * 10) / 100.0 lightness = (50 + np.random.rand() * 10) / 100.0
saturation = (90 + np.random.rand() * 10) / 100.0 saturation = (90 + np.random.rand() * 10) / 100.0
colors.append( colors.append(
tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)]) tuple(
[
int(j * 255)
for j in colorsys.hls_to_rgb(hue, lightness, saturation)
]
)
) )
self.colors = colors self.colors = colors

@ -31,7 +31,9 @@ class ConfigDict(Dict):
try: try:
value = super(ConfigDict, self).__getattr__(name) value = super(ConfigDict, self).__getattr__(name)
except KeyError: except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") ex = AttributeError(
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
)
except Exception as e: except Exception as e:
ex = e ex = e
else: else:
@ -79,9 +81,11 @@ class SLConfig(object):
check_file_exist(filename) check_file_exist(filename)
if filename.lower().endswith(".py"): if filename.lower().endswith(".py"):
with tempfile.TemporaryDirectory() as temp_config_dir: with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py") temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=".py"
)
temp_config_name = osp.basename(temp_config_file.name) temp_config_name = osp.basename(temp_config_file.name)
if os.name == 'nt': if os.name == "nt":
temp_config_file.close() temp_config_file.close()
shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name)) shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
temp_module_name = osp.splitext(temp_config_name)[0] temp_module_name = osp.splitext(temp_config_name)[0]
@ -90,7 +94,9 @@ class SLConfig(object):
mod = import_module(temp_module_name) mod = import_module(temp_module_name)
sys.path.pop(0) sys.path.pop(0)
cfg_dict = { cfg_dict = {
name: value for name, value in mod.__dict__.items() if not name.startswith("__") name: value
for name, value in mod.__dict__.items()
if not name.startswith("__")
} }
# delete imported module # delete imported module
del sys.modules[temp_module_name] del sys.modules[temp_module_name]
@ -111,7 +117,9 @@ class SLConfig(object):
if BASE_KEY in cfg_dict: if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename) cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY) base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(base_filename, list) else [base_filename] base_filename = (
base_filename if isinstance(base_filename, list) else [base_filename]
)
cfg_dict_list = list() cfg_dict_list = list()
cfg_text_list = list() cfg_text_list = list()
@ -156,7 +164,6 @@ class SLConfig(object):
b = b.copy() b = b.copy()
for k, v in a.items(): for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict) and not isinstance(b[k], list): if not isinstance(b[k], dict) and not isinstance(b[k], list):
# if : # if :
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
@ -172,7 +179,8 @@ class SLConfig(object):
_ = int(k) _ = int(k)
except BaseException: except BaseException:
raise TypeError( raise TypeError(
f"b is a list, " f"index {k} should be an int when input but {type(k)}" f"b is a list, "
f"index {k} should be an int when input but {type(k)}"
) )
b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)]) b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
else: else:
@ -215,7 +223,6 @@ class SLConfig(object):
@property @property
def pretty_text(self): def pretty_text(self):
indent = 4 indent = 4
def _indent(s_, num_spaces): def _indent(s_, num_spaces):

@ -40,7 +40,9 @@ def renorm(
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W) # img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img # return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() assert img.dim() == 3 or img.dim() == 4, (
"img.dim() should be 3 or 4 but %d" % img.dim()
)
if img.dim() == 3: if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0), img.size(0),
@ -147,8 +149,12 @@ class CocoClassMapper:
"89": 79, "89": 79,
"90": 80, "90": 80,
} }
self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()} self.origin2compact_mapper = {
self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()} int(k): v - 1 for k, v in self.category_map_str.items()
}
self.compact2origin_mapper = {
int(v - 1): int(k) for k, v in self.category_map_str.items()
}
def origin2compact(self, idx): def origin2compact(self, idx):
return self.origin2compact_mapper[int(idx)] return self.origin2compact_mapper[int(idx)]
@ -271,6 +277,7 @@ def get_embedder(multires, i=0):
def embed(x, eo=embedder_obj): def embed(x, eo=embedder_obj):
return eo.embed(x) return eo.embed(x)
return embed, embedder_obj.out_dim return embed, embedder_obj.out_dim
@ -381,7 +388,9 @@ class NiceRepr:
return str(len(self)) return str(len(self))
else: else:
# In all other cases force the subclass to overload __nice__ # In all other cases force the subclass to overload __nice__
raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}") raise NotImplementedError(
f"Define the __nice__ method for {self.__class__!r}"
)
def __repr__(self): def __repr__(self):
"""str: the string of the module""" """str: the string of the module"""
@ -496,7 +505,9 @@ class ModelEma(torch.nn.Module):
ema_v.copy_(update_fn(ema_v, model_v)) ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model): def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m) self._update(
model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m
)
def set(self, model): def set(self, model):
self._update(model, update_fn=lambda e, m: m) self._update(model, update_fn=lambda e, m: m)
@ -594,16 +605,21 @@ def targets_to(targets: List[Dict[str, Any]], device):
"dataset_type", "dataset_type",
] ]
return [ return [
{k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets {k: v.to(device) if k not in excluded_keys else v for k, v in t.items()}
for t in targets
] ]
def get_phrases_from_posmap( def get_phrases_from_posmap(
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255 posmap: torch.BoolTensor,
tokenized: Dict,
tokenizer: AutoTokenizer,
left_idx: int = 0,
right_idx: int = 255,
): ):
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
if posmap.dim() == 1: if posmap.dim() == 1:
posmap[0: left_idx + 1] = False posmap[0 : left_idx + 1] = False
posmap[right_idx:] = False posmap[right_idx:] = False
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx] token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]

@ -23,7 +23,9 @@ def renorm(
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# img: tensor(3,H,W) or tensor(B,3,H,W) # img: tensor(3,H,W) or tensor(B,3,H,W)
# return: same as img # return: same as img
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() assert img.dim() == 3 or img.dim() == 4, (
"img.dim() should be 3 or 4 but %d" % img.dim()
)
if img.dim() == 3: if img.dim() == 3:
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
img.size(0), img.size(0),
@ -124,7 +126,10 @@ class COCOVisualizer:
) )
else: else:
savename = "{}/{}-{}-{}.png".format( savename = "{}/{}-{}-{}.png".format(
savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-") savedir,
caption,
int(image_id),
str(datetime.datetime.now()).replace(" ", "-"),
) )
print("savename: {}".format(savename)) print("savename: {}".format(savename))
os.makedirs(os.path.dirname(savename), exist_ok=True) os.makedirs(os.path.dirname(savename), exist_ok=True)
@ -188,7 +193,9 @@ class COCOVisualizer:
) )
if "box_label" in tgt: if "box_label" in tgt:
assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, " assert (
len(tgt["box_label"]) == numbox
), f"{len(tgt['box_label'])} = {numbox}, "
for idx, bl in enumerate(tgt["box_label"]): for idx, bl in enumerate(tgt["box_label"]):
_string = str(bl) _string = str(bl)
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
@ -214,7 +221,9 @@ class COCOVisualizer:
tgt["attn"] = [tgt["attn"]] tgt["attn"] = [tgt["attn"]]
for item in tgt["attn"]: for item in tgt["attn"]:
attn_map, basergb = item attn_map, basergb = item
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3) attn_map = (attn_map - attn_map.min()) / (
attn_map.max() - attn_map.min() + 1e-3
)
attn_map = (attn_map * 255).astype(np.uint8) attn_map = (attn_map * 255).astype(np.uint8)
cm = ColorMap(basergb) cm = ColorMap(basergb)
heatmap = cm(attn_map) heatmap = cm(attn_map)
@ -310,7 +319,9 @@ class COCOVisualizer:
# p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
# ax.add_collection(p) # ax.add_collection(p)
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) p = PatchCollection(
polygons, facecolor="none", edgecolors=color, linewidths=2
)
ax.add_collection(p) ax.add_collection(p)
elif datasetType == "captions": elif datasetType == "captions":
for ann in anns: for ann in anns:

@ -16,7 +16,7 @@ def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
""" """
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
for j, tok_list in enumerate(token_span): for j, tok_list in enumerate(token_span):
for (beg, end) in tok_list: for beg, end in tok_list:
beg_pos = tokenized.char_to_token(beg) beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1) end_pos = tokenized.char_to_token(end - 1)
if beg_pos is None: if beg_pos is None:
@ -41,7 +41,7 @@ def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
positive_map[j, beg_pos] = 1 positive_map[j, beg_pos] = 1
break break
else: else:
positive_map[j, beg_pos: end_pos + 1].fill_(1) positive_map[j, beg_pos : end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)

@ -52,7 +52,9 @@ parser.add_argument(
help="The path to the SAM checkpoint to use for mask generation.", help="The path to the SAM checkpoint to use for mask generation.",
) )
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") parser.add_argument(
"--device", type=str, default="cuda", help="The device to run generation on."
)
parser.add_argument( parser.add_argument(
"--convert-to-rle", "--convert-to-rle",
@ -204,7 +206,9 @@ def main(args: argparse.Namespace) -> None:
targets = [args.input] targets = [args.input]
else: else:
targets = [ targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) f
for f in os.listdir(args.input)
if not os.path.isdir(os.path.join(args.input, f))
] ]
targets = [os.path.join(args.input, f) for f in targets] targets = [os.path.join(args.input, f) for f in targets]

@ -24,7 +24,10 @@ parser = argparse.ArgumentParser(
) )
parser.add_argument( parser.add_argument(
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." "--checkpoint",
type=str,
required=True,
help="The path to the SAM model checkpoint.",
) )
parser.add_argument( parser.add_argument(
@ -129,7 +132,9 @@ def run_export(
mask_input_size = [4 * x for x in embed_size] mask_input_size = [4 * x for x in embed_size]
dummy_inputs = { dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), "point_coords": torch.randint(
low=0, high=1024, size=(1, 5, 2), dtype=torch.float
),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float), "has_mask_input": torch.tensor([1], dtype=torch.float),

@ -172,7 +172,9 @@ class SamAutomaticMaskGenerator:
# Encode masks # Encode masks
if self.output_mode == "coco_rle": if self.output_mode == "coco_rle":
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] mask_data["segmentations"] = [
coco_encode_rle(rle) for rle in mask_data["rles"]
]
elif self.output_mode == "binary_mask": elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else: else:
@ -242,7 +244,9 @@ class SamAutomaticMaskGenerator:
# Generate masks for this crop in batches # Generate masks for this crop in batches
data = MaskData() data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image): for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) batch_data = self._process_batch(
points, cropped_im_size, crop_box, orig_size
)
data.cat(batch_data) data.cat(batch_data)
del batch_data del batch_data
self.predictor.reset_image() self.predictor.reset_image()
@ -275,7 +279,9 @@ class SamAutomaticMaskGenerator:
# Run model on this batch # Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size) transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device) in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) in_labels = torch.ones(
in_points.shape[0], dtype=torch.int, device=in_points.device
)
masks, iou_preds, _ = self.predictor.predict_torch( masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :], in_points[:, None, :],
in_labels[:, None], in_labels[:, None],
@ -298,7 +304,9 @@ class SamAutomaticMaskGenerator:
# Calculate stability score # Calculate stability score
data["stability_score"] = calculate_stability_score( data["stability_score"] = calculate_stability_score(
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset data["masks"],
self.predictor.model.mask_threshold,
self.stability_score_offset,
) )
if self.stability_score_thresh > 0.0: if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh keep_mask = data["stability_score"] >= self.stability_score_thresh
@ -309,7 +317,9 @@ class SamAutomaticMaskGenerator:
data["boxes"] = batched_mask_to_box(data["masks"]) data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries # Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) keep_mask = ~is_box_near_crop_edge(
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
)
if not torch.all(keep_mask): if not torch.all(keep_mask):
data.filter(keep_mask) data.filter(keep_mask)

@ -8,7 +8,13 @@ import torch
from functools import partial from functools import partial
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer from .modeling import (
ImageEncoderViT,
MaskDecoder,
PromptEncoder,
Sam,
TwoWayTransformer,
)
def build_sam_vit_h(checkpoint=None): def build_sam_vit_h(checkpoint=None):

@ -66,7 +66,9 @@ class ImageEncoderViT(nn.Module):
if use_abs_pos: if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size. # Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) torch.zeros(
1, img_size // patch_size, img_size // patch_size, embed_dim
)
) )
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
@ -159,7 +161,9 @@ class Block(nn.Module):
) )
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.mlp = MLPBlock(
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
)
self.window_size = window_size self.window_size = window_size
@ -224,23 +228,34 @@ class Attention(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C) # qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = (
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C) # q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1) attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos: if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = add_decomposed_rel_pos(
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = (
(attn @ v)
.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x) x = self.proj(x)
return x return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: def window_partition(
x: torch.Tensor, window_size: int
) -> Tuple[torch.Tensor, Tuple[int, int]]:
""" """
Partition into non-overlapping windows with padding if needed. Partition into non-overlapping windows with padding if needed.
Args: Args:
@ -260,12 +275,17 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
Hp, Wp = H + pad_h, W + pad_w Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp) return windows, (Hp, Wp)
def window_unpartition( def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] windows: torch.Tensor,
window_size: int,
pad_hw: Tuple[int, int],
hw: Tuple[int, int],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Window unpartition into original sequences and removing padding. Window unpartition into original sequences and removing padding.
@ -281,7 +301,9 @@ def window_unpartition(
Hp, Wp = pad_hw Hp, Wp = pad_hw
H, W = hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size) B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = windows.view(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W: if Hp > H or Wp > W:
@ -355,7 +377,9 @@ def add_decomposed_rel_pos(
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = ( attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] attn.view(B, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, None]
+ rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w) ).view(B, q_h * q_w, k_h * k_w)
return attn return attn

@ -51,10 +51,14 @@ class MaskDecoder(nn.Module):
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential( self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), nn.ConvTranspose2d(
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
),
LayerNorm2d(transformer_dim // 4), LayerNorm2d(transformer_dim // 4),
activation(), activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), nn.ConvTranspose2d(
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
),
activation(), activation(),
) )
self.output_hypernetworks_mlps = nn.ModuleList( self.output_hypernetworks_mlps = nn.ModuleList(
@ -118,8 +122,12 @@ class MaskDecoder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details.""" """Predicts masks. See 'forward' for more details."""
# Concatenate output tokens # Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = torch.cat(
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) [self.iou_token.weight, self.mask_tokens.weight], dim=0
)
output_tokens = output_tokens.unsqueeze(0).expand(
sparse_prompt_embeddings.size(0), -1, -1
)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask # Expand per-image data in batch direction to be per-mask
@ -131,14 +139,16 @@ class MaskDecoder(nn.Module):
# Run the transformer # Run the transformer
hs, src = self.transformer(src, pos_src, tokens) hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :] iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens # Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w) src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src) upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = [] hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens): for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in_list.append(
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
)
hyper_in = torch.stack(hyper_in_list, dim=1) hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

@ -43,11 +43,16 @@ class PromptEncoder(nn.Module):
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] point_embeddings = [
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
]
self.point_embeddings = nn.ModuleList(point_embeddings) self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim) self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) self.mask_input_size = (
4 * image_embedding_size[0],
4 * image_embedding_size[1],
)
self.mask_downscaling = nn.Sequential( self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4), LayerNorm2d(mask_in_chans // 4),
@ -83,7 +88,9 @@ class PromptEncoder(nn.Module):
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1) points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1) labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 0] += self.point_embeddings[0].weight
@ -94,7 +101,9 @@ class PromptEncoder(nn.Module):
"""Embeds box prompts.""" """Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2) coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding = self.pe_layer.forward_with_coords(
coords, self.input_image_size
)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding return corner_embedding
@ -149,7 +158,9 @@ class PromptEncoder(nn.Module):
Bx(embed_dim)x(embed_H)x(embed_W) Bx(embed_dim)x(embed_H)x(embed_W)
""" """
bs = self._get_batch_size(points, boxes, masks) bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) sparse_embeddings = torch.empty(
(bs, 0, self.embed_dim), device=self._get_device()
)
if points is not None: if points is not None:
coords, labels = points coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))

@ -43,7 +43,9 @@ class Sam(nn.Module):
self.image_encoder = image_encoder self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer(
"pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property @property
@ -94,7 +96,9 @@ class Sam(nn.Module):
shape BxCxHxW, where H=W=256. Can be passed as mask input shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction. to subsequent iterations of prediction.
""" """
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) input_images = torch.stack(
[self.preprocess(x["image"]) for x in batched_input], dim=0
)
image_embeddings = self.image_encoder(input_images) image_embeddings = self.image_encoder(input_images)
outputs = [] outputs = []
@ -158,7 +162,9 @@ class Sam(nn.Module):
align_corners=False, align_corners=False,
) )
masks = masks[..., : input_size[0], : input_size[1]] masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) masks = F.interpolate(
masks, original_size, mode="bilinear", align_corners=False
)
return masks return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor: def preprocess(self, x: torch.Tensor) -> torch.Tensor:

@ -198,7 +198,9 @@ class Attention(nn.Module):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." assert (
self.internal_dim % num_heads == 0
), "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim)

@ -55,7 +55,9 @@ class SamPredictor:
# Transform the image to the form expected by the model # Transform the image to the form expected by the model
input_image = self.transform.apply_image(image) input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device) input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
None, :, :, :
]
self.set_torch_image(input_image_torch, image.shape[:2]) self.set_torch_image(input_image_torch, image.shape[:2])
@ -131,7 +133,9 @@ class SamPredictor:
a subsequent iteration as mask input. a subsequent iteration as mask input.
""" """
if not self.is_image_set: if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
# Transform input prompts # Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
@ -140,15 +144,21 @@ class SamPredictor:
point_labels is not None point_labels is not None
), "point_labels must be supplied if point_coords is supplied." ), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size) point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) coords_torch = torch.as_tensor(
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) point_coords, dtype=torch.float, device=self.device
)
labels_torch = torch.as_tensor(
point_labels, dtype=torch.int, device=self.device
)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None: if box is not None:
box = self.transform.apply_boxes(box, self.original_size) box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :] box_torch = box_torch[None, :]
if mask_input is not None: if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) mask_input_torch = torch.as_tensor(
mask_input, dtype=torch.float, device=self.device
)
mask_input_torch = mask_input_torch[None, :, :, :] mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch( masks, iou_predictions, low_res_masks = self.predict_torch(
@ -211,7 +221,9 @@ class SamPredictor:
a subsequent iteration as mask input. a subsequent iteration as mask input.
""" """
if not self.is_image_set: if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
if point_coords is not None: if point_coords is not None:
points = (point_coords, point_labels) points = (point_coords, point_labels)
@ -235,7 +247,9 @@ class SamPredictor:
) )
# Upscale the masks to the original image resolution # Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) masks = self.model.postprocess_masks(
low_res_masks, self.input_size, self.original_size
)
if not return_logits: if not return_logits:
masks = masks > self.model.mask_threshold masks = masks > self.model.mask_threshold
@ -252,7 +266,9 @@ class SamPredictor:
raise RuntimeError( raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding." "An image must be set with .set_image(...) to generate an embedding."
) )
assert self.features is not None, "Features must exist if an image has been set." assert (
self.features is not None
), "Features must exist if an image has been set."
return self.features return self.features
@property @property

@ -101,7 +101,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
), "Batched iteration must have inputs of all the same size." ), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches): for b in range(n_batches):
yield [arg[b * batch_size: (b + 1) * batch_size] for arg in args] yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
@ -142,7 +142,7 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
idx = 0 idx = 0
parity = False parity = False
for count in rle["counts"]: for count in rle["counts"]:
mask[idx: idx + count] = parity mask[idx : idx + count] = parity
idx += count idx += count
parity ^= True parity ^= True
mask = mask.reshape(w, h) mask = mask.reshape(w, h)

@ -48,32 +48,43 @@ class SamOnnxModel(nn.Module):
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size return transformed_size
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: def _embed_points(
self, point_coords: torch.Tensor, point_labels: torch.Tensor
) -> torch.Tensor:
point_coords = point_coords + 0.5 point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1) point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( point_embedding = (
point_labels == -1 point_embedding
+ self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
) )
for i in range(self.model.prompt_encoder.num_point_embeddings): for i in range(self.model.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ point_embedding = (
i point_embedding
].weight * (point_labels == i) + self.model.prompt_encoder.point_embeddings[i].weight
* (point_labels == i)
)
return point_embedding return point_embedding
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: def _embed_masks(
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
) -> torch.Tensor:
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
input_mask
)
mask_embedding = mask_embedding + ( mask_embedding = mask_embedding + (
1 - has_mask_input 1 - has_mask_input
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding return mask_embedding
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: def mask_postprocessing(
self, masks: torch.Tensor, orig_im_size: torch.Tensor
) -> torch.Tensor:
masks = F.interpolate( masks = F.interpolate(
masks, masks,
size=(self.img_size, self.img_size), size=(self.img_size, self.img_size),
@ -81,7 +92,9 @@ class SamOnnxModel(nn.Module):
align_corners=False, align_corners=False,
) )
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
torch.int64
)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
orig_im_size = orig_im_size.to(torch.int64) orig_im_size = orig_im_size.to(torch.int64)

@ -27,10 +27,14 @@ class ResizeLongestSide:
""" """
Expects a numpy array with shape HxWxC in uint8 format. Expects a numpy array with shape HxWxC in uint8 format.
""" """
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) target_size = self.get_preprocess_shape(
image.shape[0], image.shape[1], self.target_length
)
return np.array(resize(to_pil_image(image), target_size)) return np.array(resize(to_pil_image(image), target_size))
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: def apply_coords(
self, coords: np.ndarray, original_size: Tuple[int, ...]
) -> np.ndarray:
""" """
Expects a numpy array of length 2 in the final dimension. Requires the Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format. original image size in (H, W) format.
@ -44,7 +48,9 @@ class ResizeLongestSide:
coords[..., 1] = coords[..., 1] * (new_h / old_h) coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords return coords
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: def apply_boxes(
self, boxes: np.ndarray, original_size: Tuple[int, ...]
) -> np.ndarray:
""" """
Expects a numpy array shape Bx4. Requires the original image size Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format. in (H, W) format.
@ -59,7 +65,9 @@ class ResizeLongestSide:
the transformation expected by the model. the transformation expected by the model.
""" """
# Expects an image in BCHW format. May not exactly match apply_image. # Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) target_size = self.get_preprocess_shape(
image.shape[2], image.shape[3], self.target_length
)
return F.interpolate( return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True image, target_size, mode="bilinear", align_corners=False, antialias=True
) )
@ -91,7 +99,9 @@ class ResizeLongestSide:
return boxes.reshape(-1, 4) return boxes.reshape(-1, 4)
@staticmethod @staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: def get_preprocess_shape(
oldh: int, oldw: int, long_side_length: int
) -> Tuple[int, int]:
""" """
Compute the output size given input size and target long side length. Compute the output size given input size and target long side length.
""" """

File diff suppressed because it is too large Load Diff

@ -31,7 +31,7 @@ max_length = {
"davinci": 2049, "davinci": 2049,
"curie": 2049, "curie": 2049,
"babbage": 2049, "babbage": 2049,
"ada": 2049 "ada": 2049,
} }
@ -44,14 +44,14 @@ def get_max_context_length(model_name):
def get_token_ids_for_task_parsing(model_name): def get_token_ids_for_task_parsing(model_name):
text = '''{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}''' text = """{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}"""
res = encodings[model_name].encode(text) res = encodings[model_name].encode(text)
res = list(set(res)) res = list(set(res))
return res return res
def get_token_ids_for_choose_model(model_name): def get_token_ids_for_choose_model(model_name):
text = '''{"id": "reason"}''' text = """{"id": "reason"}"""
res = encodings[model_name].encode(text) res = encodings[model_name].encode(text)
res = list(set(res)) res = list(set(res))
return res return res

@ -65,7 +65,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.INFO) handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
@ -100,10 +100,16 @@ def load_pipes(local_deployment):
if local_deployment in ["full"]: if local_deployment in ["full"]:
other_pipes = { other_pipes = {
"nlpconnect/vit-gpt2-image-captioning": { "nlpconnect/vit-gpt2-image-captioning": {
"model": VisionEncoderDecoderModel.from_pretrained(f"{local_fold}/nlpconnect/vit-gpt2-image-captioning"), "model": VisionEncoderDecoderModel.from_pretrained(
"feature_extractor": ViTImageProcessor.from_pretrained(f"{local_fold}/nlpconnect/vit-gpt2-image-captioning"), f"{local_fold}/nlpconnect/vit-gpt2-image-captioning"
"tokenizer": AutoTokenizer.from_pretrained(f"{local_fold}/nlpconnect/vit-gpt2-image-captioning"), ),
"device": device "feature_extractor": ViTImageProcessor.from_pretrained(
f"{local_fold}/nlpconnect/vit-gpt2-image-captioning"
),
"tokenizer": AutoTokenizer.from_pretrained(
f"{local_fold}/nlpconnect/vit-gpt2-image-captioning"
),
"device": device,
}, },
# "Salesforce/blip-image-captioning-large": { # "Salesforce/blip-image-captioning-large": {
# "model": BlipForConditionalGeneration.from_pretrained(f"{local_fold}/Salesforce/blip-image-captioning-large"), # "model": BlipForConditionalGeneration.from_pretrained(f"{local_fold}/Salesforce/blip-image-captioning-large"),
@ -111,8 +117,12 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"damo-vilab/text-to-video-ms-1.7b": { "damo-vilab/text-to-video-ms-1.7b": {
"model": DiffusionPipeline.from_pretrained(f"{local_fold}/damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"), "model": DiffusionPipeline.from_pretrained(
"device": device f"{local_fold}/damo-vilab/text-to-video-ms-1.7b",
torch_dtype=torch.float16,
variant="fp16",
),
"device": device,
}, },
# "facebook/maskformer-swin-large-ade": { # "facebook/maskformer-swin-large-ade": {
# "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_fold}/facebook/maskformer-swin-large-ade"), # "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_fold}/facebook/maskformer-swin-large-ade"),
@ -130,16 +140,22 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": { "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
"model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"), "model": BaseModel.from_pretrained(
"device": device "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"
),
"device": device,
}, },
"espnet/kan-bayashi_ljspeech_vits": { "espnet/kan-bayashi_ljspeech_vits": {
"model": Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits"), "model": Text2Speech.from_pretrained(
"device": device "espnet/kan-bayashi_ljspeech_vits"
),
"device": device,
}, },
"lambdalabs/sd-image-variations-diffusers": { "lambdalabs/sd-image-variations-diffusers": {
"model": DiffusionPipeline.from_pretrained(f"{local_fold}/lambdalabs/sd-image-variations-diffusers"), # torch_dtype=torch.float16 "model": DiffusionPipeline.from_pretrained(
"device": device f"{local_fold}/lambdalabs/sd-image-variations-diffusers"
), # torch_dtype=torch.float16
"device": device,
}, },
# "CompVis/stable-diffusion-v1-4": { # "CompVis/stable-diffusion-v1-4": {
# "model": DiffusionPipeline.from_pretrained(f"{local_fold}/CompVis/stable-diffusion-v1-4"), # "model": DiffusionPipeline.from_pretrained(f"{local_fold}/CompVis/stable-diffusion-v1-4"),
@ -150,8 +166,10 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"runwayml/stable-diffusion-v1-5": { "runwayml/stable-diffusion-v1-5": {
"model": DiffusionPipeline.from_pretrained(f"{local_fold}/runwayml/stable-diffusion-v1-5"), "model": DiffusionPipeline.from_pretrained(
"device": device f"{local_fold}/runwayml/stable-diffusion-v1-5"
),
"device": device,
}, },
# "microsoft/speecht5_tts":{ # "microsoft/speecht5_tts":{
# "processor": SpeechT5Processor.from_pretrained(f"{local_fold}/microsoft/speecht5_tts"), # "processor": SpeechT5Processor.from_pretrained(f"{local_fold}/microsoft/speecht5_tts"),
@ -165,11 +183,19 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"microsoft/speecht5_vc": { "microsoft/speecht5_vc": {
"processor": SpeechT5Processor.from_pretrained(f"{local_fold}/microsoft/speecht5_vc"), "processor": SpeechT5Processor.from_pretrained(
"model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_fold}/microsoft/speecht5_vc"), f"{local_fold}/microsoft/speecht5_vc"
"vocoder": SpeechT5HifiGan.from_pretrained(f"{local_fold}/microsoft/speecht5_hifigan"), ),
"embeddings_dataset": load_dataset(f"{local_fold}/Matthijs/cmu-arctic-xvectors", split="validation"), "model": SpeechT5ForSpeechToSpeech.from_pretrained(
"device": device f"{local_fold}/microsoft/speecht5_vc"
),
"vocoder": SpeechT5HifiGan.from_pretrained(
f"{local_fold}/microsoft/speecht5_hifigan"
),
"embeddings_dataset": load_dataset(
f"{local_fold}/Matthijs/cmu-arctic-xvectors", split="validation"
),
"device": device,
}, },
# "julien-c/wine-quality": { # "julien-c/wine-quality": {
# "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib"))) # "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib")))
@ -180,15 +206,23 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"facebook/maskformer-swin-base-coco": { "facebook/maskformer-swin-base-coco": {
"feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_fold}/facebook/maskformer-swin-base-coco"), "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(
"model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_fold}/facebook/maskformer-swin-base-coco"), f"{local_fold}/facebook/maskformer-swin-base-coco"
"device": device ),
"model": MaskFormerForInstanceSegmentation.from_pretrained(
f"{local_fold}/facebook/maskformer-swin-base-coco"
),
"device": device,
}, },
"Intel/dpt-hybrid-midas": { "Intel/dpt-hybrid-midas": {
"model": DPTForDepthEstimation.from_pretrained(f"{local_fold}/Intel/dpt-hybrid-midas", low_cpu_mem_usage=True), "model": DPTForDepthEstimation.from_pretrained(
"feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_fold}/Intel/dpt-hybrid-midas"), f"{local_fold}/Intel/dpt-hybrid-midas", low_cpu_mem_usage=True
"device": device ),
} "feature_extractor": DPTFeatureExtractor.from_pretrained(
f"{local_fold}/Intel/dpt-hybrid-midas"
),
"device": device,
},
} }
if local_deployment in ["full", "standard"]: if local_deployment in ["full", "standard"]:
@ -198,36 +232,53 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"openai/whisper-base": { "openai/whisper-base": {
"model": pipeline(task="automatic-speech-recognition", model=f"{local_fold}/openai/whisper-base"), "model": pipeline(
"device": device task="automatic-speech-recognition",
model=f"{local_fold}/openai/whisper-base",
),
"device": device,
}, },
"microsoft/speecht5_asr": { "microsoft/speecht5_asr": {
"model": pipeline(task="automatic-speech-recognition", model=f"{local_fold}/microsoft/speecht5_asr"), "model": pipeline(
"device": device task="automatic-speech-recognition",
model=f"{local_fold}/microsoft/speecht5_asr",
),
"device": device,
}, },
"Intel/dpt-large": { "Intel/dpt-large": {
"model": pipeline(task="depth-estimation", model=f"{local_fold}/Intel/dpt-large"), "model": pipeline(
"device": device task="depth-estimation", model=f"{local_fold}/Intel/dpt-large"
),
"device": device,
}, },
# "microsoft/beit-base-patch16-224-pt22k-ft22k": { # "microsoft/beit-base-patch16-224-pt22k-ft22k": {
# "model": pipeline(task="image-classification", model=f"{local_fold}/microsoft/beit-base-patch16-224-pt22k-ft22k"), # "model": pipeline(task="image-classification", model=f"{local_fold}/microsoft/beit-base-patch16-224-pt22k-ft22k"),
# "device": device # "device": device
# }, # },
"facebook/detr-resnet-50-panoptic": { "facebook/detr-resnet-50-panoptic": {
"model": pipeline(task="image-segmentation", model=f"{local_fold}/facebook/detr-resnet-50-panoptic"), "model": pipeline(
"device": device task="image-segmentation",
model=f"{local_fold}/facebook/detr-resnet-50-panoptic",
),
"device": device,
}, },
"facebook/detr-resnet-101": { "facebook/detr-resnet-101": {
"model": pipeline(task="object-detection", model=f"{local_fold}/facebook/detr-resnet-101"), "model": pipeline(
"device": device task="object-detection",
model=f"{local_fold}/facebook/detr-resnet-101",
),
"device": device,
}, },
# "openai/clip-vit-large-patch14": { # "openai/clip-vit-large-patch14": {
# "model": pipeline(task="zero-shot-image-classification", model=f"{local_fold}/openai/clip-vit-large-patch14"), # "model": pipeline(task="zero-shot-image-classification", model=f"{local_fold}/openai/clip-vit-large-patch14"),
# "device": device # "device": device
# }, # },
"google/owlvit-base-patch32": { "google/owlvit-base-patch32": {
"model": pipeline(task="zero-shot-object-detection", model=f"{local_fold}/google/owlvit-base-patch32"), "model": pipeline(
"device": device task="zero-shot-object-detection",
model=f"{local_fold}/google/owlvit-base-patch32",
),
"device": device,
}, },
# "microsoft/DialoGPT-medium": { # "microsoft/DialoGPT-medium": {
# "model": pipeline(task="conversational", model=f"{local_fold}/microsoft/DialoGPT-medium"), # "model": pipeline(task="conversational", model=f"{local_fold}/microsoft/DialoGPT-medium"),
@ -270,86 +321,121 @@ def load_pipes(local_deployment):
# "device": device # "device": device
# }, # },
"impira/layoutlm-document-qa": { "impira/layoutlm-document-qa": {
"model": pipeline(task="document-question-answering", model=f"{local_fold}/impira/layoutlm-document-qa"), "model": pipeline(
"device": device task="document-question-answering",
model=f"{local_fold}/impira/layoutlm-document-qa",
),
"device": device,
}, },
"ydshieh/vit-gpt2-coco-en": { "ydshieh/vit-gpt2-coco-en": {
"model": pipeline(task="image-to-text", model=f"{local_fold}/ydshieh/vit-gpt2-coco-en"), "model": pipeline(
"device": device task="image-to-text", model=f"{local_fold}/ydshieh/vit-gpt2-coco-en"
),
"device": device,
}, },
"dandelin/vilt-b32-finetuned-vqa": { "dandelin/vilt-b32-finetuned-vqa": {
"model": pipeline(task="visual-question-answering", model=f"{local_fold}/dandelin/vilt-b32-finetuned-vqa"), "model": pipeline(
"device": device task="visual-question-answering",
} model=f"{local_fold}/dandelin/vilt-b32-finetuned-vqa",
),
"device": device,
},
} }
if local_deployment in ["full", "standard", "minimal"]: if local_deployment in ["full", "standard", "minimal"]:
controlnet = ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) controlnet = ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
)
controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained( controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained(
f"{local_fold}/runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 f"{local_fold}/runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float16,
) )
def mlsd_control_network(): def mlsd_control_network():
model = MobileV2_MLSD_Large() model = MobileV2_MLSD_Large()
model.load_state_dict(torch.load(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth"), strict=True) model.load_state_dict(
torch.load(
f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/mlsd_large_512_fp32.pth"
),
strict=True,
)
return MLSDdetector(model) return MLSDdetector(model)
hed_network = Network(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/network-bsds500.pth") hed_network = Network(
f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/network-bsds500.pth"
)
controlnet_sd_pipes = { controlnet_sd_pipes = {
"openpose-control": { "openpose-control": {
"model": OpenposeDetector(Body(f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/body_pose_model.pth")) "model": OpenposeDetector(
}, Body(
"mlsd-control": { f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/body_pose_model.pth"
"model": mlsd_control_network() )
}, )
"hed-control": {
"model": HEDdetector(hed_network)
},
"scribble-control": {
"model": HEDdetector(hed_network)
}, },
"mlsd-control": {"model": mlsd_control_network()},
"hed-control": {"model": HEDdetector(hed_network)},
"scribble-control": {"model": HEDdetector(hed_network)},
"midas-control": { "midas-control": {
"model": MidasDetector(model_path=f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt") "model": MidasDetector(
}, model_path=f"{local_fold}/lllyasviel/ControlNet/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
"canny-control": { )
"model": CannyDetector()
}, },
"canny-control": {"model": CannyDetector()},
"lllyasviel/sd-controlnet-canny": { "lllyasviel/sd-controlnet-canny": {
"control": controlnet, "control": controlnet,
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
}, },
"lllyasviel/sd-controlnet-depth": { "lllyasviel/sd-controlnet-depth": {
"control": ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16), "control": ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-depth",
torch_dtype=torch.float16,
),
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
}, },
"lllyasviel/sd-controlnet-hed": { "lllyasviel/sd-controlnet-hed": {
"control": ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16), "control": ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-hed",
torch_dtype=torch.float16,
),
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
}, },
"lllyasviel/sd-controlnet-mlsd": { "lllyasviel/sd-controlnet-mlsd": {
"control": ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16), "control": ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-mlsd",
torch_dtype=torch.float16,
),
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
}, },
"lllyasviel/sd-controlnet-openpose": { "lllyasviel/sd-controlnet-openpose": {
"control": ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16), "control": ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-openpose",
torch_dtype=torch.float16,
),
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
}, },
"lllyasviel/sd-controlnet-scribble": { "lllyasviel/sd-controlnet-scribble": {
"control": ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16), "control": ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-scribble",
torch_dtype=torch.float16,
),
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
}, },
"lllyasviel/sd-controlnet-seg": { "lllyasviel/sd-controlnet-seg": {
"control": ControlNetModel.from_pretrained(f"{local_fold}/lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16), "control": ControlNetModel.from_pretrained(
f"{local_fold}/lllyasviel/sd-controlnet-seg",
torch_dtype=torch.float16,
),
"model": controlnetpipe, "model": controlnetpipe,
"device": device "device": device,
} },
} }
pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes} pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes}
return pipes return pipes
@ -363,14 +449,17 @@ during = end - start
print(f"[ ready ] {during}s") print(f"[ ready ] {during}s")
@app.route('/running', methods=['GET']) @app.route("/running", methods=["GET"])
def running(): def running():
return jsonify({"running": True}) return jsonify({"running": True})
@app.route('/status/<path:model_id>', methods=['GET']) @app.route("/status/<path:model_id>", methods=["GET"])
def status(model_id): def status(model_id):
disabled_models = ["microsoft/trocr-base-printed", "microsoft/trocr-base-handwritten"] disabled_models = [
"microsoft/trocr-base-printed",
"microsoft/trocr-base-handwritten",
]
if model_id in pipes.keys() and model_id not in disabled_models: if model_id in pipes.keys() and model_id not in disabled_models:
print(f"[ check {model_id} ] success") print(f"[ check {model_id} ] success")
return jsonify({"loaded": True}) return jsonify({"loaded": True})
@ -379,7 +468,7 @@ def status(model_id):
return jsonify({"loaded": False}) return jsonify({"loaded": False})
@app.route('/models/<path:model_id>', methods=['POST']) @app.route("/models/<path:model_id>", methods=["POST"])
def models(model_id): def models(model_id):
while "using" in pipes[model_id] and pipes[model_id]["using"]: while "using" in pipes[model_id] and pipes[model_id]["using"]:
print(f"[ inference {model_id} ] waiting") print(f"[ inference {model_id} ] waiting")
@ -402,23 +491,29 @@ def models(model_id):
try: try:
# text to video # text to video
if model_id == "damo-vilab/text-to-video-ms-1.7b": if model_id == "damo-vilab/text-to-video-ms-1.7b":
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
# pipe.enable_model_cpu_offload() # pipe.enable_model_cpu_offload()
prompt = request.get_json()["text"] prompt = request.get_json()["text"]
video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
video_path = export_to_video(video_frames) video_path = export_to_video(video_frames)
file_name = str(uuid.uuid4())[:4] file_name = str(uuid.uuid4())[:4]
os.system(f"LD_LIBRARY_PATH=/usr/local/lib /usr/local/bin/ffmpeg -i {video_path} -vcodec libx264 public/videos/{file_name}.mp4") os.system(
f"LD_LIBRARY_PATH=/usr/local/lib /usr/local/bin/ffmpeg -i {video_path} -vcodec libx264 public/videos/{file_name}.mp4"
)
result = {"path": f"/videos/{file_name}.mp4"} result = {"path": f"/videos/{file_name}.mp4"}
# controlnet # controlnet
if model_id.startswith("lllyasviel/sd-controlnet-"): if model_id.startswith("lllyasviel/sd-controlnet-"):
pipe.controlnet.to('cpu') pipe.controlnet.to("cpu")
pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"]) pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
control_image = load_image(request.get_json()["img_url"]) control_image = load_image(request.get_json()["img_url"])
# generator = torch.manual_seed(66) # generator = torch.manual_seed(66)
out_image: Image = pipe(request.get_json()["text"], num_inference_steps=20, image=control_image).images[0] out_image: Image = pipe(
request.get_json()["text"], num_inference_steps=20, image=control_image
).images[0]
file_name = str(uuid.uuid4())[:4] file_name = str(uuid.uuid4())[:4]
out_image.save(f"public/images/{file_name}.png") out_image.save(f"public/images/{file_name}.png")
result = {"path": f"/images/{file_name}.png"} result = {"path": f"/images/{file_name}.png"}
@ -441,7 +536,8 @@ def models(model_id):
file_name = str(uuid.uuid4())[:4] file_name = str(uuid.uuid4())[:4]
with open(f"public/images/{file_name}.png", "wb") as f: with open(f"public/images/{file_name}.png", "wb") as f:
f.write(request.data) f.write(request.data)
tform = transforms.Compose([ tform = transforms.Compose(
[
transforms.ToTensor(), transforms.ToTensor(),
transforms.Resize( transforms.Resize(
(224, 224), (224, 224),
@ -450,8 +546,10 @@ def models(model_id):
), ),
transforms.Normalize( transforms.Normalize(
[0.48145466, 0.4578275, 0.40821073], [0.48145466, 0.4578275, 0.40821073],
[0.26862954, 0.26130258, 0.27577711]), [0.26862954, 0.26130258, 0.27577711],
]) ),
]
)
inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0) inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
out = pipe(inp, guidance_scale=3) out = pipe(inp, guidance_scale=3)
out["images"][0].save(f"public/images/{file_name}.jpg") out["images"][0].save(f"public/images/{file_name}.jpg")
@ -459,30 +557,47 @@ def models(model_id):
# image to text # image to text
if model_id == "Salesforce/blip-image-captioning-large": if model_id == "Salesforce/blip-image-captioning-large":
raw_image = load_image(request.get_json()["img_url"]).convert('RGB') raw_image = load_image(request.get_json()["img_url"]).convert("RGB")
text = request.get_json()["text"] text = request.get_json()["text"]
inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"]) inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(
pipes[model_id]["device"]
)
out = pipe.generate(**inputs) out = pipe.generate(**inputs)
caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True) caption = pipes[model_id]["processor"].decode(
out[0], skip_special_tokens=True
)
result = {"generated text": caption} result = {"generated text": caption}
if model_id == "ydshieh/vit-gpt2-coco-en": if model_id == "ydshieh/vit-gpt2-coco-en":
img_url = request.get_json()["img_url"] img_url = request.get_json()["img_url"]
generated_text = pipe(img_url)[0]['generated_text'] generated_text = pipe(img_url)[0]["generated_text"]
result = {"generated text": generated_text} result = {"generated text": generated_text}
if model_id == "nlpconnect/vit-gpt2-image-captioning": if model_id == "nlpconnect/vit-gpt2-image-captioning":
image = load_image(request.get_json()["img_url"]).convert("RGB") image = load_image(request.get_json()["img_url"]).convert("RGB")
pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values pixel_values = pipes[model_id]["feature_extractor"](
images=image, return_tensors="pt"
).pixel_values
pixel_values = pixel_values.to(pipes[model_id]["device"]) pixel_values = pixel_values.to(pipes[model_id]["device"])
generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1}) generated_ids = pipe.generate(
generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0] pixel_values, **{"max_length": 200, "num_beams": 1}
)
generated_text = pipes[model_id]["tokenizer"].batch_decode(
generated_ids, skip_special_tokens=True
)[0]
result = {"generated text": generated_text} result = {"generated text": generated_text}
# image to text: OCR # image to text: OCR
if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten": if (
model_id == "microsoft/trocr-base-printed"
or model_id == "microsoft/trocr-base-handwritten"
):
image = load_image(request.get_json()["img_url"]).convert("RGB") image = load_image(request.get_json()["img_url"]).convert("RGB")
pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values pixel_values = pipes[model_id]["processor"](
image, return_tensors="pt"
).pixel_values
pixel_values = pixel_values.to(pipes[model_id]["device"]) pixel_values = pixel_values.to(pipes[model_id]["device"])
generated_ids = pipe.generate(pixel_values) generated_ids = pipe.generate(pixel_values)
generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0] generated_text = pipes[model_id]["processor"].batch_decode(
generated_ids, skip_special_tokens=True
)[0]
result = {"generated text": generated_text} result = {"generated text": generated_text}
# text to image # text to image
@ -494,9 +609,87 @@ def models(model_id):
result = {"path": f"/images/{file_name}.jpg"} result = {"path": f"/images/{file_name}.jpg"}
# object detection # object detection
if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101": if (
model_id == "google/owlvit-base-patch32"
or model_id == "facebook/detr-resnet-101"
):
img_url = request.get_json()["img_url"] img_url = request.get_json()["img_url"]
open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"] open_types = [
"cat",
"couch",
"person",
"car",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
]
result = pipe(img_url, candidate_labels=open_types) result = pipe(img_url, candidate_labels=open_types)
# VQA # VQA
@ -514,14 +707,16 @@ def models(model_id):
# depth-estimation # depth-estimation
if model_id == "Intel/dpt-large": if model_id == "Intel/dpt-large":
output = pipe(request.get_json()["img_url"]) output = pipe(request.get_json()["img_url"])
image = output['depth'] image = output["depth"]
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
image.save(f"public/images/{name}.jpg") image.save(f"public/images/{name}.jpg")
result = {"path": f"/images/{name}.jpg"} result = {"path": f"/images/{name}.jpg"}
if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large": if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
image = load_image(request.get_json()["img_url"]) image = load_image(request.get_json()["img_url"])
inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt") inputs = pipes[model_id]["feature_extractor"](
images=image, return_tensors="pt"
)
with torch.no_grad(): with torch.no_grad():
outputs = pipe(**inputs) outputs = pipe(**inputs)
predicted_depth = outputs.predicted_depth predicted_depth = outputs.predicted_depth
@ -550,11 +745,21 @@ def models(model_id):
text = request.get_json()["text"] text = request.get_json()["text"]
inputs = pipes[model_id]["processor"](text=text, return_tensors="pt") inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
embeddings_dataset = pipes[model_id]["embeddings_dataset"] embeddings_dataset = pipes[model_id]["embeddings_dataset"]
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"]) speaker_embeddings = (
torch.tensor(embeddings_dataset[7306]["xvector"])
.unsqueeze(0)
.to(pipes[model_id]["device"])
)
pipes[model_id]["vocoder"].to(pipes[model_id]["device"]) pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"]) speech = pipe.generate_speech(
inputs["input_ids"].to(pipes[model_id]["device"]),
speaker_embeddings,
vocoder=pipes[model_id]["vocoder"],
)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000) sf.write(
f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000
)
result = {"path": f"/audios/{name}.wav"} result = {"path": f"/audios/{name}.wav"}
# ASR # ASR
@ -569,19 +774,31 @@ def models(model_id):
with torch.no_grad(): with torch.no_grad():
result_wav = pipe(wav.to(pipes[model_id]["device"])) result_wav = pipe(wav.to(pipes[model_id]["device"]))
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr) sf.write(
f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr
)
result = {"path": f"/audios/{name}.wav"} result = {"path": f"/audios/{name}.wav"}
if model_id == "microsoft/speecht5_vc": if model_id == "microsoft/speecht5_vc":
audio_url = request.get_json()["audio_url"] audio_url = request.get_json()["audio_url"]
wav, sr = torchaudio.load(audio_url) wav, sr = torchaudio.load(audio_url)
inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt") inputs = pipes[model_id]["processor"](
audio=wav, sampling_rate=sr, return_tensors="pt"
)
embeddings_dataset = pipes[model_id]["embeddings_dataset"] embeddings_dataset = pipes[model_id]["embeddings_dataset"]
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) speaker_embeddings = torch.tensor(
embeddings_dataset[7306]["xvector"]
).unsqueeze(0)
pipes[model_id]["vocoder"].to(pipes[model_id]["device"]) pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"]) speech = pipe.generate_speech(
inputs["input_ids"].to(pipes[model_id]["device"]),
speaker_embeddings,
vocoder=pipes[model_id]["vocoder"],
)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000) sf.write(
f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000
)
result = {"path": f"/audios/{name}.wav"} result = {"path": f"/audios/{name}.wav"}
# segmentation # segmentation
@ -592,24 +809,44 @@ def models(model_id):
colors = [] colors = []
for i in range(len(segments)): for i in range(len(segments)):
colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50)) colors.append(
(
random.randint(100, 255),
random.randint(100, 255),
random.randint(100, 255),
50,
)
)
for segment in segments: for segment in segments:
mask = segment["mask"] mask = segment["mask"]
mask = mask.convert('L') mask = mask.convert("L")
layer = Image.new('RGBA', mask.size, colors[i]) layer = Image.new("RGBA", mask.size, colors[i])
image.paste(layer, (0, 0), mask) image.paste(layer, (0, 0), mask)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
image.save(f"public/images/{name}.jpg") image.save(f"public/images/{name}.jpg")
result = {"path": f"/images/{name}.jpg"} result = {"path": f"/images/{name}.jpg"}
if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade": if (
model_id == "facebook/maskformer-swin-base-coco"
or model_id == "facebook/maskformer-swin-large-ade"
):
image = load_image(request.get_json()["img_url"]) image = load_image(request.get_json()["img_url"])
inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"]) inputs = pipes[model_id]["feature_extractor"](
images=image, return_tensors="pt"
).to(pipes[model_id]["device"])
outputs = pipe(**inputs) outputs = pipe(**inputs)
result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] result = pipes[model_id][
"feature_extractor"
].post_process_panoptic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[
0
]
predicted_panoptic_map = result["segmentation"].cpu().numpy() predicted_panoptic_map = result["segmentation"].cpu().numpy()
predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8)) predicted_panoptic_map = Image.fromarray(
predicted_panoptic_map.astype(np.uint8)
)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
predicted_panoptic_map.save(f"public/images/{name}.jpg") predicted_panoptic_map.save(f"public/images/{name}.jpg")
result = {"path": f"/images/{name}.jpg"} result = {"path": f"/images/{name}.jpg"}
@ -641,7 +878,7 @@ def models(model_id):
return jsonify(result) return jsonify(result)
if __name__ == '__main__': if __name__ == "__main__":
# temp folders # temp folders
if not os.path.exists("public/audios"): if not os.path.exists("public/audios"):
os.makedirs("public/audios") os.makedirs("public/audios")

@ -54,7 +54,7 @@ max_length = {
"davinci": 2049, "davinci": 2049,
"curie": 2049, "curie": 2049,
"babbage": 2049, "babbage": 2049,
"ada": 2049 "ada": 2049,
} }
@ -67,14 +67,14 @@ def get_max_context_length(model_name):
def get_token_ids_for_task_parsing(model_name): def get_token_ids_for_task_parsing(model_name):
text = '''{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}''' text = """{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}"""
res = encodings[model_name].encode(text) res = encodings[model_name].encode(text)
res = list(set(res)) res = list(set(res))
return res return res
def get_token_ids_for_choose_model(model_name): def get_token_ids_for_choose_model(model_name):
text = '''{"id": "reason"}''' text = """{"id": "reason"}"""
res = encodings[model_name].encode(text) res = encodings[model_name].encode(text)
res = list(set(res)) res = list(set(res))
return res return res
@ -82,7 +82,11 @@ def get_token_ids_for_choose_model(model_name):
######### #########
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="swarms/agents/workers/multi_modal_workers/omni_agent/config.yml") parser.add_argument(
"--config",
type=str,
default="swarms/agents/workers/multi_modal_workers/omni_agent/config.yml",
)
parser.add_argument("--mode", type=str, default="cli") parser.add_argument("--mode", type=str, default="cli")
args = parser.parse_args() args = parser.parse_args()
@ -102,7 +106,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler() handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
if not config["debug"]: if not config["debug"]:
handler.setLevel(logging.CRITICAL) handler.setLevel(logging.CRITICAL)
@ -143,7 +147,9 @@ elif "azure" in config:
elif "openai" in config: elif "openai" in config:
API_TYPE = "openai" API_TYPE = "openai"
else: else:
logger.warning(f"No endpoint specified in {args.config}. The endpoint will be set dynamically according to the client.") logger.warning(
f"No endpoint specified in {args.config}. The endpoint will be set dynamically according to the client."
)
if args.mode in ["test", "cli"]: if args.mode in ["test", "cli"]:
assert API_TYPE, "Only server mode supports dynamic endpoint." assert API_TYPE, "Only server mode supports dynamic endpoint."
@ -157,9 +163,13 @@ elif API_TYPE == "azure":
API_KEY = config["azure"]["api_key"] API_KEY = config["azure"]["api_key"]
elif API_TYPE == "openai": elif API_TYPE == "openai":
API_ENDPOINT = f"https://api.openai.com/v1/{api_name}" API_ENDPOINT = f"https://api.openai.com/v1/{api_name}"
if config["openai"]["api_key"].startswith("sk-"): # Check for valid OpenAI key in config file if config["openai"]["api_key"].startswith(
"sk-"
): # Check for valid OpenAI key in config file
API_KEY = config["openai"]["api_key"] API_KEY = config["openai"]["api_key"]
elif "OPENAI_API_KEY" in os.environ and os.getenv("OPENAI_API_KEY").startswith("sk-"): # Check for environment variable OPENAI_API_KEY elif "OPENAI_API_KEY" in os.environ and os.getenv("OPENAI_API_KEY").startswith(
"sk-"
): # Check for environment variable OPENAI_API_KEY
API_KEY = os.getenv("OPENAI_API_KEY") API_KEY = os.getenv("OPENAI_API_KEY")
else: else:
raise ValueError(f"Incorrect OpenAI key. Please check your {args.config} file.") raise ValueError(f"Incorrect OpenAI key. Please check your {args.config} file.")
@ -175,7 +185,12 @@ inference_mode = config["inference_mode"]
# check the local_inference_endpoint # check the local_inference_endpoint
Model_Server = None Model_Server = None
if inference_mode != "huggingface": if inference_mode != "huggingface":
Model_Server = "http://" + config["local_inference_endpoint"]["host"] + ":" + str(config["local_inference_endpoint"]["port"]) Model_Server = (
"http://"
+ config["local_inference_endpoint"]["host"]
+ ":"
+ str(config["local_inference_endpoint"]["port"])
)
message = f"The server of local inference endpoints is not running, please start it first. (or using `inference_mode: huggingface` in {args.config} for a feature-limited experience)" message = f"The server of local inference endpoints is not running, please start it first. (or using `inference_mode: huggingface` in {args.config} for a feature-limited experience)"
try: try:
r = requests.get(Model_Server + "/running") r = requests.get(Model_Server + "/running")
@ -185,9 +200,15 @@ if inference_mode != "huggingface":
raise ValueError(message) raise ValueError(message)
parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read() parse_task_demos_or_presteps = open(
choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read() config["demos_or_presteps"]["parse_task"], "r"
response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read() ).read()
choose_model_demos_or_presteps = open(
config["demos_or_presteps"]["choose_model"], "r"
).read()
response_results_demos_or_presteps = open(
config["demos_or_presteps"]["response_results"], "r"
).read()
parse_task_prompt = config["prompt"]["parse_task"] parse_task_prompt = config["prompt"]["parse_task"]
choose_model_prompt = config["prompt"]["choose_model"] choose_model_prompt = config["prompt"]["choose_model"]
@ -209,37 +230,54 @@ for model in MODELS:
METADATAS[model["id"]] = model METADATAS[model["id"]] = model
HUGGINGFACE_HEADERS = {} HUGGINGFACE_HEADERS = {}
if config["huggingface"]["token"] and config["huggingface"]["token"].startswith("hf_"): # Check for valid huggingface token in config file if config["huggingface"]["token"] and config["huggingface"]["token"].startswith(
"hf_"
): # Check for valid huggingface token in config file
HUGGINGFACE_HEADERS = { HUGGINGFACE_HEADERS = {
"Authorization": f"Bearer {config['huggingface']['token']}", "Authorization": f"Bearer {config['huggingface']['token']}",
} }
elif "HUGGINGFACE_ACCESS_TOKEN" in os.environ and os.getenv("HUGGINGFACE_ACCESS_TOKEN").startswith("hf_"): # Check for environment variable HUGGINGFACE_ACCESS_TOKEN elif "HUGGINGFACE_ACCESS_TOKEN" in os.environ and os.getenv(
"HUGGINGFACE_ACCESS_TOKEN"
).startswith(
"hf_"
): # Check for environment variable HUGGINGFACE_ACCESS_TOKEN
HUGGINGFACE_HEADERS = { HUGGINGFACE_HEADERS = {
"Authorization": f"Bearer {os.getenv('HUGGINGFACE_ACCESS_TOKEN')}", "Authorization": f"Bearer {os.getenv('HUGGINGFACE_ACCESS_TOKEN')}",
} }
else: else:
raise ValueError(f"Incorrect HuggingFace token. Please check your {args.config} file.") raise ValueError(
f"Incorrect HuggingFace token. Please check your {args.config} file."
)
def convert_chat_to_completion(data): def convert_chat_to_completion(data):
messages = data.pop('messages', []) messages = data.pop("messages", [])
tprompt = "" tprompt = ""
if messages[0]['role'] == "system": if messages[0]["role"] == "system":
tprompt = messages[0]['content'] tprompt = messages[0]["content"]
messages = messages[1:] messages = messages[1:]
final_prompt = "" final_prompt = ""
for message in messages: for message in messages:
if message['role'] == "user": if message["role"] == "user":
final_prompt += ("<im_start>" + "user" + "\n" + message['content'] + "<im_end>\n") final_prompt += (
elif message['role'] == "assistant": "<im_start>" + "user" + "\n" + message["content"] + "<im_end>\n"
final_prompt += ("<im_start>" + "assistant" + "\n" + message['content'] + "<im_end>\n") )
elif message["role"] == "assistant":
final_prompt += (
"<im_start>" + "assistant" + "\n" + message["content"] + "<im_end>\n"
)
else: else:
final_prompt += ("<im_start>" + "system" + "\n" + message['content'] + "<im_end>\n") final_prompt += (
"<im_start>" + "system" + "\n" + message["content"] + "<im_end>\n"
)
final_prompt = tprompt + final_prompt final_prompt = tprompt + final_prompt
final_prompt = final_prompt + "<im_start>assistant" final_prompt = final_prompt + "<im_start>assistant"
data["prompt"] = final_prompt data["prompt"] = final_prompt
data['stop'] = data.get('stop', ["<im_end>"]) data["stop"] = data.get("stop", ["<im_end>"])
data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1)) data["max_tokens"] = data.get(
"max_tokens",
max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1),
)
return data return data
@ -250,14 +288,9 @@ def send_request(data):
if use_completion: if use_completion:
data = convert_chat_to_completion(data) data = convert_chat_to_completion(data)
if api_type == "openai": if api_type == "openai":
HEADER = { HEADER = {"Authorization": f"Bearer {api_key}"}
"Authorization": f"Bearer {api_key}"
}
elif api_type == "azure": elif api_type == "azure":
HEADER = { HEADER = {"api-key": api_key, "Content-Type": "application/json"}
"api-key": api_key,
"Content-Type": "application/json"
}
else: else:
HEADER = None HEADER = None
response = requests.post(api_endpoint, json=data, headers=HEADER, proxies=PROXY) response = requests.post(api_endpoint, json=data, headers=HEADER, proxies=PROXY)
@ -274,15 +307,17 @@ def replace_slot(text, entries):
for key, value in entries.items(): for key, value in entries.items():
if not isinstance(value, str): if not isinstance(value, str):
value = str(value) value = str(value)
text = text.replace("{{" + key + "}}", value.replace('"', "'").replace('\n', "")) text = text.replace(
"{{" + key + "}}", value.replace('"', "'").replace("\n", "")
)
return text return text
def find_json(s): def find_json(s):
s = s.replace("\'", "\"") s = s.replace("'", '"')
start = s.find("{") start = s.find("{")
end = s.rfind("}") end = s.rfind("}")
res = s[start:end + 1] res = s[start : end + 1]
res = res.replace("\n", "") res = res.replace("\n", "")
return res return res
@ -290,10 +325,10 @@ def find_json(s):
def field_extract(s, field): def field_extract(s, field):
try: try:
field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE) field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
extracted = field_rep.search(s).group(1).replace("\"", "\'") extracted = field_rep.search(s).group(1).replace('"', "'")
except BaseException: except BaseException:
field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE) field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
extracted = field_rep.search(s).group(1).replace("\"", "\'") extracted = field_rep.search(s).group(1).replace('"', "'")
return extracted return extracted
@ -377,7 +412,7 @@ def chitchat(messages, api_key, api_type, api_endpoint):
"messages": messages, "messages": messages,
"api_key": api_key, "api_key": api_key,
"api_type": api_type, "api_type": api_type,
"api_endpoint": api_endpoint "api_endpoint": api_endpoint,
} }
return send_request(data) return send_request(data)
@ -391,10 +426,7 @@ def parse_task(context, input, api_key, api_type, api_endpoint):
start = 0 start = 0
while start <= len(context): while start <= len(context):
history = context[start:] history = context[start:]
prompt = replace_slot(parse_task_prompt, { prompt = replace_slot(parse_task_prompt, {"input": input, "context": history})
"input": input,
"context": history
})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages]) history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
num = count_tokens(LLM_encoding, history_text) num = count_tokens(LLM_encoding, history_text)
@ -408,25 +440,29 @@ def parse_task(context, input, api_key, api_type, api_endpoint):
"model": LLM, "model": LLM,
"messages": messages, "messages": messages,
"temperature": 0, "temperature": 0,
"logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids}, "logit_bias": {
item: config["logit_bias"]["parse_task"]
for item in task_parsing_highlight_ids
},
"api_key": api_key, "api_key": api_key,
"api_type": api_type, "api_type": api_type,
"api_endpoint": api_endpoint "api_endpoint": api_endpoint,
} }
return send_request(data) return send_request(data)
def choose_model(input, task, metas, api_key, api_type, api_endpoint): def choose_model(input, task, metas, api_key, api_type, api_endpoint):
prompt = replace_slot(choose_model_prompt, { prompt = replace_slot(
choose_model_prompt,
{
"input": input, "input": input,
"task": task, "task": task,
"metas": metas, "metas": metas,
}) },
demos_or_presteps = replace_slot(choose_model_demos_or_presteps, { )
"input": input, demos_or_presteps = replace_slot(
"task": task, choose_model_demos_or_presteps, {"input": input, "task": task, "metas": metas}
"metas": metas )
})
messages = json.loads(demos_or_presteps) messages = json.loads(demos_or_presteps)
messages.insert(0, {"role": "system", "content": choose_model_tprompt}) messages.insert(0, {"role": "system", "content": choose_model_tprompt})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
@ -435,23 +471,28 @@ def choose_model(input, task, metas, api_key, api_type, api_endpoint):
"model": LLM, "model": LLM,
"messages": messages, "messages": messages,
"temperature": 0, "temperature": 0,
"logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5 "logit_bias": {
item: config["logit_bias"]["choose_model"]
for item in choose_model_highlight_ids
}, # 5
"api_key": api_key, "api_key": api_key,
"api_type": api_type, "api_type": api_type,
"api_endpoint": api_endpoint "api_endpoint": api_endpoint,
} }
return send_request(data) return send_request(data)
def response_results(input, results, api_key, api_type, api_endpoint): def response_results(input, results, api_key, api_type, api_endpoint):
results = [v for k, v in sorted(results.items(), key=lambda item: item[0])] results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
prompt = replace_slot(response_results_prompt, { prompt = replace_slot(
"input": input, response_results_prompt,
}) {
demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
"input": input, "input": input,
"processes": results },
}) )
demos_or_presteps = replace_slot(
response_results_demos_or_presteps, {"input": input, "processes": results}
)
messages = json.loads(demos_or_presteps) messages = json.loads(demos_or_presteps)
messages.insert(0, {"role": "system", "content": response_results_tprompt}) messages.insert(0, {"role": "system", "content": response_results_tprompt})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
@ -462,7 +503,7 @@ def response_results(input, results, api_key, api_type, api_endpoint):
"temperature": 0, "temperature": 0,
"api_key": api_key, "api_key": api_key,
"api_type": api_type, "api_type": api_type,
"api_endpoint": api_endpoint "api_endpoint": api_endpoint,
} }
return send_request(data) return send_request(data)
@ -473,12 +514,23 @@ def huggingface_model_inference(model_id, data, task):
# NLP tasks # NLP tasks
if task == "question-answering": if task == "question-answering":
inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "")} inputs = {
"question": data["text"],
"context": (data["context"] if "context" in data else ""),
}
result = inference(inputs) result = inference(inputs)
if task == "sentence-similarity": if task == "sentence-similarity":
inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]} inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
result = inference(inputs) result = inference(inputs)
if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]: if task in [
"text-classification",
"token-classification",
"text2text-generation",
"summarization",
"translation",
"conversational",
"text-generation",
]:
inputs = data["text"] inputs = data["text"]
result = inference(inputs) result = inference(inputs)
@ -492,7 +544,9 @@ def huggingface_model_inference(model_id, data, task):
json_data["inputs"] = {} json_data["inputs"] = {}
json_data["inputs"]["question"] = text json_data["inputs"]["question"] = text
json_data["inputs"]["image"] = img_base64 json_data["inputs"]["image"] = img_base64
result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json() result = requests.post(
task_url, headers=HUGGINGFACE_HEADERS, json=json_data
).json()
# result = inference(inputs) # not support # result = inference(inputs) # not support
if task == "image-to-image": if task == "image-to-image":
@ -520,15 +574,22 @@ def huggingface_model_inference(model_id, data, task):
predicted = inference(data=img_data) predicted = inference(data=img_data)
colors = [] colors = []
for i in range(len(predicted)): for i in range(len(predicted)):
colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155)) colors.append(
(
random.randint(100, 255),
random.randint(100, 255),
random.randint(100, 255),
155,
)
)
for i, pred in enumerate(predicted): for i, pred in enumerate(predicted):
label = pred["label"] label = pred["label"]
mask = pred.pop("mask").encode("utf-8") mask = pred.pop("mask").encode("utf-8")
mask = base64.b64decode(mask) mask = base64.b64decode(mask)
mask = Image.open(BytesIO(mask), mode='r') mask = Image.open(BytesIO(mask), mode="r")
mask = mask.convert('L') mask = mask.convert("L")
layer = Image.new('RGBA', mask.size, colors[i]) layer = Image.new("RGBA", mask.size, colors[i])
image.paste(layer, (0, 0), mask) image.paste(layer, (0, 0), mask)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
image.save(f"public/images/{name}.jpg") image.save(f"public/images/{name}.jpg")
@ -542,15 +603,27 @@ def huggingface_model_inference(model_id, data, task):
predicted = inference(data=img_data) predicted = inference(data=img_data)
image = Image.open(BytesIO(img_data)) image = Image.open(BytesIO(img_data))
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
labels = list(item['label'] for item in predicted) labels = list(item["label"] for item in predicted)
color_map = {} color_map = {}
for label in labels: for label in labels:
if label not in color_map: if label not in color_map:
color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255)) color_map[label] = (
random.randint(0, 255),
random.randint(0, 100),
random.randint(0, 255),
)
for label in predicted: for label in predicted:
box = label["box"] box = label["box"]
draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2) draw.rectangle(
draw.text((box["xmin"] + 5, box["ymin"] - 15), label["label"], fill=color_map[label["label"]]) ((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])),
outline=color_map[label["label"]],
width=2,
)
draw.text(
(box["xmin"] + 5, box["ymin"] - 15),
label["label"],
fill=color_map[label["label"]],
)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
image.save(f"public/images/{name}.jpg") image.save(f"public/images/{name}.jpg")
result = {} result = {}
@ -566,7 +639,9 @@ def huggingface_model_inference(model_id, data, task):
img_url = data["image"] img_url = data["image"]
img_data = image_to_bytes(img_url) img_data = image_to_bytes(img_url)
HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data)) HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data, proxies=PROXY) r = requests.post(
task_url, headers=HUGGINGFACE_HEADERS, data=img_data, proxies=PROXY
)
result = {} result = {}
if "generated_text" in r.json()[0]: if "generated_text" in r.json()[0]:
result["generated text"] = r.json()[0].pop("generated_text") result["generated text"] = r.json()[0].pop("generated_text")
@ -580,7 +655,11 @@ def huggingface_model_inference(model_id, data, task):
with open(f"public/audios/{name}.flac", "wb") as f: with open(f"public/audios/{name}.flac", "wb") as f:
f.write(response.content) f.write(response.content)
result = {"generated audio": f"/audios/{name}.flac"} result = {"generated audio": f"/audios/{name}.flac"}
if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]: if task in [
"automatic-speech-recognition",
"audio-to-audio",
"audio-classification",
]:
audio_url = data["audio"] audio_url = data["audio"]
audio_data = requests.get(audio_url, timeout=10).content audio_data = requests.get(audio_url, timeout=10).content
response = inference(data=audio_data, raw_response=True) response = inference(data=audio_data, raw_response=True)
@ -631,7 +710,15 @@ def local_model_inference(model_id, data, task):
if task == "question-answering" or task == "sentence-similarity": if task == "question-answering" or task == "sentence-similarity":
response = requests.post(task_url, json=data) response = requests.post(task_url, json=data)
return response.json() return response.json()
if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]: if task in [
"text-classification",
"token-classification",
"text2text-generation",
"summarization",
"translation",
"conversational",
"text-generation",
]:
response = requests.post(task_url, json=data) response = requests.post(task_url, json=data)
return response.json() return response.json()
@ -670,22 +757,39 @@ def local_model_inference(model_id, data, task):
return predicted return predicted
image = load_image(img_url) image = load_image(img_url)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
labels = list(item['label'] for item in predicted) labels = list(item["label"] for item in predicted)
color_map = {} color_map = {}
for label in labels: for label in labels:
if label not in color_map: if label not in color_map:
color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255)) color_map[label] = (
random.randint(0, 255),
random.randint(0, 100),
random.randint(0, 255),
)
for label in predicted: for label in predicted:
box = label["box"] box = label["box"]
draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2) draw.rectangle(
draw.text((box["xmin"] + 5, box["ymin"] - 15), label["label"], fill=color_map[label["label"]]) ((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])),
outline=color_map[label["label"]],
width=2,
)
draw.text(
(box["xmin"] + 5, box["ymin"] - 15),
label["label"],
fill=color_map[label["label"]],
)
name = str(uuid.uuid4())[:4] name = str(uuid.uuid4())[:4]
image.save(f"public/images/{name}.jpg") image.save(f"public/images/{name}.jpg")
results = {} results = {}
results["generated image"] = f"/images/{name}.jpg" results["generated image"] = f"/images/{name}.jpg"
results["predicted"] = predicted results["predicted"] = predicted
return results return results
if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]: if task in [
"image-classification",
"image-to-text",
"document-question-answering",
"visual-question-answering",
]:
img_url = data["image"] img_url = data["image"]
text = None text = None
if "text" in data: if "text" in data:
@ -700,7 +804,11 @@ def local_model_inference(model_id, data, task):
if "path" in results: if "path" in results:
results["generated audio"] = results.pop("path") results["generated audio"] = results.pop("path")
return results return results
if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]: if task in [
"automatic-speech-recognition",
"audio-to-audio",
"audio-classification",
]:
audio_url = data["audio"] audio_url = data["audio"]
response = requests.post(task_url, json={"audio_url": audio_url}) response = requests.post(task_url, json={"audio_url": audio_url})
return response.json() return response.json()
@ -714,8 +822,12 @@ def model_inference(model_id, data, hosted_on, task):
if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]: if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
hosted_on = "local" hosted_on = "local"
else: else:
huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}" huggingfaceStatusUrl = (
r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY) f"https://api-inference.huggingface.co/status/{model_id}"
)
r = requests.get(
huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY
)
logger.debug("Huggingface Status: " + str(r.json())) logger.debug("Huggingface Status: " + str(r.json()))
if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]: if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
hosted_on = "huggingface" hosted_on = "huggingface"
@ -756,14 +868,27 @@ def get_avaliable_models(candidates, topk=5):
model_id = candidate["id"] model_id = candidate["id"]
if inference_mode != "local": if inference_mode != "local":
huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}" huggingfaceStatusUrl = (
thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue)) f"https://api-inference.huggingface.co/status/{model_id}"
)
thread = threading.Thread(
target=get_model_status,
args=(
model_id,
huggingfaceStatusUrl,
HUGGINGFACE_HEADERS,
result_queue,
),
)
threads.append(thread) threads.append(thread)
thread.start() thread.start()
if inference_mode != "huggingface" and config["local_deployment"] != "minimal": if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
localStatusUrl = f"{Model_Server}/status/{model_id}" localStatusUrl = f"{Model_Server}/status/{model_id}"
thread = threading.Thread(target=get_model_status, args=(model_id, localStatusUrl, {}, result_queue)) thread = threading.Thread(
target=get_model_status,
args=(model_id, localStatusUrl, {}, result_queue),
)
threads.append(thread) threads.append(thread)
thread.start() thread.start()
@ -772,7 +897,10 @@ def get_avaliable_models(candidates, topk=5):
model_id, status, endpoint_type = result_queue.get() model_id, status, endpoint_type = result_queue.get()
if status and model_id not in all_available_models: if status and model_id not in all_available_models:
all_available_models[endpoint_type].append(model_id) all_available_models[endpoint_type].append(model_id)
if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk: if (
len(all_available_models["local"] + all_available_models["huggingface"])
>= topk
):
break break
result_count -= 1 result_count -= 1
@ -807,33 +935,45 @@ def run_task(input, command, results, api_key, api_type, api_endpoint):
if "image" in args and "<GENERATED>-" in args["image"]: if "image" in args and "<GENERATED>-" in args["image"]:
resource_id = int(args["image"].split("-")[1]) resource_id = int(args["image"].split("-")[1])
if "generated image" in results[resource_id]["inference result"]: if "generated image" in results[resource_id]["inference result"]:
args["image"] = results[resource_id]["inference result"]["generated image"] args["image"] = results[resource_id]["inference result"][
"generated image"
]
if "audio" in args and "<GENERATED>-" in args["audio"]: if "audio" in args and "<GENERATED>-" in args["audio"]:
resource_id = int(args["audio"].split("-")[1]) resource_id = int(args["audio"].split("-")[1])
if "generated audio" in results[resource_id]["inference result"]: if "generated audio" in results[resource_id]["inference result"]:
args["audio"] = results[resource_id]["inference result"]["generated audio"] args["audio"] = results[resource_id]["inference result"][
"generated audio"
]
if "text" in args and "<GENERATED>-" in args["text"]: if "text" in args and "<GENERATED>-" in args["text"]:
resource_id = int(args["text"].split("-")[1]) resource_id = int(args["text"].split("-")[1])
if "generated text" in results[resource_id]["inference result"]: if "generated text" in results[resource_id]["inference result"]:
args["text"] = results[resource_id]["inference result"]["generated text"] args["text"] = results[resource_id]["inference result"][
"generated text"
]
text = image = audio = None text = image = audio = None
for dep_task in dep_tasks: for dep_task in dep_tasks:
if "generated text" in dep_task["inference result"]: if "generated text" in dep_task["inference result"]:
text = dep_task["inference result"]["generated text"] text = dep_task["inference result"]["generated text"]
logger.debug("Detect the generated text of dependency task (from results):" + text) logger.debug(
"Detect the generated text of dependency task (from results):" + text
)
elif "text" in dep_task["task"]["args"]: elif "text" in dep_task["task"]["args"]:
text = dep_task["task"]["args"]["text"] text = dep_task["task"]["args"]["text"]
logger.debug("Detect the text of dependency task (from args): " + text) logger.debug("Detect the text of dependency task (from args): " + text)
if "generated image" in dep_task["inference result"]: if "generated image" in dep_task["inference result"]:
image = dep_task["inference result"]["generated image"] image = dep_task["inference result"]["generated image"]
logger.debug("Detect the generated image of dependency task (from results): " + image) logger.debug(
"Detect the generated image of dependency task (from results): " + image
)
elif "image" in dep_task["task"]["args"]: elif "image" in dep_task["task"]["args"]:
image = dep_task["task"]["args"]["image"] image = dep_task["task"]["args"]["image"]
logger.debug("Detect the image of dependency task (from args): " + image) logger.debug("Detect the image of dependency task (from args): " + image)
if "generated audio" in dep_task["inference result"]: if "generated audio" in dep_task["inference result"]:
audio = dep_task["inference result"]["generated audio"] audio = dep_task["inference result"]["generated audio"]
logger.debug("Detect the generated audio of dependency task (from results): " + audio) logger.debug(
"Detect the generated audio of dependency task (from results): " + audio
)
elif "audio" in dep_task["task"]["args"]: elif "audio" in dep_task["task"]["args"]:
audio = dep_task["task"]["args"]["audio"] audio = dep_task["task"]["args"]["audio"]
logger.debug("Detect the audio of dependency task (from args): " + audio) logger.debug("Detect the audio of dependency task (from args): " + audio)
@ -849,19 +989,26 @@ def run_task(input, command, results, api_key, api_type, api_endpoint):
args["text"] = text args["text"] = text
for resource in ["image", "audio"]: for resource in ["image", "audio"]:
if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"): if (
resource in args
and not args[resource].startswith("public/")
and len(args[resource]) > 0
and not args[resource].startswith("http")
):
args[resource] = f"public/{args[resource]}" args[resource] = f"public/{args[resource]}"
if "-text-to-image" in command['task'] and "text" not in args: if "-text-to-image" in command["task"] and "text" not in args:
logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.") logger.debug(
"control-text-to-image task, but text is empty, so we use control-generation instead."
)
control = task.split("-")[0] control = task.split("-")[0]
if control == "seg": if control == "seg":
task = "image-segmentation" task = "image-segmentation"
command['task'] = task command["task"] = task
elif control == "depth": elif control == "depth":
task = "depth-estimation" task = "depth-estimation"
command['task'] = task command["task"] = task
else: else:
task = f"{control}-control" task = f"{control}-control"
@ -880,45 +1027,93 @@ def run_task(input, command, results, api_key, api_type, api_endpoint):
choose = {"id": best_model_id, "reason": reason} choose = {"id": best_model_id, "reason": reason}
logger.debug(f"chosen model: {choose}") logger.debug(f"chosen model: {choose}")
else: else:
logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.") logger.warning(
record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op": "message"}) f"Task {command['task']} is not available. ControlNet need to be deployed locally."
inference_result = {"error": "service related to ControlNet is not available."} )
record_case(
success=False,
**{
"input": input,
"task": command,
"reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.",
"op": "message",
},
)
inference_result = {
"error": "service related to ControlNet is not available."
}
results[id] = collect_result(command, "", inference_result) results[id] = collect_result(command, "", inference_result)
return False return False
elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do elif task in [
"summarization",
"translation",
"conversational",
"text-generation",
"text2text-generation",
]: # ChatGPT Can do
best_model_id = "ChatGPT" best_model_id = "ChatGPT"
reason = "ChatGPT performs well on some NLP tasks as well." reason = "ChatGPT performs well on some NLP tasks as well."
choose = {"id": best_model_id, "reason": reason} choose = {"id": best_model_id, "reason": reason}
messages = [{ messages = [
{
"role": "user", "role": "user",
"content": f"[ {input} ] contains a task in JSON format {command}. Now you are a {command['task']} system, the arguments are {command['args']}. Just help me do {command['task']} and give me the result. The result must be in text form without any urls." "content": f"[ {input} ] contains a task in JSON format {command}. Now you are a {command['task']} system, the arguments are {command['args']}. Just help me do {command['task']} and give me the result. The result must be in text form without any urls.",
}] }
]
response = chitchat(messages, api_key, api_type, api_endpoint) response = chitchat(messages, api_key, api_type, api_endpoint)
results[id] = collect_result(command, choose, {"response": response}) results[id] = collect_result(command, choose, {"response": response})
return True return True
else: else:
if task not in MODELS_MAP: if task not in MODELS_MAP:
logger.warning(f"no available models on {task} task.") logger.warning(f"no available models on {task} task.")
record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op": "message"}) record_case(
inference_result = {"error": f"{command['task']} not found in available tasks."} success=False,
**{
"input": input,
"task": command,
"reason": f"task not support: {command['task']}",
"op": "message",
},
)
inference_result = {
"error": f"{command['task']} not found in available tasks."
}
results[id] = collect_result(command, "", inference_result) results[id] = collect_result(command, "", inference_result)
return False return False
candidates = MODELS_MAP[task][:10] candidates = MODELS_MAP[task][:10]
all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"]) all_avaliable_models = get_avaliable_models(
all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"] candidates, config["num_candidate_models"]
)
all_avaliable_model_ids = (
all_avaliable_models["local"] + all_avaliable_models["huggingface"]
)
logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}") logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
if len(all_avaliable_model_ids) == 0: if len(all_avaliable_model_ids) == 0:
logger.warning(f"no available models on {command['task']}") logger.warning(f"no available models on {command['task']}")
record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op": "message"}) record_case(
inference_result = {"error": f"no available models on {command['task']} task."} success=False,
**{
"input": input,
"task": command,
"reason": f"no available models: {command['task']}",
"op": "message",
},
)
inference_result = {
"error": f"no available models on {command['task']} task."
}
results[id] = collect_result(command, "", inference_result) results[id] = collect_result(command, "", inference_result)
return False return False
if len(all_avaliable_model_ids) == 1: if len(all_avaliable_model_ids) == 1:
best_model_id = all_avaliable_model_ids[0] best_model_id = all_avaliable_model_ids[0]
hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface" hosted_on = (
"local"
if best_model_id in all_avaliable_models["local"]
else "huggingface"
)
reason = "Only one model available." reason = "Only one model available."
choose = {"id": best_model_id, "reason": reason} choose = {"id": best_model_id, "reason": reason}
logger.debug(f"chosen model: {choose}") logger.debug(f"chosen model: {choose}")
@ -927,34 +1122,60 @@ def run_task(input, command, results, api_key, api_type, api_endpoint):
{ {
"id": model["id"], "id": model["id"],
"inference endpoint": all_avaliable_models.get( "inference endpoint": all_avaliable_models.get(
"local" if model["id"] in all_avaliable_models["local"] else "huggingface" "local"
if model["id"] in all_avaliable_models["local"]
else "huggingface"
), ),
"likes": model.get("likes"), "likes": model.get("likes"),
"description": model.get("description", "")[:config["max_description_length"]], "description": model.get("description", "")[
: config["max_description_length"]
],
# "language": model.get("meta").get("language") if model.get("meta") else None, # "language": model.get("meta").get("language") if model.get("meta") else None,
"tags": model.get("meta").get("tags") if model.get("meta") else None, "tags": model.get("meta").get("tags")
if model.get("meta")
else None,
} }
for model in candidates for model in candidates
if model["id"] in all_avaliable_model_ids if model["id"] in all_avaliable_model_ids
] ]
choose_str = choose_model(input, command, cand_models_info, api_key, api_type, api_endpoint) choose_str = choose_model(
input, command, cand_models_info, api_key, api_type, api_endpoint
)
logger.debug(f"chosen model: {choose_str}") logger.debug(f"chosen model: {choose_str}")
try: try:
choose = json.loads(choose_str) choose = json.loads(choose_str)
reason = choose["reason"] reason = choose["reason"]
best_model_id = choose["id"] best_model_id = choose["id"]
hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface" hosted_on = (
"local"
if best_model_id in all_avaliable_models["local"]
else "huggingface"
)
except Exception: except Exception:
logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.") logger.warning(
f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response."
)
choose_str = find_json(choose_str) choose_str = find_json(choose_str)
best_model_id, reason, choose = get_id_reason(choose_str) best_model_id, reason, choose = get_id_reason(choose_str)
hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface" hosted_on = (
inference_result = model_inference(best_model_id, args, hosted_on, command['task']) "local"
if best_model_id in all_avaliable_models["local"]
else "huggingface"
)
inference_result = model_inference(best_model_id, args, hosted_on, command["task"])
if "error" in inference_result: if "error" in inference_result:
logger.warning(f"Inference error: {inference_result['error']}") logger.warning(f"Inference error: {inference_result['error']}")
record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op": "message"}) record_case(
success=False,
**{
"input": input,
"task": command,
"reason": f"inference error: {inference_result['error']}",
"op": "message",
},
)
results[id] = collect_result(command, choose, inference_result) results[id] = collect_result(command, choose, inference_result)
return False return False
@ -962,7 +1183,14 @@ def run_task(input, command, results, api_key, api_type, api_endpoint):
return True return True
def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=False, return_results=False): def chat_huggingface(
messages,
api_key,
api_type,
api_endpoint,
return_planning=False,
return_results=False,
):
start = time.time() start = time.time()
context = messages[:-1] context = messages[:-1]
input = messages[-1]["content"] input = messages[-1]["content"]
@ -972,7 +1200,15 @@ def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=
task_str = parse_task(context, input, api_key, api_type, api_endpoint) task_str = parse_task(context, input, api_key, api_type, api_endpoint)
if "error" in task_str: if "error" in task_str:
record_case(success=False, **{"input": input, "task": task_str, "reason": f"task parsing error: {task_str['error']['message']}", "op": "report message"}) record_case(
success=False,
**{
"input": input,
"task": task_str,
"reason": f"task parsing error: {task_str['error']['message']}",
"op": "report message",
},
)
return {"message": task_str["error"]["message"]} return {"message": task_str["error"]["message"]}
task_str = task_str.strip() task_str = task_str.strip()
@ -983,16 +1219,46 @@ def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=
except Exception as e: except Exception as e:
logger.debug(e) logger.debug(e)
response = chitchat(messages, api_key, api_type, api_endpoint) response = chitchat(messages, api_key, api_type, api_endpoint)
record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op": "chitchat"}) record_case(
success=False,
**{
"input": input,
"task": task_str,
"reason": "task parsing fail",
"op": "chitchat",
},
)
return {"message": response} return {"message": response}
if task_str == "[]": # using LLM response for empty task if task_str == "[]": # using LLM response for empty task
record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"}) record_case(
success=False,
**{
"input": input,
"task": [],
"reason": "task parsing fail: empty",
"op": "chitchat",
},
)
response = chitchat(messages, api_key, api_type, api_endpoint) response = chitchat(messages, api_key, api_type, api_endpoint)
return {"message": response} return {"message": response}
if len(tasks) == 1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: if len(tasks) == 1 and tasks[0]["task"] in [
record_case(success=True, **{"input": input, "task": tasks, "reason": "chitchat tasks", "op": "chitchat"}) "summarization",
"translation",
"conversational",
"text-generation",
"text2text-generation",
]:
record_case(
success=True,
**{
"input": input,
"task": tasks,
"reason": "chitchat tasks",
"op": "chitchat",
},
)
response = chitchat(messages, api_key, api_type, api_endpoint) response = chitchat(messages, api_key, api_type, api_endpoint)
return {"message": response} return {"message": response}
@ -1019,7 +1285,10 @@ def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=
dep = task["dep"] dep = task["dep"]
if dep[0] == -1 or len(list(set(dep).intersection(d.keys()))) == len(dep): if dep[0] == -1 or len(list(set(dep).intersection(d.keys()))) == len(dep):
tasks.remove(task) tasks.remove(task)
thread = threading.Thread(target=run_task, args=(input, task, d, api_key, api_type, api_endpoint)) thread = threading.Thread(
target=run_task,
args=(input, task, d, api_key, api_type, api_endpoint),
)
thread.start() thread.start()
threads.append(thread) threads.append(thread)
if num_thread == len(threads): if num_thread == len(threads):
@ -1045,7 +1314,17 @@ def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning=
during = end - start during = end - start
answer = {"message": response} answer = {"message": response}
record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op": "response"}) record_case(
success=True,
**{
"input": input,
"task": task_str,
"results": results,
"response": response,
"during": during,
"op": "response",
},
)
logger.info(f"response: {response}") logger.info(f"response: {response}")
return answer return answer
@ -1058,31 +1337,63 @@ def test():
"Please answer all the named entities in the sentence: Iron Man is a superhero appearing in American comic books published by Marvel Comics. The character was co-created by writer and editor Stan Lee, developed by scripter Larry Lieber, and designed by artists Don Heck and Jack Kirby.", "Please answer all the named entities in the sentence: Iron Man is a superhero appearing in American comic books published by Marvel Comics. The character was co-created by writer and editor Stan Lee, developed by scripter Larry Lieber, and designed by artists Don Heck and Jack Kirby.",
"please dub for me: 'Iron Man is a superhero appearing in American comic books published by Marvel Comics. The character was co-created by writer and editor Stan Lee, developed by scripter Larry Lieber, and designed by artists Don Heck and Jack Kirby.'" "please dub for me: 'Iron Man is a superhero appearing in American comic books published by Marvel Comics. The character was co-created by writer and editor Stan Lee, developed by scripter Larry Lieber, and designed by artists Don Heck and Jack Kirby.'"
"Given an image: https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg, please answer the question: What is on top of the building?", "Given an image: https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg, please answer the question: What is on top of the building?",
"Please generate a canny image based on /examples/f.jpg" "Please generate a canny image based on /examples/f.jpg",
] ]
for input in inputs: for input in inputs:
messages = [{"role": "user", "content": input}] messages = [{"role": "user", "content": input}]
chat_huggingface(messages, API_KEY, API_TYPE, API_ENDPOINT, return_planning=False, return_results=False) chat_huggingface(
messages,
API_KEY,
API_TYPE,
API_ENDPOINT,
return_planning=False,
return_results=False,
)
# multi rounds example # multi rounds example
messages = [ messages = [
{"role": "user", "content": "Please generate a canny image based on /examples/f.jpg"}, {
{"role": "assistant", "content": """Sure. I understand your request. Based on the inference results of the models, I have generated a canny image for you. The workflow I used is as follows: First, I used the image-to-text model (nlpconnect/vit-gpt2-image-captioning) to convert the image /examples/f.jpg to text. The generated text is "a herd of giraffes and zebras grazing in a field". Second, I used the canny-control model (canny-control) to generate a canny image from the text. Unfortunately, the model failed to generate the canny image. Finally, I used the canny-text-to-image model (lllyasviel/sd-controlnet-canny) to generate a canny image from the text. The generated image is located at /images/f16d.png. I hope this answers your request. Is there anything else I can help you with?"""}, "role": "user",
{"role": "user", "content": """then based on the above canny image and a prompt "a photo of a zoo", generate a new image."""}, "content": "Please generate a canny image based on /examples/f.jpg",
},
{
"role": "assistant",
"content": """Sure. I understand your request. Based on the inference results of the models, I have generated a canny image for you. The workflow I used is as follows: First, I used the image-to-text model (nlpconnect/vit-gpt2-image-captioning) to convert the image /examples/f.jpg to text. The generated text is "a herd of giraffes and zebras grazing in a field". Second, I used the canny-control model (canny-control) to generate a canny image from the text. Unfortunately, the model failed to generate the canny image. Finally, I used the canny-text-to-image model (lllyasviel/sd-controlnet-canny) to generate a canny image from the text. The generated image is located at /images/f16d.png. I hope this answers your request. Is there anything else I can help you with?""",
},
{
"role": "user",
"content": """then based on the above canny image and a prompt "a photo of a zoo", generate a new image.""",
},
] ]
chat_huggingface(messages, API_KEY, API_TYPE, API_ENDPOINT, return_planning=False, return_results=False) chat_huggingface(
messages,
API_KEY,
API_TYPE,
API_ENDPOINT,
return_planning=False,
return_results=False,
)
def cli(): def cli():
messages = [] messages = []
print("Welcome to Jarvis! A collaborative system that consists of an LLM as the controller and numerous expert models as collaborative executors. Jarvis can plan tasks, schedule Hugging Face models, generate friendly responses based on your requests, and help you with many things. Please enter your request (`exit` to exit).") print(
"Welcome to Jarvis! A collaborative system that consists of an LLM as the controller and numerous expert models as collaborative executors. Jarvis can plan tasks, schedule Hugging Face models, generate friendly responses based on your requests, and help you with many things. Please enter your request (`exit` to exit)."
)
while True: while True:
message = input("[ User ]: ") message = input("[ User ]: ")
if message == "exit": if message == "exit":
break break
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
answer = chat_huggingface(messages, API_KEY, API_TYPE, API_ENDPOINT, return_planning=False, return_results=False) answer = chat_huggingface(
messages,
API_KEY,
API_TYPE,
API_ENDPOINT,
return_planning=False,
return_results=False,
)
print("[ Jarvis ]: ", answer["message"]) print("[ Jarvis ]: ", answer["message"])
messages.append({"role": "assistant", "content": answer["message"]}) messages.append({"role": "assistant", "content": answer["message"]})

@ -17,12 +17,7 @@ from swarms.agents.message import Message
class Step: class Step:
def __init__( def __init__(
self, self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool
task: str,
id: int,
dep: List[int],
args: Dict[str, str],
tool: BaseTool
): ):
self.task = task self.task = task
self.id = id self.id = id
@ -32,10 +27,7 @@ class Step:
class Plan: class Plan:
def __init__( def __init__(self, steps: List[Step]):
self,
steps: List[Step]
):
self.steps = steps self.steps = steps
def __str__(self) -> str: def __str__(self) -> str:
@ -104,10 +96,7 @@ class OmniModalAgent:
# self.task_executor = TaskExecutor # self.task_executor = TaskExecutor
self.history = [] self.history = []
def run( def run(self, input: str) -> str:
self,
input: str
) -> str:
"""Run the OmniAgent""" """Run the OmniAgent"""
plan = self.chat_planner.plan( plan = self.chat_planner.plan(
inputs={ inputs={
@ -124,11 +113,7 @@ class OmniModalAgent:
return response return response
def chat( def chat(self, msg: str = None, streaming: bool = False):
self,
msg: str = None,
streaming: bool = False
):
""" """
Run chat Run chat
@ -148,24 +133,14 @@ class OmniModalAgent:
""" """
# add users message to the history # add users message to the history
self.history.append( self.history.append(Message("User", msg))
Message(
"User",
msg
)
)
# process msg # process msg
try: try:
response = self.agent.run(msg) response = self.agent.run(msg)
# add agent's response to the history # add agent's response to the history
self.history.append( self.history.append(Message("Agent", response))
Message(
"Agent",
response
)
)
# if streaming is = True # if streaming is = True
if streaming: if streaming:
@ -177,19 +152,11 @@ class OmniModalAgent:
error_message = f"Error processing message: {str(error)}" error_message = f"Error processing message: {str(error)}"
# add error to history # add error to history
self.history.append( self.history.append(Message("Agent", error_message))
Message(
"Agent",
error_message
)
)
return error_message return error_message
def _stream_response( def _stream_response(self, response: str = None):
self,
response: str = None
):
""" """
Yield the response token by token (word by word) Yield the response token by token (word by word)

@ -85,7 +85,7 @@ class SalesConversationChain(LLMChain):
"1", "1",
"Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional.", "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional.",
), ),
) )
""" """
@ -166,14 +166,12 @@ def get_tools(product_catalog):
func=knowledge_base.run, func=knowledge_base.run,
description="useful for when you need to answer questions about product information", description="useful for when you need to answer questions about product information",
), ),
# Interpreter # Interpreter
Tool( Tool(
name="Code Interepeter", name="Code Interepeter",
func=compile, func=compile,
description="Useful when you need to run code locally, such as Python, Javascript, Shell, and more." description="Useful when you need to run code locally, such as Python, Javascript, Shell, and more.",
) )
# omnimodal agent # omnimodal agent
] ]
@ -354,12 +352,7 @@ class ProfitPilot(Chain, BaseModel):
return {} return {}
@classmethod @classmethod
def from_llm( def from_llm(cls, llm: BaseLLM, verbose: bool = False, **kwargs): # noqa: F821
cls,
llm: BaseLLM,
verbose: bool = False,
**kwargs
): # noqa: F821
"""Initialize the SalesGPT Controller.""" """Initialize the SalesGPT Controller."""
stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose) stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose)

@ -1,5 +1,3 @@
def stream(response): def stream(response):
""" """
Yield the response token by token (word by word) from llm Yield the response token by token (word by word) from llm

@ -10,9 +10,14 @@ from marshmallow.exceptions import RegistryError
@define @define
class BaseArtifact(ABC): class BaseArtifact(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
name: str = field(default=Factory(lambda self: self.id, takes_self=True), kw_only=True) name: str = field(
default=Factory(lambda self: self.id, takes_self=True), kw_only=True
)
value: any = field() value: any = field()
type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) type: str = field(
default=Factory(lambda self: self.__class__.__name__, takes_self=True),
kw_only=True,
)
@classmethod @classmethod
def value_to_bytes(cls, value: any) -> bytes: def value_to_bytes(cls, value: any) -> bytes:
@ -38,7 +43,7 @@ class BaseArtifact(ABC):
ErrorArtifactSchema, ErrorArtifactSchema,
BlobArtifactSchema, BlobArtifactSchema,
CsvRowArtifactSchema, CsvRowArtifactSchema,
ListArtifactSchema ListArtifactSchema,
) )
class_registry.register("TextArtifact", TextArtifactSchema) class_registry.register("TextArtifact", TextArtifactSchema)

@ -12,14 +12,8 @@ class Artifact(BaseModel):
Artifact that has the task has been produced Artifact that has the task has been produced
""" """
artifact_id: StrictStr = Field( artifact_id: StrictStr = Field(..., description="ID of the artifact")
..., file_name: StrictStr = Field(..., description="Filename of the artifact")
description="ID of the artifact"
)
file_name: StrictStr = Field(
...,
description="Filename of the artifact"
)
relative_path: Optional[StrictStr] = Field( relative_path: Optional[StrictStr] = Field(
None, description="Relative path of the artifact" None, description="Relative path of the artifact"
) )

@ -10,7 +10,9 @@ from langchain.vectorstores import FAISS
from langchain_experimental.autonomous_agents import BabyAGI from langchain_experimental.autonomous_agents import BabyAGI
from pydantic import ValidationError from pydantic import ValidationError
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# ---------- Boss Node ---------- # ---------- Boss Node ----------
@ -48,7 +50,7 @@ class Boss:
boss_system_prompt="You are a boss planner in a swarm...", boss_system_prompt="You are a boss planner in a swarm...",
llm_class=OpenAI, llm_class=OpenAI,
worker_node=None, worker_node=None,
verbose=False verbose=False,
): ):
# Store parameters # Store parameters
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.api_key = api_key or os.getenv("OPENAI_API_KEY")
@ -85,11 +87,7 @@ class Boss:
embedding_size = 8192 embedding_size = 8192
index = faiss.IndexFlatL2(embedding_size) index = faiss.IndexFlatL2(embedding_size)
return FAISS( return FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
embeddings_model.embed_query,
index,
InMemoryDocstore({}), {}
)
except Exception as e: except Exception as e:
logging.error(f"Failed to initialize vector store: {e}") logging.error(f"Failed to initialize vector store: {e}")
@ -102,9 +100,13 @@ class Boss:
Tool( Tool(
name="Goal Decomposition Tool", name="Goal Decomposition Tool",
func=todo_chain.run, func=todo_chain.run,
description="Use Case: Decompose ambitious goals into as many explicit and well defined tasks for an AI agent to follow. Rules and Regulations, don't use this tool too often only in the beginning when the user grants you a mission." description="Use Case: Decompose ambitious goals into as many explicit and well defined tasks for an AI agent to follow. Rules and Regulations, don't use this tool too often only in the beginning when the user grants you a mission.",
),
Tool(
name="Swarm Worker Agent",
func=worker_node,
description="Use Case: When you want to delegate and assign the decomposed goal sub tasks to a worker agent in your swarm, Rules and Regulations, Provide a task specification sheet to the worker agent. It can use the browser, process csvs and generate content",
), ),
Tool(name="Swarm Worker Agent", func=worker_node, description="Use Case: When you want to delegate and assign the decomposed goal sub tasks to a worker agent in your swarm, Rules and Regulations, Provide a task specification sheet to the worker agent. It can use the browser, process csvs and generate content")
] ]
suffix = """Question: {task}\n{agent_scratchpad}""" suffix = """Question: {task}\n{agent_scratchpad}"""
@ -118,7 +120,9 @@ class Boss:
llm_chain = LLMChain(llm=self.llm, prompt=prompt) llm_chain = LLMChain(llm=self.llm, prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tools) agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tools)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=self.verbose) return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=self.verbose
)
def _initialize_baby_agi(self, human_in_the_loop): def _initialize_baby_agi(self, human_in_the_loop):
try: try:
@ -127,7 +131,7 @@ class Boss:
vectorstore=self.vectorstore, vectorstore=self.vectorstore,
task_execution_chain=self.agent_executor, task_execution_chain=self.agent_executor,
max_iterations=self.max_iterations, max_iterations=self.max_iterations,
human_in_the_loop=human_in_the_loop human_in_the_loop=human_in_the_loop,
) )
except ValidationError as e: except ValidationError as e:
logging.error(f"Validation Error while initializing BabyAGI: {e}") logging.error(f"Validation Error while initializing BabyAGI: {e}")

@ -28,7 +28,9 @@ from tenacity import (
from swarms.embeddings.base import Embeddings from swarms.embeddings.base import Embeddings
def get_from_dict_or_env(values: dict, key: str, env_key: str, default: Any = None) -> Any: def get_from_dict_or_env(
values: dict, key: str, env_key: str, default: Any = None
) -> Any:
import os import os
return values.get(key) or os.getenv(env_key) or default return values.get(key) or os.getenv(env_key) or default
@ -345,7 +347,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
for j in range(0, len(token), self.embedding_ctx_length): for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j: j + self.embedding_ctx_length]) tokens.append(token[j : j + self.embedding_ctx_length])
indices.append(i) indices.append(i)
batched_embeddings: List[List[float]] = [] batched_embeddings: List[List[float]] = []
@ -364,7 +366,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in _iter: for i in _iter:
response = embed_with_retry( response = embed_with_retry(
self, self,
input=tokens[i: i + _chunk_size], input=tokens[i : i + _chunk_size],
**self._invocation_params, **self._invocation_params,
) )
batched_embeddings.extend(r["embedding"] for r in response["data"]) batched_embeddings.extend(r["embedding"] for r in response["data"])
@ -426,7 +428,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
for j in range(0, len(token), self.embedding_ctx_length): for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j: j + self.embedding_ctx_length]) tokens.append(token[j : j + self.embedding_ctx_length])
indices.append(i) indices.append(i)
batched_embeddings: List[List[float]] = [] batched_embeddings: List[List[float]] = []
@ -434,7 +436,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(0, len(tokens), _chunk_size): for i in range(0, len(tokens), _chunk_size):
response = await async_embed_with_retry( response = await async_embed_with_retry(
self, self,
input=tokens[i: i + _chunk_size], input=tokens[i : i + _chunk_size],
**self._invocation_params, **self._invocation_params,
) )
batched_embeddings.extend(r["embedding"] for r in response["data"]) batched_embeddings.extend(r["embedding"] for r in response["data"])

@ -8,10 +8,7 @@ from pegasus import Pegasus
class PegasusEmbedding: class PegasusEmbedding:
def __init__( def __init__(
self, self, modality: str, multi_process: bool = False, n_processes: int = 4
modality: str,
multi_process: bool = False,
n_processes: int = 4
): ):
self.modality = modality self.modality = modality
self.multi_process = multi_process self.multi_process = multi_process
@ -19,7 +16,9 @@ class PegasusEmbedding:
try: try:
self.pegasus = Pegasus(modality, multi_process, n_processes) self.pegasus = Pegasus(modality, multi_process, n_processes)
except Exception as e: except Exception as e:
logging.error(f"Failed to initialize Pegasus with modality: {modality}: {e}") logging.error(
f"Failed to initialize Pegasus with modality: {modality}: {e}"
)
raise raise
def embed(self, data: Union[str, list[str]]): def embed(self, data: Union[str, list[str]]):

@ -10,16 +10,13 @@ import logging
from swarms.swarms.swarms import HierarchicalSwarm from swarms.swarms.swarms import HierarchicalSwarm
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
)
class HiveMind: class HiveMind:
def __init__( def __init__(self, openai_api_key="", num_swarms=1, max_workers=None):
self,
openai_api_key="",
num_swarms=1,
max_workers=None
):
self.openai_api_key = openai_api_key self.openai_api_key = openai_api_key
self.num_swarms = num_swarms self.num_swarms = num_swarms
self.swarms = [HierarchicalSwarm(openai_api_key) for _ in range(num_swarms)] self.swarms = [HierarchicalSwarm(openai_api_key) for _ in range(num_swarms)]
@ -43,8 +40,13 @@ class HiveMind:
logging.error(f"An error occurred in run: {e}") logging.error(f"An error occurred in run: {e}")
def run(self, objective, timeout=None): def run(self, objective, timeout=None):
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: with concurrent.futures.ThreadPoolExecutor(
futures = {executor.submit(self.run_swarm, swarm, objective) for swarm in self.swarms} max_workers=self.max_workers
) as executor:
futures = {
executor.submit(self.run_swarm, swarm, objective)
for swarm in self.swarms
}
results = [] results = []
for future in concurrent.futures.as_completed(futures, timeout=timeout): for future in concurrent.futures.as_completed(futures, timeout=timeout):
try: try:

@ -4,8 +4,7 @@ from chromadb import EmbeddingFunction
def openai_embed(self, input, api_key, model_name): def openai_embed(self, input, api_key, model_name):
openai = EmbeddingFunction.OpenAIEmbeddingFunction( openai = EmbeddingFunction.OpenAIEmbeddingFunction(
api_key=api_key, api_key=api_key, model_name=model_name
model_name=model_name
) )
embedding = openai(input) embedding = openai(input)
return embedding return embedding

@ -26,19 +26,16 @@ class Artifact(BaseModel):
relative_path: Optional[str] = Field( relative_path: Optional[str] = Field(
None, None,
description="Relative path of the artifact in the agent's workspace", description="Relative path of the artifact in the agent's workspace",
example="python/code/" example="python/code/",
) )
class ArtifactUpload(BaseModel): class ArtifactUpload(BaseModel):
file: bytes = Field( file: bytes = Field(..., description="File to upload")
...,
description="File to upload"
)
relative_path: Optional[str] = Field( relative_path: Optional[str] = Field(
None, None,
description="Relative path of the artifact in the agent's workspace", description="Relative path of the artifact in the agent's workspace",
example="python/code/" example="python/code/",
) )

@ -1,7 +1,9 @@
# prompts # prompts
from swarms.models.anthropic import Anthropic from swarms.models.anthropic import Anthropic
# from swarms.models.palm import GooglePalm # from swarms.models.palm import GooglePalm
from swarms.models.petals import Petals from swarms.models.petals import Petals
# from swarms.models.chat_openai import OpenAIChat # from swarms.models.chat_openai import OpenAIChat
from swarms.models.prompts.debate import * from swarms.models.prompts.debate import *
from swarms.models.mistral import Mistral from swarms.models.mistral import Mistral

@ -13,7 +13,7 @@ class Anthropic:
top_k=None, top_k=None,
top_p=None, top_p=None,
streaming=False, streaming=False,
default_request_timeout=None default_request_timeout=None,
): ):
self.model = model self.model = model
self.max_tokens_to_sample = max_tokens_to_sample self.max_tokens_to_sample = max_tokens_to_sample
@ -22,7 +22,9 @@ class Anthropic:
self.top_p = top_p self.top_p = top_p
self.streaming = streaming self.streaming = streaming
self.default_request_timeout = default_request_timeout or 600 self.default_request_timeout = default_request_timeout or 600
self.anthropic_api_url = os.getenv("ANTHROPIC_API_URL", "https://api.anthropic.com") self.anthropic_api_url = os.getenv(
"ANTHROPIC_API_URL", "https://api.anthropic.com"
)
self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
def _default_params(self): def _default_params(self):
@ -44,12 +46,13 @@ class Anthropic:
stop = stop or [] stop = stop or []
params = self._default_params() params = self._default_params()
headers = {"Authorization": f"Bearer {self.anthropic_api_key}"} headers = {"Authorization": f"Bearer {self.anthropic_api_key}"}
data = { data = {"prompt": prompt, "stop_sequences": stop, **params}
"prompt": prompt, response = requests.post(
"stop_sequences": stop, f"{self.anthropic_api_url}/completions",
**params headers=headers,
} json=data,
response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout) timeout=self.default_request_timeout,
)
return response.json().get("completion") return response.json().get("completion")
def __call__(self, prompt, stop=None): def __call__(self, prompt, stop=None):
@ -57,10 +60,11 @@ class Anthropic:
stop = stop or [] stop = stop or []
params = self._default_params() params = self._default_params()
headers = {"Authorization": f"Bearer {self.anthropic_api_key}"} headers = {"Authorization": f"Bearer {self.anthropic_api_key}"}
data = { data = {"prompt": prompt, "stop_sequences": stop, **params}
"prompt": prompt, response = requests.post(
"stop_sequences": stop, f"{self.anthropic_api_url}/completions",
**params headers=headers,
} json=data,
response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout) timeout=self.default_request_timeout,
)
return response.json().get("completion") return response.json().get("completion")

@ -458,7 +458,7 @@ class BaseOpenAI(BaseLLM):
) )
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [ sub_prompts = [
prompts[i: i + self.batch_size] prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size) for i in range(0, len(prompts), self.batch_size)
] ]
return sub_prompts return sub_prompts
@ -469,7 +469,7 @@ class BaseOpenAI(BaseLLM):
"""Create the LLMResult from the choices and prompts.""" """Create the LLMResult from the choices and prompts."""
generations = [] generations = []
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
sub_choices = choices[i * self.n: (i + 1) * self.n] sub_choices = choices[i * self.n : (i + 1) * self.n]
generations.append( generations.append(
[ [
Generation( Generation(

@ -23,7 +23,7 @@ class Mistral:
use_flash_attention: bool = False, use_flash_attention: bool = False,
temperature: float = 1.0, temperature: float = 1.0,
max_length: int = 100, max_length: int = 100,
do_sample: bool = True do_sample: bool = True,
): ):
self.ai_name = ai_name self.ai_name = ai_name
self.system_prompt = system_prompt self.system_prompt = system_prompt
@ -52,34 +52,24 @@ class Mistral:
except Exception as e: except Exception as e:
raise ValueError(f"Error loading the Mistral model: {str(e)}") raise ValueError(f"Error loading the Mistral model: {str(e)}")
def run( def run(self, task: str):
self,
task: str
):
"""Run the model on a given task.""" """Run the model on a given task."""
try: try:
model_inputs = self.tokenizer( model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
[task],
return_tensors="pt"
).to(self.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
**model_inputs, **model_inputs,
max_length=self.max_length, max_length=self.max_length,
do_sample=self.do_sample, do_sample=self.do_sample,
temperature=self.temperature, temperature=self.temperature,
max_new_tokens=self.max_length max_new_tokens=self.max_length,
) )
output_text = self.tokenizer.batch_decode(generated_ids)[0] output_text = self.tokenizer.batch_decode(generated_ids)[0]
return output_text return output_text
except Exception as e: except Exception as e:
raise ValueError(f"Error running the model: {str(e)}") raise ValueError(f"Error running the model: {str(e)}")
def chat( def chat(self, msg: str = None, streaming: bool = False):
self,
msg: str = None,
streaming: bool = False
):
""" """
Run chat Run chat
@ -99,24 +89,14 @@ class Mistral:
""" """
# add users message to the history # add users message to the history
self.history.append( self.history.append(Message("User", msg))
Message(
"User",
msg
)
)
# process msg # process msg
try: try:
response = self.agent.run(msg) response = self.agent.run(msg)
# add agent's response to the history # add agent's response to the history
self.history.append( self.history.append(Message("Agent", response))
Message(
"Agent",
response
)
)
# if streaming is = True # if streaming is = True
if streaming: if streaming:
@ -128,19 +108,11 @@ class Mistral:
error_message = f"Error processing message: {str(error)}" error_message = f"Error processing message: {str(error)}"
# add error to history # add error to history
self.history.append( self.history.append(Message("Agent", error_message))
Message(
"Agent",
error_message
)
)
return error_message return error_message
def _stream_response( def _stream_response(self, response: str = None):
self,
response: str = None
):
""" """
Yield the response token by token (word by word) Yield the response token by token (word by word)

@ -12,7 +12,7 @@ class Petals:
top_p=0.9, top_p=0.9,
top_k=None, top_k=None,
do_sample=True, do_sample=True,
max_length=None max_length=None,
): ):
self.model_name = model_name self.model_name = model_name
self.temperature = temperature self.temperature = temperature

@ -6,6 +6,7 @@ from typing import Dict, NamedTuple
class AgentAction(NamedTuple): class AgentAction(NamedTuple):
"""Action returned by AgentOutputParser.""" """Action returned by AgentOutputParser."""
name: str name: str
args: Dict args: Dict

@ -16,14 +16,12 @@ class PromptConstructor:
self.tools = tools self.tools = tools
def construct_full_prompt(self, goals: List[str]) -> str: def construct_full_prompt(self, goals: List[str]) -> str:
prompt_start = ( prompt_start = """Your decisions must always be made independently
"""Your decisions must always be made independently
without seeking user assistance.\n without seeking user assistance.\n
Play to your strengths as an LLM and pursue simple Play to your strengths as an LLM and pursue simple
strategies with no legal complications.\n strategies with no legal complications.\n
If you have completed all your tasks, make sure to If you have completed all your tasks, make sure to
use the "finish" command.""" use the "finish" command."""
)
# Construct full prompt # Construct full prompt
full_prompt = ( full_prompt = (
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
@ -56,10 +54,12 @@ class MessageFormatter:
send_token_limit: int = 4196 send_token_limit: int = 4196
def format_messages(self, **kwargs: Any) -> List[Message]: def format_messages(self, **kwargs: Any) -> List[Message]:
prompt_constructor = PromptConstructor(ai_name=kwargs["ai_name"], prompt_constructor = PromptConstructor(
ai_role=kwargs["ai_role"], ai_name=kwargs["ai_name"], ai_role=kwargs["ai_role"], tools=kwargs["tools"]
tools=kwargs["tools"]) )
base_prompt = SystemMessage(content=prompt_constructor.construct_full_prompt(kwargs["goals"])) base_prompt = SystemMessage(
content=prompt_constructor.construct_full_prompt(kwargs["goals"])
)
time_prompt = SystemMessage( time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}" content=f"The current time and date is {time.strftime('%c')}"
) )

@ -1,5 +1,5 @@
def generate_agent_role_prompt(agent): def generate_agent_role_prompt(agent):
""" Generates the agent role prompt. """Generates the agent role prompt.
Args: agent (str): The type of the agent. Args: agent (str): The type of the agent.
Returns: str: The agent role prompt. Returns: str: The agent role prompt.
""" """
@ -7,35 +7,38 @@ def generate_agent_role_prompt(agent):
"Finance Agent": "You are a seasoned finance analyst AI assistant. Your primary goal is to compose comprehensive, astute, impartial, and methodically arranged financial reports based on provided data and trends.", "Finance Agent": "You are a seasoned finance analyst AI assistant. Your primary goal is to compose comprehensive, astute, impartial, and methodically arranged financial reports based on provided data and trends.",
"Travel Agent": "You are a world-travelled AI tour guide assistant. Your main purpose is to draft engaging, insightful, unbiased, and well-structured travel reports on given locations, including history, attractions, and cultural insights.", "Travel Agent": "You are a world-travelled AI tour guide assistant. Your main purpose is to draft engaging, insightful, unbiased, and well-structured travel reports on given locations, including history, attractions, and cultural insights.",
"Academic Research Agent": "You are an AI academic research assistant. Your primary responsibility is to create thorough, academically rigorous, unbiased, and systematically organized reports on a given research topic, following the standards of scholarly work.", "Academic Research Agent": "You are an AI academic research assistant. Your primary responsibility is to create thorough, academically rigorous, unbiased, and systematically organized reports on a given research topic, following the standards of scholarly work.",
"Default Agent": "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text." "Default Agent": "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text.",
} }
return prompts.get(agent, "No such agent") return prompts.get(agent, "No such agent")
def generate_report_prompt(question, research_summary): def generate_report_prompt(question, research_summary):
""" Generates the report prompt for the given question and research summary. """Generates the report prompt for the given question and research summary.
Args: question (str): The question to generate the report prompt for Args: question (str): The question to generate the report prompt for
research_summary (str): The research summary to generate the report prompt for research_summary (str): The research summary to generate the report prompt for
Returns: str: The report prompt for the given question and research summary Returns: str: The report prompt for the given question and research summary
""" """
return f'"""{research_summary}""" Using the above information, answer the following'\ return (
f' question or topic: "{question}" in a detailed report --'\ f'"""{research_summary}""" Using the above information, answer the following'
" The report should focus on the answer to the question, should be well structured, informative," \ f' question or topic: "{question}" in a detailed report --'
" in depth, with facts and numbers if available, a minimum of 1,200 words and with markdown syntax and apa format. "\ " The report should focus on the answer to the question, should be well structured, informative,"
" in depth, with facts and numbers if available, a minimum of 1,200 words and with markdown syntax and apa format. "
"Write all source urls at the end of the report in apa format" "Write all source urls at the end of the report in apa format"
)
def generate_search_queries_prompt(question): def generate_search_queries_prompt(question):
""" Generates the search queries prompt for the given question. """Generates the search queries prompt for the given question.
Args: question (str): The question to generate the search queries prompt for Args: question (str): The question to generate the search queries prompt for
Returns: str: The search queries prompt for the given question Returns: str: The search queries prompt for the given question
""" """
return f'Write 4 google search queries to search online that form an objective opinion from the following: "{question}"'\ return (
f'Write 4 google search queries to search online that form an objective opinion from the following: "{question}"'
f'You must respond with a list of strings in the following format: ["query 1", "query 2", "query 3", "query 4"]' f'You must respond with a list of strings in the following format: ["query 1", "query 2", "query 3", "query 4"]'
)
def generate_resource_report_prompt(question, research_summary): def generate_resource_report_prompt(question, research_summary):
@ -48,39 +51,45 @@ def generate_resource_report_prompt(question, research_summary):
Returns: Returns:
str: The resource report prompt for the given question and research summary. str: The resource report prompt for the given question and research summary.
""" """
return f'"""{research_summary}""" Based on the above information, generate a bibliography recommendation report for the following' \ return (
f' question or topic: "{question}". The report should provide a detailed analysis of each recommended resource,' \ f'"""{research_summary}""" Based on the above information, generate a bibliography recommendation report for the following'
' explaining how each source can contribute to finding answers to the research question.' \ f' question or topic: "{question}". The report should provide a detailed analysis of each recommended resource,'
' Focus on the relevance, reliability, and significance of each source.' \ " explaining how each source can contribute to finding answers to the research question."
' Ensure that the report is well-structured, informative, in-depth, and follows Markdown syntax.' \ " Focus on the relevance, reliability, and significance of each source."
' Include relevant facts, figures, and numbers whenever available.' \ " Ensure that the report is well-structured, informative, in-depth, and follows Markdown syntax."
' The report should have a minimum length of 1,200 words.' " Include relevant facts, figures, and numbers whenever available."
" The report should have a minimum length of 1,200 words."
)
def generate_outline_report_prompt(question, research_summary): def generate_outline_report_prompt(question, research_summary):
""" Generates the outline report prompt for the given question and research summary. """Generates the outline report prompt for the given question and research summary.
Args: question (str): The question to generate the outline report prompt for Args: question (str): The question to generate the outline report prompt for
research_summary (str): The research summary to generate the outline report prompt for research_summary (str): The research summary to generate the outline report prompt for
Returns: str: The outline report prompt for the given question and research summary Returns: str: The outline report prompt for the given question and research summary
""" """
return f'"""{research_summary}""" Using the above information, generate an outline for a research report in Markdown syntax'\ return (
f' for the following question or topic: "{question}". The outline should provide a well-structured framework'\ f'"""{research_summary}""" Using the above information, generate an outline for a research report in Markdown syntax'
' for the research report, including the main sections, subsections, and key points to be covered.' \ f' for the following question or topic: "{question}". The outline should provide a well-structured framework'
' The research report should be detailed, informative, in-depth, and a minimum of 1,200 words.' \ " for the research report, including the main sections, subsections, and key points to be covered."
' Use appropriate Markdown syntax to format the outline and ensure readability.' " The research report should be detailed, informative, in-depth, and a minimum of 1,200 words."
" Use appropriate Markdown syntax to format the outline and ensure readability."
)
def generate_concepts_prompt(question, research_summary): def generate_concepts_prompt(question, research_summary):
""" Generates the concepts prompt for the given question. """Generates the concepts prompt for the given question.
Args: question (str): The question to generate the concepts prompt for Args: question (str): The question to generate the concepts prompt for
research_summary (str): The research summary to generate the concepts prompt for research_summary (str): The research summary to generate the concepts prompt for
Returns: str: The concepts prompt for the given question Returns: str: The concepts prompt for the given question
""" """
return f'"""{research_summary}""" Using the above information, generate a list of 5 main concepts to learn for a research report'\ return (
f' on the following question or topic: "{question}". The outline should provide a well-structured framework'\ f'"""{research_summary}""" Using the above information, generate a list of 5 main concepts to learn for a research report'
f' on the following question or topic: "{question}". The outline should provide a well-structured framework'
'You must respond with a list of strings in the following format: ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]' 'You must respond with a list of strings in the following format: ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]'
)
def generate_lesson_prompt(concept): def generate_lesson_prompt(concept):
@ -92,17 +101,19 @@ def generate_lesson_prompt(concept):
str: The lesson prompt for the given concept. str: The lesson prompt for the given concept.
""" """
prompt = f'generate a comprehensive lesson about {concept} in Markdown syntax. This should include the definition'\ prompt = (
f'of {concept}, its historical background and development, its applications or uses in different'\ f"generate a comprehensive lesson about {concept} in Markdown syntax. This should include the definition"
f'fields, and notable events or facts related to {concept}.' f"of {concept}, its historical background and development, its applications or uses in different"
f"fields, and notable events or facts related to {concept}."
)
return prompt return prompt
def get_report_by_type(report_type): def get_report_by_type(report_type):
report_type_mapping = { report_type_mapping = {
'research_report': generate_report_prompt, "research_report": generate_report_prompt,
'resource_report': generate_resource_report_prompt, "resource_report": generate_resource_report_prompt,
'outline_report': generate_outline_report_prompt "outline_report": generate_outline_report_prompt,
} }
return report_type_mapping[report_type] return report_type_mapping[report_type]

@ -38,5 +38,7 @@ def debate_monitor(game_description, word_limit, character_names):
return prompt return prompt
def generate_character_header(game_description, topic, character_name, character_description): def generate_character_header(
game_description, topic, character_name, character_description
):
pass pass

@ -1,4 +1,4 @@
PROJECT_MANAGR_PROMPT_TEMPLATE = ''' PROJECT_MANAGR_PROMPT_TEMPLATE = """
# Context # Context
{context} {context}
@ -23,7 +23,7 @@ Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD W
## Anything UNCLEAR: Provide as Plain text. Make clear here. For example, don't forget a main entry. don't forget to init 3rd party libs. ## Anything UNCLEAR: Provide as Plain text. Make clear here. For example, don't forget a main entry. don't forget to init 3rd party libs.
''' """
FORMAT_EXAMPLE = ''' FORMAT_EXAMPLE = '''
--- ---

@ -1,5 +1,3 @@
SALES_ASSISTANT_PROMPT = """You are a sales assistant helping your sales agent to determine which stage of a sales conversation should the agent move to, or stay at. SALES_ASSISTANT_PROMPT = """You are a sales assistant helping your sales agent to determine which stage of a sales conversation should the agent move to, or stay at.
Following '===' is the conversation history. Following '===' is the conversation history.
Use this conversation history to make your decision. Use this conversation history to make your decision.
@ -47,10 +45,12 @@ Conversation history:
{salesperson_name}: {salesperson_name}:
""" """
conversation_stages = {'1': "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are contacting the prospect.", conversation_stages = {
'2': "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.", "1": "Introduction: Start the conversation by introducing yourself and your company. Be polite and respectful while keeping the tone of the conversation professional. Your greeting should be welcoming. Always clarify in your greeting the reason why you are contacting the prospect.",
'3': "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.", "2": "Qualification: Qualify the prospect by confirming if they are the right person to talk to regarding your product/service. Ensure that they have the authority to make purchasing decisions.",
'4': "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.", "3": "Value proposition: Briefly explain how your product/service can benefit the prospect. Focus on the unique selling points and value proposition of your product/service that sets it apart from competitors.",
'5': "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.", "4": "Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain points. Listen carefully to their responses and take notes.",
'6': "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.", "5": "Solution presentation: Based on the prospect's needs, present your product/service as the solution that can address their pain points.",
'7': "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits."} "6": "Objection handling: Address any objections that the prospect may have regarding your product/service. Be prepared to provide evidence or testimonials to support your claims.",
"7": "Close: Ask for the sale by proposing a next step. This could be a demo, a trial or a meeting with decision-makers. Ensure to summarize what has been discussed and reiterate the benefits.",
}

@ -1,4 +1,3 @@
SUMMARIZE_PROMPT = """ SUMMARIZE_PROMPT = """
Your output should use the following template: Your output should use the following template:
### Summary ### Summary

@ -5,10 +5,7 @@ from graphlib import TopologicalSorter
class Task: class Task:
def __init__( def __init__(
self, self, id: str, parents: List["Task"] = None, children: List["Task"] = None
id: str,
parents: List["Task"] = None,
children: List["Task"] = None
): ):
self.id = id self.id = id
self.parents = parents self.parents = parents
@ -48,11 +45,7 @@ class NonLinearWorkflow:
""" """
def __init__( def __init__(self, agents, iters_per_task):
self,
agents,
iters_per_task
):
"""A workflow is a collection of tasks that can be executed in parallel or sequentially.""" """A workflow is a collection of tasks that can be executed in parallel or sequentially."""
super().__init__() super().__init__()
self.executor = ThreadPoolExecutor() self.executor = ThreadPoolExecutor()
@ -61,10 +54,7 @@ class NonLinearWorkflow:
def add(self, task: Task): def add(self, task: Task):
"""Add a task to the workflow""" """Add a task to the workflow"""
assert isinstance( assert isinstance(task, Task), "Input must be an nstance of Task"
task,
Task
), "Input must be an nstance of Task"
self.tasks.append(task) self.tasks.append(task)
return task return task
@ -100,9 +90,5 @@ class NonLinearWorkflow:
def order_tasks(self) -> List[Task]: def order_tasks(self) -> List[Task]:
"""Order the tasks USING TOPOLOGICAL SORTING""" """Order the tasks USING TOPOLOGICAL SORTING"""
task_order = TopologicalSorter( task_order = TopologicalSorter(self.to_graph()).static_order()
self.to_graph() return [self.find_task(task_id) for task_id in task_order]
).static_order()
return [
self.find_task(task_id) for task_id in task_order
]

@ -24,7 +24,7 @@ class BaseTask(ABC):
self.parent_ids: List[str] = [] self.parent_ids: List[str] = []
self.child_ids: List[str] = [] self.child_ids: List[str] = []
self.output: Optional[Union[Artifact, ErrorArtifact]] = None self.output: Optional[Union[Artifact, ErrorArtifact]] = None
self.structure: Optional['Structure'] = None self.structure: Optional["Structure"] = None
@property @property
@abstractmethod @abstractmethod
@ -45,7 +45,7 @@ class BaseTask(ABC):
def __lshift__(self, child: BaseTask) -> BaseTask: def __lshift__(self, child: BaseTask) -> BaseTask:
return self.add_parent(child) return self.add_parent(child)
def preprocess(self, structure: 'Structure') -> BaseTask: def preprocess(self, structure: "Structure") -> BaseTask:
self.structure = structure self.structure = structure
return self return self
@ -117,7 +117,9 @@ class BaseTask(ABC):
return self.output return self.output
def can_execute(self) -> bool: def can_execute(self) -> bool:
return self.state == self.State.PENDING and all(parent.is_finished() for parent in self.parents) return self.state == self.State.PENDING and all(
parent.is_finished() for parent in self.parents
)
def reset(self) -> BaseTask: def reset(self) -> BaseTask:
self.state = self.State.PENDING self.state = self.State.PENDING
@ -130,21 +132,13 @@ class BaseTask(ABC):
class Task(BaseModel): class Task(BaseModel):
input: Optional[StrictStr] = Field( input: Optional[StrictStr] = Field(None, description="Input prompt for the task")
None,
description="Input prompt for the task"
)
additional_input: Optional[Any] = Field( additional_input: Optional[Any] = Field(
None, None, description="Input parameters for the task. Any value is allowed"
description="Input parameters for the task. Any value is allowed"
)
task_id: StrictStr = Field(
...,
description="ID of the task"
) )
task_id: StrictStr = Field(..., description="ID of the task")
artifacts: conlist(Artifact, min_items=1) = Field( artifacts: conlist(Artifact, min_items=1) = Field(
..., ..., description="A list of artifacts that the task has been produced"
description="A list of artifacts that the task has been produced"
) )
class Config: class Config:
@ -158,21 +152,26 @@ class Task(BaseModel):
return json.dumps(self.dict(by_alias=True, exclude_none=True)) return json.dumps(self.dict(by_alias=True, exclude_none=True))
@classmethod @classmethod
def from_json(cls, json_str: str) -> 'Task': def from_json(cls, json_str: str) -> "Task":
return cls.parse_raw(json_str) return cls.parse_raw(json_str)
def to_dict(self) -> dict: def to_dict(self) -> dict:
_dict = self.dict(by_alias=True, exclude_none=True) _dict = self.dict(by_alias=True, exclude_none=True)
if self.artifacts: if self.artifacts:
_dict["artifacts"] = [artifact.dict(by_alias=True, exclude_none=True) for artifact in self.artifacts] _dict["artifacts"] = [
artifact.dict(by_alias=True, exclude_none=True)
for artifact in self.artifacts
]
return _dict return _dict
@classmethod @classmethod
def from_dict(cls, obj: dict) -> 'Task': def from_dict(cls, obj: dict) -> "Task":
if obj is None: if obj is None:
return None return None
if not isinstance(obj, dict): if not isinstance(obj, dict):
raise ValueError("Input must be a dictionary.") raise ValueError("Input must be a dictionary.")
if 'artifacts' in obj: if "artifacts" in obj:
obj['artifacts'] = [Artifact.parse_obj(artifact) for artifact in obj['artifacts']] obj["artifacts"] = [
Artifact.parse_obj(artifact) for artifact in obj["artifacts"]
]
return cls.parse_obj(obj) return cls.parse_obj(obj)

@ -25,6 +25,7 @@ class Workflow:
""" """
class Task: class Task:
def __init__(self, task: str): def __init__(self, task: str):
self.task = task self.task = task
@ -33,7 +34,7 @@ class Workflow:
self.output = None self.output = None
self.structure = None self.structure = None
def add_child(self, child: 'Workflow.Task'): def add_child(self, child: "Workflow.Task"):
self.children.append(child) self.children.append(child)
child.parents.append(self) child.parents.append(self)
child.structure = self.structure child.structure = self.structure
@ -80,9 +81,11 @@ class Workflow:
def context(self, task: Task) -> Dict[str, Any]: def context(self, task: Task) -> Dict[str, Any]:
return { return {
"parent_output": task.parents[0].output if task.parents and task.parents[0].output else None, "parent_output": task.parents[0].output
if task.parents and task.parents[0].output
else None,
"parent": task.parents[0] if task.parents else None, "parent": task.parents[0] if task.parents else None,
"child": task.children[0] if task.children else None "child": task.children[0] if task.children else None,
} }
def __run_from_task(self, task: Optional[Task]) -> None: def __run_from_task(self, task: Optional[Task]) -> None:

@ -24,6 +24,7 @@ class AutoScaler:
auto_scaler.add_task9f"task {I}}) auto_scaler.add_task9f"task {I}})
``` ```
""" """
@log_decorator @log_decorator
@error_decorator @error_decorator
@timing_decorator @timing_decorator

@ -6,12 +6,7 @@ class DialogueSimulator:
def __init__(self, agents: List[Worker]): def __init__(self, agents: List[Worker]):
self.agents = agents self.agents = agents
def run( def run(self, max_iters: int, name: str = None, message: str = None):
self,
max_iters: int,
name: str = None,
message: str = None
):
step = 0 step = 0
if name and message: if name and message:
prompt = f"Name {name} and message: {message}" prompt = f"Name {name} and message: {message}"
@ -25,7 +20,9 @@ class DialogueSimulator:
speaker_message = speaker.run(prompt) speaker_message = speaker.run(prompt)
for receiver in self.agents: for receiver in self.agents:
message_history = f"Speaker Name: {speaker.name} and message: {speaker_message}" message_history = (
f"Speaker Name: {speaker.name} and message: {speaker_message}"
)
receiver.run(message_history) receiver.run(message_history)
print(f"({speaker.name}): {speaker_message}") print(f"({speaker.name}): {speaker_message}")

@ -30,10 +30,7 @@ class GodMode:
""" """
def __init__( def __init__(self, llms):
self,
llms
):
self.llms = llms self.llms = llms
def run(self, task): def run(self, task):
@ -49,10 +46,6 @@ class GodMode:
table.append([f"LLM {i+1}", response]) table.append([f"LLM {i+1}", response])
print( print(
colored( colored(
tabulate( tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan"
table,
headers=["LLM", "Response"],
tablefmt="pretty"
), "cyan"
) )
) )

@ -52,7 +52,8 @@ class GroupChat:
selector.update_system_message(self.select_speaker_msg()) selector.update_system_message(self.select_speaker_msg())
final, name = selector.run( final, name = selector.run(
self.messages + [ self.messages
+ [
{ {
"role": "system", "role": "system",
"context": f"Read the above conversation. Then select the next role from {self.worker_names} to play. Only return the role.", "context": f"Read the above conversation. Then select the next role from {self.worker_names} to play. Only return the role.",
@ -80,20 +81,17 @@ class GroupChatManager(Worker):
max_consecutive_auto_reply: Optional[int] = sys.maxsize, max_consecutive_auto_reply: Optional[int] = sys.maxsize,
human_input_mode: Optional[str] = "NEVER", human_input_mode: Optional[str] = "NEVER",
system_message: Optional[str] = "Group chat manager", system_message: Optional[str] = "Group chat manager",
**kwargs **kwargs,
): ):
super().__init__( super().__init__(
ai_name=ai_name, ai_name=ai_name,
# max_consecutive_auto_reply=max_consecutive_auto_reply, # max_consecutive_auto_reply=max_consecutive_auto_reply,
# human_input_mode=human_input_mode, # human_input_mode=human_input_mode,
# system_message=system_message, # system_message=system_message,
**kwargs **kwargs,
) )
self.register_reply( self.register_reply(
Worker, Worker, GroupChatManager.run, config=groupchat, reset_config=GroupChat.reset
GroupChatManager.run,
config=groupchat,
reset_config=GroupChat.reset
) )
def run( def run(
@ -147,11 +145,7 @@ class GroupChatManager(Worker):
break break
# speaker sends message without requesting a reply # speaker sends message without requesting a reply
speaker.send( speaker.send(reply, self, request_reply=False)
reply,
self,
request_reply=False
)
message = self.last_message(speaker) message = self.last_message(speaker)
message = self.last_messge(speaker) message = self.last_messge(speaker)
return True, None return True, None

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save