yapf code quality

pull/128/head
Kye 1 year ago
parent c279784458
commit 2e7905db46

@ -5,7 +5,7 @@
# Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i) # Run autopep8 with max aggressiveness (-aaa) and in-place modification (-i)
# on all Python files (*.py) under the 'swarms' directory. # on all Python files (*.py) under the 'swarms' directory.
autopep8 --in-place --aggressive --aggressive --recursive --experimental swarms/ autopep8 --in-place --aggressive --aggressive --recursive --experimental --list-fixes swarms/
# Run black with default settings, since black does not have an aggressiveness level. # Run black with default settings, since black does not have an aggressiveness level.
# Black will format all Python files it finds in the 'swarms' directory. # Black will format all Python files it finds in the 'swarms' directory.
@ -15,4 +15,5 @@ black --experimental-string-processing swarms/
# Add any additional flags if needed according to your version of ruff. # Add any additional flags if needed according to your version of ruff.
ruff swarms/ ruff swarms/
# If you want to ensure the script stops if any command fails, add 'set -e' at the top. # YAPF
yapf --recursive --in-place --verbose --style=google --parallel swarms

@ -8,8 +8,6 @@ from swarms.agents.registry import Registry
# from swarms.agents.idea_to_image_agent import Idea2Image # from swarms.agents.idea_to_image_agent import Idea2Image
from swarms.agents.simple_agent import SimpleAgent from swarms.agents.simple_agent import SimpleAgent
"""Agent Infrastructure, models, memory, utils, tools""" """Agent Infrastructure, models, memory, utils, tools"""
__all__ = [ __all__ = [

@ -8,8 +8,7 @@ from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.memory import ChatMessageHistory from langchain.memory import ChatMessageHistory
from langchain.prompts.chat import ( from langchain.prompts.chat import (
BaseChatPromptTemplate, BaseChatPromptTemplate,)
)
from langchain.schema import ( from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
Document, Document,
@ -34,7 +33,6 @@ from langchain_experimental.autonomous_agents.autogpt.prompt_generator import (
) )
from langchain_experimental.pydantic_v1 import BaseModel, ValidationError from langchain_experimental.pydantic_v1 import BaseModel, ValidationError
# PROMPT # PROMPT
FINISH_NAME = "finish" FINISH_NAME = "finish"
@ -72,14 +70,12 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
send_token_limit: int = 4196 send_token_limit: int = 4196
def construct_full_prompt(self, goals: List[str]) -> str: def construct_full_prompt(self, goals: List[str]) -> str:
prompt_start = ( prompt_start = ("Your decisions must always be made independently "
"Your decisions must always be made independently " "without seeking user assistance.\n"
"without seeking user assistance.\n" "Play to your strengths as an LLM and pursue simple "
"Play to your strengths as an LLM and pursue simple " "strategies with no legal complications.\n"
"strategies with no legal complications.\n" "If you have completed all your tasks, make sure to "
"If you have completed all your tasks, make sure to " 'use the "finish" command.')
'use the "finish" command.'
)
# Construct full prompt # Construct full prompt
full_prompt = ( full_prompt = (
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
@ -91,25 +87,23 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
return full_prompt return full_prompt
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"])) base_prompt = SystemMessage(
content=self.construct_full_prompt(kwargs["goals"]))
time_prompt = SystemMessage( time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}" content=f"The current time and date is {time.strftime('%c')}")
) used_tokens = self.token_counter(
used_tokens = self.token_counter(base_prompt.content) + self.token_counter( base_prompt.content) + self.token_counter(time_prompt.content)
time_prompt.content
)
memory: VectorStoreRetriever = kwargs["memory"] memory: VectorStoreRetriever = kwargs["memory"]
previous_messages = kwargs["messages"] previous_messages = kwargs["messages"]
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:])) relevant_docs = memory.get_relevant_documents(
str(previous_messages[-10:]))
relevant_memory = [d.page_content for d in relevant_docs] relevant_memory = [d.page_content for d in relevant_docs]
relevant_memory_tokens = sum( relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory] [self.token_counter(doc) for doc in relevant_memory])
)
while used_tokens + relevant_memory_tokens > 2500: while used_tokens + relevant_memory_tokens > 2500:
relevant_memory = relevant_memory[:-1] relevant_memory = relevant_memory[:-1]
relevant_memory_tokens = sum( relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory] [self.token_counter(doc) for doc in relevant_memory])
)
content_format = ( content_format = (
f"This reminds you of these events from your past:\n{relevant_memory}\n\n" f"This reminds you of these events from your past:\n{relevant_memory}\n\n"
) )
@ -147,13 +141,23 @@ class PromptGenerator:
self.performance_evaluation: List[str] = [] self.performance_evaluation: List[str] = []
self.response_format = { self.response_format = {
"thoughts": { "thoughts": {
"text": "thought", "text":
"reasoning": "reasoning", "thought",
"plan": "- short bulleted\n- list that conveys\n- long-term plan", "reasoning":
"criticism": "constructive self-criticism", "reasoning",
"speak": "thoughts summary to say to user", "plan":
"- short bulleted\n- list that conveys\n- long-term plan",
"criticism":
"constructive self-criticism",
"speak":
"thoughts summary to say to user",
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}, },
"command": {"name": "command name", "args": {"arg name": "value"}},
} }
def add_constraint(self, constraint: str) -> None: def add_constraint(self, constraint: str) -> None:
@ -191,7 +195,9 @@ class PromptGenerator:
""" """
self.performance_evaluation.append(evaluation) self.performance_evaluation.append(evaluation)
def _generate_numbered_list(self, items: list, item_type: str = "list") -> str: def _generate_numbered_list(self,
items: list,
item_type: str = "list") -> str:
""" """
Generate a numbered list from given items based on the item_type. Generate a numbered list from given items based on the item_type.
@ -209,16 +215,11 @@ class PromptGenerator:
for i, item in enumerate(items) for i, item in enumerate(items)
] ]
finish_description = ( finish_description = (
"use this to signal that you have finished all your objectives" "use this to signal that you have finished all your objectives")
) finish_args = ('"response": "final response to let '
finish_args = ( 'people know you have finished your objectives"')
'"response": "final response to let ' finish_string = (f"{len(items) + 1}. {FINISH_NAME}: "
'people know you have finished your objectives"' f"{finish_description}, args: {finish_args}")
)
finish_string = (
f"{len(items) + 1}. {FINISH_NAME}: "
f"{finish_description}, args: {finish_args}"
)
return "\n".join(command_strings + [finish_string]) return "\n".join(command_strings + [finish_string])
else: else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
@ -239,8 +240,7 @@ class PromptGenerator:
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below " "You should only respond in JSON format as described below "
f"\nResponse Format: \n{formatted_response_format} " f"\nResponse Format: \n{formatted_response_format} "
"\nEnsure the response can be parsed by Python json.loads" "\nEnsure the response can be parsed by Python json.loads")
)
return prompt_string return prompt_string
@ -261,13 +261,11 @@ def get_prompt(tools: List[BaseTool]) -> str:
prompt_generator.add_constraint( prompt_generator.add_constraint(
"~16000 word limit for short term memory. " "~16000 word limit for short term memory. "
"Your short term memory is short, " "Your short term memory is short, "
"so immediately save important information to files." "so immediately save important information to files.")
)
prompt_generator.add_constraint( prompt_generator.add_constraint(
"If you are unsure how you previously did something " "If you are unsure how you previously did something "
"or want to recall past events, " "or want to recall past events, "
"thinking about similar events will help you remember." "thinking about similar events will help you remember.")
)
prompt_generator.add_constraint("No user assistance") prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint( prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"' 'Exclusively use the commands listed in double quotes e.g. "command name"'
@ -279,29 +277,23 @@ def get_prompt(tools: List[BaseTool]) -> str:
# Add resources to the PromptGenerator object # Add resources to the PromptGenerator object
prompt_generator.add_resource( prompt_generator.add_resource(
"Internet access for searches and information gathering." "Internet access for searches and information gathering.")
)
prompt_generator.add_resource("Long Term memory management.") prompt_generator.add_resource("Long Term memory management.")
prompt_generator.add_resource( prompt_generator.add_resource(
"GPT-3.5 powered Agents for delegation of simple tasks." "GPT-3.5 powered Agents for delegation of simple tasks.")
)
prompt_generator.add_resource("File output.") prompt_generator.add_resource("File output.")
# Add performance evaluations to the PromptGenerator object # Add performance evaluations to the PromptGenerator object
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Continuously review and analyze your actions " "Continuously review and analyze your actions "
"to ensure you are performing to the best of your abilities." "to ensure you are performing to the best of your abilities.")
)
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Constructively self-criticize your big-picture behavior constantly." "Constructively self-criticize your big-picture behavior constantly.")
)
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Reflect on past decisions and strategies to refine your approach." "Reflect on past decisions and strategies to refine your approach.")
)
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Every command has a cost, so be smart and efficient. " "Every command has a cost, so be smart and efficient. "
"Aim to complete tasks in the least number of steps." "Aim to complete tasks in the least number of steps.")
)
# Generate the prompt string # Generate the prompt string
prompt_string = prompt_generator.generate_prompt_string() prompt_string = prompt_generator.generate_prompt_string()
@ -372,10 +364,8 @@ class AutoGPT:
) )
def run(self, goals: List[str]) -> str: def run(self, goals: List[str]) -> str:
user_input = ( user_input = ("Determine which next command to use, "
"Determine which next command to use, " "and respond using the format specified above:")
"and respond using the format specified above:"
)
# Interaction Loop # Interaction Loop
loop_count = 0 loop_count = 0
while True: while True:
@ -392,8 +382,10 @@ class AutoGPT:
# Print Assistant thoughts # Print Assistant thoughts
print(assistant_reply) print(assistant_reply)
self.chat_history_memory.add_message(HumanMessage(content=user_input)) self.chat_history_memory.add_message(
self.chat_history_memory.add_message(AIMessage(content=assistant_reply)) HumanMessage(content=user_input))
self.chat_history_memory.add_message(
AIMessage(content=assistant_reply))
# Get command name and arguments # Get command name and arguments
action = self.output_parser.parse(assistant_reply) action = self.output_parser.parse(assistant_reply)
@ -419,8 +411,7 @@ class AutoGPT:
result = ( result = (
f"Unknown command '{action.name}'. " f"Unknown command '{action.name}'. "
"Please refer to the 'COMMANDS' list for available " "Please refer to the 'COMMANDS' list for available "
"commands and only respond in the specified JSON format." "commands and only respond in the specified JSON format.")
)
memory_to_add = f"Assistant Reply: {assistant_reply} \nResult: {result} " memory_to_add = f"Assistant Reply: {assistant_reply} \nResult: {result} "
if self.feedback_tool is not None: if self.feedback_tool is not None:

@ -4,13 +4,13 @@ import time
import openai_model import openai_model
logging.basicConfig( logging.basicConfig(level=logging.INFO,
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" format="%(asctime)s - %(levelname)s - %(message)s")
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpenAI: class OpenAI:
def __init__( def __init__(
self, self,
api_key, api_key,
@ -68,16 +68,13 @@ class OpenAI:
temperature=temperature, temperature=temperature,
) )
with open("openai.logs", "a") as log_file: with open("openai.logs", "a") as log_file:
log_file.write( log_file.write("\n" + "-----------" + "\n" + "Prompt : " +
"\n" + "-----------" + "\n" + "Prompt : " + prompt + "\n" prompt + "\n")
)
return response return response
except openai_model.error.RateLimitError as e: except openai_model.error.RateLimitError as e:
sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30) sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30)
print( print(f"{str(e)}, sleep for {sleep_duratoin}s, set it by env"
f"{str(e)}, sleep for {sleep_duratoin}s, set it by env" " OPENAI_RATE_TIMEOUT")
" OPENAI_RATE_TIMEOUT"
)
time.sleep(sleep_duratoin) time.sleep(sleep_duratoin)
def openai_choice2text_handler(self, choice): def openai_choice2text_handler(self, choice):
@ -100,11 +97,16 @@ class OpenAI:
else: else:
response = self.run(prompt, 300, 0.5, k) response = self.run(prompt, 300, 0.5, k)
thoughts = [ thoughts = [
self.openai_choice2text_handler(choice) for choice in response.choices self.openai_choice2text_handler(choice)
for choice in response.choices
] ]
return thoughts return thoughts
def generate_thoughts(self, state, k, initial_prompt, rejected_solutions=None): def generate_thoughts(self,
state,
k,
initial_prompt,
rejected_solutions=None):
if isinstance(state, str): if isinstance(state, str):
pass pass
else: else:
@ -177,7 +179,8 @@ class OpenAI:
""" """
response = self.run(prompt, 10, 1) response = self.run(prompt, 10, 1)
try: try:
value_text = self.openai_choice2text_handler(response.choices[0]) value_text = self.openai_choice2text_handler(
response.choices[0])
# print(f'state: {value_text}') # print(f'state: {value_text}')
value = float(value_text) value = float(value_text)
print(f"Evaluated Thought Value: {value}") print(f"Evaluated Thought Value: {value}")
@ -187,10 +190,12 @@ class OpenAI:
return state_values return state_values
else: else:
raise ValueError("Invalid evaluation strategy. Choose 'value' or 'vote'.") raise ValueError(
"Invalid evaluation strategy. Choose 'value' or 'vote'.")
class AoTAgent: class AoTAgent:
def __init__( def __init__(
self, self,
num_thoughts: int = None, num_thoughts: int = None,
@ -222,7 +227,8 @@ class AoTAgent:
return None return None
best_state, _ = max(self.output, key=lambda x: x[1]) best_state, _ = max(self.output, key=lambda x: x[1])
solution = self.model.generate_solution(self.initial_prompt, best_state) solution = self.model.generate_solution(self.initial_prompt,
best_state)
print(f"Solution is {solution}") print(f"Solution is {solution}")
return solution if solution else best_state return solution if solution else best_state
except Exception as error: except Exception as error:
@ -239,11 +245,8 @@ class AoTAgent:
for next_state in thoughts: for next_state in thoughts:
state_value = self.evaluated_thoughts[next_state] state_value = self.evaluated_thoughts[next_state]
if state_value > self.value_threshold: if state_value > self.value_threshold:
child = ( child = ((state, next_state) if isinstance(state, str) else
(state, next_state) (*state, next_state))
if isinstance(state, str)
else (*state, next_state)
)
self.dfs(child, step + 1) self.dfs(child, step + 1)
# backtracking # backtracking
@ -253,17 +256,14 @@ class AoTAgent:
continue continue
def generate_and_filter_thoughts(self, state): def generate_and_filter_thoughts(self, state):
thoughts = self.model.generate_thoughts( thoughts = self.model.generate_thoughts(state, self.num_thoughts,
state, self.num_thoughts, self.initial_prompt self.initial_prompt)
)
self.evaluated_thoughts = self.model.evaluate_states( self.evaluated_thoughts = self.model.evaluate_states(
thoughts, self.initial_prompt thoughts, self.initial_prompt)
)
filtered_thoughts = [ filtered_thoughts = [
thought thought for thought in thoughts
for thought in thoughts
if self.evaluated_thoughts[thought] >= self.pruning_threshold if self.evaluated_thoughts[thought] >= self.pruning_threshold
] ]
print(f"filtered_thoughts: {filtered_thoughts}") print(f"filtered_thoughts: {filtered_thoughts}")

@ -38,7 +38,8 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None):
if not os.path.exists("agents"): if not os.path.exists("agents"):
os.makedirs("agents") os.makedirs("agents")
if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local": if os.path.exists(
f"agents/{agent_name}.py") and config.environment != "local":
if not _is_blank_agent(agent_name=agent_name): if not _is_blank_agent(agent_name=agent_name):
raise Exception(f"Agent with name {agent_name} already exists") raise Exception(f"Agent with name {agent_name} already exists")
driver = get_driver( # noqa: F841 driver = get_driver( # noqa: F841
@ -54,12 +55,10 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None):
print( print(
"\033[34mYou have the Python debugger open, you can run commands in it like you" "\033[34mYou have the Python debugger open, you can run commands in it like you"
" would in a normal Python shell.\033[0m" " would in a normal Python shell.\033[0m")
)
print( print(
"\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and" "\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and"
" press enter.\033[0m" " press enter.\033[0m")
)
breakpoint() breakpoint()
@ -79,12 +78,13 @@ def extract_domain_from_url(url: str):
class AutotabChromeDriver(uc.Chrome): class AutotabChromeDriver(uc.Chrome):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def find_element_with_retry( def find_element_with_retry(self,
self, by=By.ID, value: Optional[str] = None by=By.ID,
) -> WebElement: value: Optional[str] = None) -> WebElement:
try: try:
return super().find_element(by, value) return super().find_element(by, value)
except Exception as e: except Exception as e:
@ -102,11 +102,8 @@ def open_plugin(driver: AutotabChromeDriver):
def open_plugin_and_login(driver: AutotabChromeDriver): def open_plugin_and_login(driver: AutotabChromeDriver):
if config.autotab_api_key is not None: if config.autotab_api_key is not None:
backend_url = ( backend_url = ("http://localhost:8000" if config.environment == "local"
"http://localhost:8000" else "https://api.autotab.com")
if config.environment == "local"
else "https://api.autotab.com"
)
driver.get(f"{backend_url}/auth/signin-api-key-page") driver.get(f"{backend_url}/auth/signin-api-key-page")
response = requests.post( response = requests.post(
f"{backend_url}/auth/signin-api-key", f"{backend_url}/auth/signin-api-key",
@ -119,8 +116,7 @@ def open_plugin_and_login(driver: AutotabChromeDriver):
else: else:
raise Exception( raise Exception(
f"Error {response.status_code} from backend while logging you in" f"Error {response.status_code} from backend while logging you in"
f" with your API key: {response.text}" f" with your API key: {response.text}")
)
cookie["name"] = cookie["key"] cookie["name"] = cookie["key"]
del cookie["key"] del cookie["key"]
driver.add_cookie(cookie) driver.add_cookie(cookie)
@ -130,26 +126,21 @@ def open_plugin_and_login(driver: AutotabChromeDriver):
else: else:
print("No autotab API key found, heading to autotab.com to sign up") print("No autotab API key found, heading to autotab.com to sign up")
url = ( url = ("http://localhost:3000/dashboard" if config.environment
"http://localhost:3000/dashboard" == "local" else "https://autotab.com/dashboard")
if config.environment == "local"
else "https://autotab.com/dashboard"
)
driver.get(url) driver.get(url)
time.sleep(0.5) time.sleep(0.5)
open_plugin(driver) open_plugin(driver)
def get_driver( def get_driver(autotab_ext_path: Optional[str] = None,
autotab_ext_path: Optional[str] = None, record_mode: bool = False record_mode: bool = False) -> AutotabChromeDriver:
) -> AutotabChromeDriver:
options = webdriver.ChromeOptions() options = webdriver.ChromeOptions()
options.add_argument("--no-sandbox") # Necessary for running options.add_argument("--no-sandbox") # Necessary for running
options.add_argument( options.add_argument(
"--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" "--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
" (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" " (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36")
)
options.add_argument("--enable-webgl") options.add_argument("--enable-webgl")
options.add_argument("--enable-3d-apis") options.add_argument("--enable-3d-apis")
options.add_argument("--enable-clipboard-read-write") options.add_argument("--enable-clipboard-read-write")
@ -238,7 +229,8 @@ class Config(BaseModel):
return cls( return cls(
autotab_api_key=autotab_api_key, autotab_api_key=autotab_api_key,
credentials=_credentials, credentials=_credentials,
google_credentials=GoogleCredentials(credentials=google_credentials), google_credentials=GoogleCredentials(
credentials=google_credentials),
chrome_binary_location=config.get("chrome_binary_location"), chrome_binary_location=config.get("chrome_binary_location"),
environment=config.get("environment", "prod"), environment=config.get("environment", "prod"),
) )
@ -256,9 +248,9 @@ def is_signed_in_to_google(driver):
return len([c for c in cookies if c["name"] == "SAPISID"]) != 0 return len([c for c in cookies if c["name"] == "SAPISID"]) != 0
def google_login( def google_login(driver,
driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True credentials: Optional[SiteCredentials] = None,
): navigate: bool = True):
print("Logging in to Google") print("Logging in to Google")
if navigate: if navigate:
driver.get("https://accounts.google.com/") driver.get("https://accounts.google.com/")
@ -290,8 +282,7 @@ def google_login(
email_input.send_keys(credentials.email) email_input.send_keys(credentials.email)
email_input.send_keys(Keys.ENTER) email_input.send_keys(Keys.ENTER)
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']")) EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']")))
)
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
password_input.send_keys(credentials.password) password_input.send_keys(credentials.password)
@ -314,21 +305,20 @@ def google_login(
cookies = driver.get_cookies() cookies = driver.get_cookies()
cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"] cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"]
google_cookies = [ google_cookies = [
cookie cookie for cookie in cookies
for cookie in cookies if cookie["domain"] in [".google.com", "accounts.google.com"] and
if cookie["domain"] in [".google.com", "accounts.google.com"] cookie["name"] in cookie_names
and cookie["name"] in cookie_names
] ]
with open("google_cookies.json", "w") as f: with open("google_cookies.json", "w") as f:
json.dump(google_cookies, f) json.dump(google_cookies, f)
# Log back in # Log back in
login_button = driver.find_element( login_button = driver.find_element(
By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']" By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']")
)
login_button.click() login_button.click()
time.sleep(1) time.sleep(1)
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") password_input = driver.find_element(By.CSS_SELECTOR,
"[type='password']")
password_input.send_keys(credentials.password) password_input.send_keys(credentials.password)
password_input.send_keys(Keys.ENTER) password_input.send_keys(Keys.ENTER)
@ -343,8 +333,7 @@ def login(driver, url: str):
login_url = credentials.login_url login_url = credentials.login_url
if credentials.login_with_google_account: if credentials.login_with_google_account:
google_credentials = config.google_credentials.credentials[ google_credentials = config.google_credentials.credentials[
credentials.login_with_google_account credentials.login_with_google_account]
]
_login_with_google(driver, login_url, google_credentials) _login_with_google(driver, login_url, google_credentials)
else: else:
_login(driver, login_url, credentials=credentials) _login(driver, login_url, credentials=credentials)
@ -371,16 +360,15 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
driver.get(url) driver.get(url)
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body")) EC.presence_of_element_located((By.TAG_NAME, "body")))
)
main_window = driver.current_window_handle main_window = driver.current_window_handle
xpath = ( xpath = (
"//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with" "//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with"
" Google') or contains(@title, 'Sign in with Google')]" " Google') or contains(@title, 'Sign in with Google')]")
)
WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath))) WebDriverWait(driver,
10).until(EC.presence_of_element_located((By.XPATH, xpath)))
driver.find_element( driver.find_element(
By.XPATH, By.XPATH,
xpath, xpath,
@ -388,8 +376,8 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
driver.switch_to.window(driver.window_handles[-1]) driver.switch_to.window(driver.window_handles[-1])
driver.find_element( driver.find_element(
By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]" By.XPATH,
).click() f"//*[contains(text(), '{google_credentials.email}')]").click()
driver.switch_to.window(main_window) driver.switch_to.window(main_window)
@ -442,8 +430,11 @@ def should_update():
# Parse the XML file # Parse the XML file
root = ET.fromstring(xml_content) root = ET.fromstring(xml_content)
namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces namespaces = {
xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version") "ns": "http://www.google.com/update2/response"
} # add namespaces
xml_version = root.find(".//ns:app/ns:updatecheck",
namespaces).get("version")
# Load the local JSON file # Load the local JSON file
with open("src/extension/autotab/manifest.json", "r") as f: with open("src/extension/autotab/manifest.json", "r") as f:
@ -484,8 +475,6 @@ def play(agent_name: Optional[str] = None):
if __name__ == "__main__": if __name__ == "__main__":
play() play()
""" """

@ -19,7 +19,6 @@ from transformers.utils import is_offline_mode, is_openai_available, logging
# utils # utils
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_openai_available(): if is_openai_available():
import openai import openai
@ -28,7 +27,6 @@ else:
_tools_are_initialized = False _tools_are_initialized = False
BASE_PYTHON_TOOLS = { BASE_PYTHON_TOOLS = {
"print": print, "print": print,
"range": range, "range": range,
@ -48,7 +46,6 @@ class PreTool:
HUGGINGFACE_DEFAULT_TOOLS = {} HUGGINGFACE_DEFAULT_TOOLS = {}
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [ HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
"image-transformation", "image-transformation",
"text-download", "text-download",
@ -59,23 +56,24 @@ HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
def get_remote_tools(organization="huggingface-tools"): def get_remote_tools(organization="huggingface-tools"):
if is_offline_mode(): if is_offline_mode():
logger.info("You are in offline mode, so remote tools are not available.") logger.info(
"You are in offline mode, so remote tools are not available.")
return {} return {}
spaces = list_spaces(author=organization) spaces = list_spaces(author=organization)
tools = {} tools = {}
for space_info in spaces: for space_info in spaces:
repo_id = space_info.id repo_id = space_info.id
resolved_config_file = hf_hub_download( resolved_config_file = hf_hub_download(repo_id,
repo_id, TOOL_CONFIG_FILE, repo_type="space" TOOL_CONFIG_FILE,
) repo_type="space")
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader) config = json.load(reader)
task = repo_id.split("/")[-1] task = repo_id.split("/")[-1]
tools[config["name"]] = PreTool( tools[config["name"]] = PreTool(task=task,
task=task, description=config["description"], repo_id=repo_id description=config["description"],
) repo_id=repo_id)
return tools return tools
@ -95,8 +93,7 @@ def _setup_default_tools():
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
description = tool_class.description description = tool_class.description
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool( HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(
task=task_name, description=description, repo_id=None task=task_name, description=description, repo_id=None)
)
if not is_offline_mode(): if not is_offline_mode():
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB: for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
@ -200,18 +197,19 @@ class Agent:
one of the default tools, that default tool will be overridden. one of the default tools, that default tool will be overridden.
""" """
def __init__( def __init__(self,
self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None chat_prompt_template=None,
): run_prompt_template=None,
additional_tools=None):
_setup_default_tools() _setup_default_tools()
agent_name = self.__class__.__name__ agent_name = self.__class__.__name__
self.chat_prompt_template = download_prompt( self.chat_prompt_template = download_prompt(chat_prompt_template,
chat_prompt_template, agent_name, mode="chat" agent_name,
) mode="chat")
self.run_prompt_template = download_prompt( self.run_prompt_template = download_prompt(run_prompt_template,
run_prompt_template, agent_name, mode="run" agent_name,
) mode="run")
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy() self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
self.log = print self.log = print
if additional_tools is not None: if additional_tools is not None:
@ -227,17 +225,16 @@ class Agent:
} }
self._toolbox.update(additional_tools) self._toolbox.update(additional_tools)
if len(replacements) > 1: if len(replacements) > 1:
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()]) names = "\n".join(
[f"- {n}: {t}" for n, t in replacements.items()])
logger.warning( logger.warning(
"The following tools have been replaced by the ones provided in" "The following tools have been replaced by the ones provided in"
f" `additional_tools`:\n{names}." f" `additional_tools`:\n{names}.")
)
elif len(replacements) == 1: elif len(replacements) == 1:
name = list(replacements.keys())[0] name = list(replacements.keys())[0]
logger.warning( logger.warning(
f"{name} has been replaced by {replacements[name]} as provided in" f"{name} has been replaced by {replacements[name]} as provided in"
" `additional_tools`." " `additional_tools`.")
)
self.prepare_for_new_chat() self.prepare_for_new_chat()
@ -247,17 +244,20 @@ class Agent:
return self._toolbox return self._toolbox
def format_prompt(self, task, chat_mode=False): def format_prompt(self, task, chat_mode=False):
description = "\n".join( description = "\n".join([
[f"- {name}: {tool.description}" for name, tool in self.toolbox.items()] f"- {name}: {tool.description}"
) for name, tool in self.toolbox.items()
])
if chat_mode: if chat_mode:
if self.chat_history is None: if self.chat_history is None:
prompt = self.chat_prompt_template.replace("<<all_tools>>", description) prompt = self.chat_prompt_template.replace(
"<<all_tools>>", description)
else: else:
prompt = self.chat_history prompt = self.chat_history
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task) prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
else: else:
prompt = self.run_prompt_template.replace("<<all_tools>>", description) prompt = self.run_prompt_template.replace("<<all_tools>>",
description)
prompt = prompt.replace("<<prompt>>", task) prompt = prompt.replace("<<prompt>>", task)
return prompt return prompt
@ -306,14 +306,19 @@ class Agent:
if not return_code: if not return_code:
self.log("\n\n==Result==") self.log("\n\n==Result==")
self.cached_tools = resolve_tools( self.cached_tools = resolve_tools(
code, self.toolbox, remote=remote, cached_tools=self.cached_tools code,
) self.toolbox,
remote=remote,
cached_tools=self.cached_tools)
self.chat_state.update(kwargs) self.chat_state.update(kwargs)
return evaluate( return evaluate(code,
code, self.cached_tools, self.chat_state, chat_mode=True self.cached_tools,
) self.chat_state,
chat_mode=True)
else: else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) tool_code = get_tool_creation_code(code,
self.toolbox,
remote=remote)
return f"{tool_code}\n{code}" return f"{tool_code}\n{code}"
def prepare_for_new_chat(self): def prepare_for_new_chat(self):
@ -355,12 +360,15 @@ class Agent:
self.log(f"\n\n==Code generated by the agent==\n{code}") self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code: if not return_code:
self.log("\n\n==Result==") self.log("\n\n==Result==")
self.cached_tools = resolve_tools( self.cached_tools = resolve_tools(code,
code, self.toolbox, remote=remote, cached_tools=self.cached_tools self.toolbox,
) remote=remote,
cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy()) return evaluate(code, self.cached_tools, state=kwargs.copy())
else: else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote) tool_code = get_tool_creation_code(code,
self.toolbox,
remote=remote)
return f"{tool_code}\n{code}" return f"{tool_code}\n{code}"
def generate_one(self, prompt, stop): def generate_one(self, prompt, stop):
@ -420,8 +428,7 @@ class HFAgent(Agent):
): ):
if not is_openai_available(): if not is_openai_available():
raise ImportError( raise ImportError(
"Using `OpenAiAgent` requires `openai`: `pip install openai`." "Using `OpenAiAgent` requires `openai`: `pip install openai`.")
)
if api_key is None: if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY", None) api_key = os.environ.get("OPENAI_API_KEY", None)
@ -429,8 +436,7 @@ class HFAgent(Agent):
raise ValueError( raise ValueError(
"You need an openai key to use `OpenAIAgent`. You can get one here: Get" "You need an openai key to use `OpenAIAgent`. You can get one here: Get"
" one here https://openai.com/api/`. If you have one, set it in your" " one here https://openai.com/api/`. If you have one, set it in your"
" env with `os.environ['OPENAI_API_KEY'] = xxx." " env with `os.environ['OPENAI_API_KEY'] = xxx.")
)
else: else:
openai.api_key = api_key openai.api_key = api_key
self.model = model self.model = model
@ -455,7 +461,10 @@ class HFAgent(Agent):
def _chat_generate(self, prompt, stop): def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create( result = openai.ChatCompletion.create(
model=self.model, model=self.model,
messages=[{"role": "user", "content": prompt}], messages=[{
"role": "user",
"content": prompt
}],
temperature=0, temperature=0,
stop=stop, stop=stop,
) )
@ -533,8 +542,7 @@ class AzureOpenAI(Agent):
): ):
if not is_openai_available(): if not is_openai_available():
raise ImportError( raise ImportError(
"Using `OpenAiAgent` requires `openai`: `pip install openai`." "Using `OpenAiAgent` requires `openai`: `pip install openai`.")
)
self.deployment_id = deployment_id self.deployment_id = deployment_id
openai.api_type = "azure" openai.api_type = "azure"
@ -544,8 +552,7 @@ class AzureOpenAI(Agent):
raise ValueError( raise ValueError(
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have" "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have"
" one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] =" " one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] ="
" xxx." " xxx.")
)
else: else:
openai.api_key = api_key openai.api_key = api_key
if resource_name is None: if resource_name is None:
@ -554,8 +561,7 @@ class AzureOpenAI(Agent):
raise ValueError( raise ValueError(
"You need a resource_name to use `AzureOpenAIAgent`. If you have one," "You need a resource_name to use `AzureOpenAIAgent`. If you have one,"
" set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] =" " set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] ="
" xxx." " xxx.")
)
else: else:
openai.api_base = f"https://{resource_name}.openai.azure.com" openai.api_base = f"https://{resource_name}.openai.azure.com"
openai.api_version = api_version openai.api_version = api_version
@ -585,7 +591,10 @@ class AzureOpenAI(Agent):
def _chat_generate(self, prompt, stop): def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create( result = openai.ChatCompletion.create(
engine=self.deployment_id, engine=self.deployment_id,
messages=[{"role": "user", "content": prompt}], messages=[{
"role": "user",
"content": prompt
}],
temperature=0, temperature=0,
stop=stop, stop=stop,
) )

@ -88,9 +88,8 @@ class MetaPrompterAgent:
Assistant: Assistant:
""" """
prompt = PromptTemplate( prompt = PromptTemplate(input_variables=["history", "human_input"],
input_variables=["history", "human_input"], template=template template=template)
)
self.chain = LLMChain( self.chain = LLMChain(
llm=self.llm(), llm=self.llm(),
@ -102,13 +101,15 @@ class MetaPrompterAgent:
def get_chat_history(self, chain_memory): def get_chat_history(self, chain_memory):
"""Get Chat History from the memory""" """Get Chat History from the memory"""
memory_key = chain_memory.memory_key memory_key = chain_memory.memory_key
chat_history = chain_memory.load_memory_variables(memory_key)[memory_key] chat_history = chain_memory.load_memory_variables(
memory_key)[memory_key]
return chat_history return chat_history
def get_new_instructions(self, meta_output): def get_new_instructions(self, meta_output):
"""Get New Instructions from the meta_output""" """Get New Instructions from the meta_output"""
delimiter = "Instructions: " delimiter = "Instructions: "
new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter) :] new_instructions = meta_output[meta_output.find(delimiter) +
len(delimiter):]
return new_instructions return new_instructions
def run(self, task: str): def run(self, task: str):
@ -149,8 +150,7 @@ class MetaPrompterAgent:
meta_chain = self.initialize_meta_chain() meta_chain = self.initialize_meta_chain()
meta_output = meta_chain.predict( meta_output = meta_chain.predict(
chat_history=self.get_chat_history(chain.memory) chat_history=self.get_chat_history(chain.memory))
)
print(f"Feedback: {meta_output}") print(f"Feedback: {meta_output}")
self.instructions = self.get_new_instructions(meta_output) self.instructions = self.get_new_instructions(meta_output)

File diff suppressed because it is too large Load Diff

@ -2,6 +2,7 @@
class Replicator: class Replicator:
def __init__( def __init__(
self, self,
model_name, model_name,

@ -3,23 +3,20 @@ from typing import Dict, List
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import ( from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import (
load_response_generator, load_response_generator,)
)
from langchain_experimental.autonomous_agents.hugginggpt.task_executor import ( from langchain_experimental.autonomous_agents.hugginggpt.task_executor import (
TaskExecutor, TaskExecutor,)
)
from langchain_experimental.autonomous_agents.hugginggpt.task_planner import ( from langchain_experimental.autonomous_agents.hugginggpt.task_planner import (
load_chat_planner, load_chat_planner,)
)
from transformers import load_tool from transformers import load_tool
from swarms.agents.message import Message from swarms.agents.message import Message
class Step: class Step:
def __init__(
self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool def __init__(self, task: str, id: int, dep: List[int], args: Dict[str, str],
): tool: BaseTool):
self.task = task self.task = task
self.id = id self.id = id
self.dep = dep self.dep = dep
@ -28,6 +25,7 @@ class Step:
class Plan: class Plan:
def __init__(self, steps: List[Step]): def __init__(self, steps: List[Step]):
self.steps = steps self.steps = steps
@ -73,8 +71,7 @@ class OmniModalAgent:
print("Loading tools...") print("Loading tools...")
self.tools = [ self.tools = [
load_tool(tool_name) load_tool(tool_name) for tool_name in [
for tool_name in [
"document-question-answering", "document-question-answering",
"image-captioning", "image-captioning",
"image-question-answering", "image-question-answering",
@ -99,18 +96,15 @@ class OmniModalAgent:
def run(self, input: str) -> str: def run(self, input: str) -> str:
"""Run the OmniAgent""" """Run the OmniAgent"""
plan = self.chat_planner.plan( plan = self.chat_planner.plan(inputs={
inputs={ "input": input,
"input": input, "hf_tools": self.tools,
"hf_tools": self.tools, })
}
)
self.task_executor = TaskExecutor(plan) self.task_executor = TaskExecutor(plan)
self.task_executor.run() self.task_executor.run()
response = self.response_generator.generate( response = self.response_generator.generate(
{"task_execution": self.task_executor} {"task_execution": self.task_executor})
)
return response return response

@ -145,13 +145,12 @@ def setup_knowledge_base(product_catalog: str = None):
llm = OpenAI(temperature=0) llm = OpenAI(temperature=0)
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_texts( docsearch = Chroma.from_texts(texts,
texts, embeddings, collection_name="product-knowledge-base" embeddings,
) collection_name="product-knowledge-base")
knowledge_base = RetrievalQA.from_chain_type( knowledge_base = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", retriever=docsearch.as_retriever() llm=llm, chain_type="stuff", retriever=docsearch.as_retriever())
)
return knowledge_base return knowledge_base
@ -163,8 +162,8 @@ def get_tools(product_catalog):
Tool( Tool(
name="ProductSearch", name="ProductSearch",
func=knowledge_base.run, func=knowledge_base.run,
description=( description=
"useful for when you need to answer questions about product information" ("useful for when you need to answer questions about product information"
), ),
), ),
# omnimodal agent # omnimodal agent
@ -194,8 +193,7 @@ class CustomPromptTemplateForTools(StringPromptTemplate):
tools = self.tools_getter(kwargs["input"]) tools = self.tools_getter(kwargs["input"])
# Create a tools variable from the list of tools provided # Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join( kwargs["tools"] = "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools] [f"{tool.name}: {tool.description}" for tool in tools])
)
# Create a list of tool names for the tools provided # Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in tools]) kwargs["tool_names"] = ", ".join([tool.name for tool in tools])
return self.template.format(**kwargs) return self.template.format(**kwargs)
@ -218,8 +216,7 @@ class SalesConvoOutputParser(AgentOutputParser):
print("-------") print("-------")
if f"{self.ai_prefix}:" in text: if f"{self.ai_prefix}:" in text:
return AgentFinish( return AgentFinish(
{"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text)
)
regex = r"Action: (.*?)[\n]*Action Input: (.*)" regex = r"Action: (.*?)[\n]*Action Input: (.*)"
match = re.search(regex, text) match = re.search(regex, text)
if not match: if not match:
@ -228,15 +225,15 @@ class SalesConvoOutputParser(AgentOutputParser):
{ {
"output": ( "output": (
"I apologize, I was unable to find the answer to your question." "I apologize, I was unable to find the answer to your question."
" Is there anything else I can help with?" " Is there anything else I can help with?")
)
}, },
text, text,
) )
# raise OutputParserException(f"Could not parse LLM output: `{text}`") # raise OutputParserException(f"Could not parse LLM output: `{text}`")
action = match.group(1) action = match.group(1)
action_input = match.group(2) action_input = match.group(2)
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text) return AgentAction(action.strip(),
action_input.strip(" ").strip('"'), text)
@property @property
def _type(self) -> str: def _type(self) -> str:
@ -264,13 +261,11 @@ class ProfitPilot(Chain, BaseModel):
"2": ( "2": (
"Qualification: Qualify the prospect by confirming if they are the right" "Qualification: Qualify the prospect by confirming if they are the right"
" person to talk to regarding your product/service. Ensure that they have" " person to talk to regarding your product/service. Ensure that they have"
" the authority to make purchasing decisions." " the authority to make purchasing decisions."),
),
"3": ( "3": (
"Value proposition: Briefly explain how your product/service can benefit" "Value proposition: Briefly explain how your product/service can benefit"
" the prospect. Focus on the unique selling points and value proposition of" " the prospect. Focus on the unique selling points and value proposition of"
" your product/service that sets it apart from competitors." " your product/service that sets it apart from competitors."),
),
"4": ( "4": (
"Needs analysis: Ask open-ended questions to uncover the prospect's needs" "Needs analysis: Ask open-ended questions to uncover the prospect's needs"
" and pain points. Listen carefully to their responses and take notes." " and pain points. Listen carefully to their responses and take notes."
@ -282,13 +277,11 @@ class ProfitPilot(Chain, BaseModel):
"6": ( "6": (
"Objection handling: Address any objections that the prospect may have" "Objection handling: Address any objections that the prospect may have"
" regarding your product/service. Be prepared to provide evidence or" " regarding your product/service. Be prepared to provide evidence or"
" testimonials to support your claims." " testimonials to support your claims."),
),
"7": ( "7": (
"Close: Ask for the sale by proposing a next step. This could be a demo, a" "Close: Ask for the sale by proposing a next step. This could be a demo, a"
" trial or a meeting with decision-makers. Ensure to summarize what has" " trial or a meeting with decision-makers. Ensure to summarize what has"
" been discussed and reiterate the benefits." " been discussed and reiterate the benefits."),
),
} }
salesperson_name: str = "Ted Lasso" salesperson_name: str = "Ted Lasso"
@ -298,19 +291,16 @@ class ProfitPilot(Chain, BaseModel):
"Sleep Haven is a premium mattress company that provides customers with the" "Sleep Haven is a premium mattress company that provides customers with the"
" most comfortable and supportive sleeping experience possible. We offer a" " most comfortable and supportive sleeping experience possible. We offer a"
" range of high-quality mattresses, pillows, and bedding accessories that are" " range of high-quality mattresses, pillows, and bedding accessories that are"
" designed to meet the unique needs of our customers." " designed to meet the unique needs of our customers.")
)
company_values: str = ( company_values: str = (
"Our mission at Sleep Haven is to help people achieve a better night's sleep by" "Our mission at Sleep Haven is to help people achieve a better night's sleep by"
" providing them with the best possible sleep solutions. We believe that" " providing them with the best possible sleep solutions. We believe that"
" quality sleep is essential to overall health and well-being, and we are" " quality sleep is essential to overall health and well-being, and we are"
" committed to helping our customers achieve optimal sleep by offering" " committed to helping our customers achieve optimal sleep by offering"
" exceptional products and customer service." " exceptional products and customer service.")
)
conversation_purpose: str = ( conversation_purpose: str = (
"find out whether they are looking to achieve better sleep via buying a premier" "find out whether they are looking to achieve better sleep via buying a premier"
" mattress." " mattress.")
)
conversation_type: str = "call" conversation_type: str = "call"
def retrieve_conversation_stage(self, key): def retrieve_conversation_stage(self, key):
@ -336,8 +326,7 @@ class ProfitPilot(Chain, BaseModel):
) )
self.current_conversation_stage = self.retrieve_conversation_stage( self.current_conversation_stage = self.retrieve_conversation_stage(
conversation_stage_id conversation_stage_id)
)
print(f"Conversation Stage: {self.current_conversation_stage}") print(f"Conversation Stage: {self.current_conversation_stage}")
@ -391,13 +380,15 @@ class ProfitPilot(Chain, BaseModel):
return {} return {}
@classmethod @classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = False, **kwargs): # noqa: F821 def from_llm(cls,
llm: BaseLLM,
verbose: bool = False,
**kwargs): # noqa: F821
"""Initialize the SalesGPT Controller.""" """Initialize the SalesGPT Controller."""
stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose) stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose)
sales_conversation_utterance_chain = SalesConversationChain.from_llm( sales_conversation_utterance_chain = SalesConversationChain.from_llm(
llm, verbose=verbose llm, verbose=verbose)
)
if "use_tools" in kwargs.keys() and kwargs["use_tools"] is False: if "use_tools" in kwargs.keys() and kwargs["use_tools"] is False:
sales_agent_executor = None sales_agent_executor = None
@ -430,7 +421,8 @@ class ProfitPilot(Chain, BaseModel):
# WARNING: this output parser is NOT reliable yet # WARNING: this output parser is NOT reliable yet
# It makes assumptions about output from LLM which can break and throw an error # It makes assumptions about output from LLM which can break and throw an error
output_parser = SalesConvoOutputParser(ai_prefix=kwargs["salesperson_name"]) output_parser = SalesConvoOutputParser(
ai_prefix=kwargs["salesperson_name"])
sales_agent_with_tools = LLMSingleActionAgent( sales_agent_with_tools = LLMSingleActionAgent(
llm_chain=llm_chain, llm_chain=llm_chain,
@ -441,12 +433,12 @@ class ProfitPilot(Chain, BaseModel):
) )
sales_agent_executor = AgentExecutor.from_agent_and_tools( sales_agent_executor = AgentExecutor.from_agent_and_tools(
agent=sales_agent_with_tools, tools=tools, verbose=verbose agent=sales_agent_with_tools, tools=tools, verbose=verbose)
)
return cls( return cls(
stage_analyzer_chain=stage_analyzer_chain, stage_analyzer_chain=stage_analyzer_chain,
sales_conversation_utterance_chain=sales_conversation_utterance_chain, sales_conversation_utterance_chain=
sales_conversation_utterance_chain,
sales_agent_executor=sales_agent_executor, sales_agent_executor=sales_agent_executor,
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
@ -458,32 +450,27 @@ config = dict(
salesperson_name="Ted Lasso", salesperson_name="Ted Lasso",
salesperson_role="Business Development Representative", salesperson_role="Business Development Representative",
company_name="Sleep Haven", company_name="Sleep Haven",
company_business=( company_business=
"Sleep Haven is a premium mattress company that provides customers with the" ("Sleep Haven is a premium mattress company that provides customers with the"
" most comfortable and supportive sleeping experience possible. We offer a" " most comfortable and supportive sleeping experience possible. We offer a"
" range of high-quality mattresses, pillows, and bedding accessories that are" " range of high-quality mattresses, pillows, and bedding accessories that are"
" designed to meet the unique needs of our customers." " designed to meet the unique needs of our customers."),
), company_values=
company_values=( ("Our mission at Sleep Haven is to help people achieve a better night's sleep by"
"Our mission at Sleep Haven is to help people achieve a better night's sleep by" " providing them with the best possible sleep solutions. We believe that"
" providing them with the best possible sleep solutions. We believe that" " quality sleep is essential to overall health and well-being, and we are"
" quality sleep is essential to overall health and well-being, and we are" " committed to helping our customers achieve optimal sleep by offering"
" committed to helping our customers achieve optimal sleep by offering" " exceptional products and customer service."),
" exceptional products and customer service." conversation_purpose=
), ("find out whether they are looking to achieve better sleep via buying a premier"
conversation_purpose=( " mattress."),
"find out whether they are looking to achieve better sleep via buying a premier"
" mattress."
),
conversation_history=[], conversation_history=[],
conversation_type="call", conversation_type="call",
conversation_stage=conversation_stages.get( conversation_stage=conversation_stages.get(
"1", "1",
( ("Introduction: Start the conversation by introducing yourself and your"
"Introduction: Start the conversation by introducing yourself and your" " company. Be polite and respectful while keeping the tone of the"
" company. Be polite and respectful while keeping the tone of the" " conversation professional."),
" conversation professional."
),
), ),
use_tools=True, use_tools=True,
product_catalog="sample_product_catalog.txt", product_catalog="sample_product_catalog.txt",

@ -1,9 +1,11 @@
class PromptRefiner: class PromptRefiner:
def __init__(self, system_prompt: str, llm): def __init__(self, system_prompt: str, llm):
super().__init__() super().__init__()
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.llm = llm self.llm = llm
def run(self, task: str): def run(self, task: str):
refine = self.llm(f"System Prompt: {self.system_prompt} Current task: {task}") refine = self.llm(
f"System Prompt: {self.system_prompt} Current task: {task}")
return refine return refine

@ -10,6 +10,7 @@ class Registry(BaseModel):
entries: Dict = {} entries: Dict = {}
def register(self, key: str): def register(self, key: str):
def decorator(class_builder): def decorator(class_builder):
self.entries[key] = class_builder self.entries[key] = class_builder
return class_builder return class_builder
@ -20,8 +21,7 @@ class Registry(BaseModel):
if type not in self.entries: if type not in self.entries:
raise ValueError( raise ValueError(
f"{type} is not registered. Please register with the" f"{type} is not registered. Please register with the"
f' .register("{type}") method provided in {self.name} registry' f' .register("{type}") method provided in {self.name} registry')
)
return self.entries[type](**kwargs) return self.entries[type](**kwargs)
def get_all_entries(self): def get_all_entries(self):

@ -29,7 +29,8 @@ class SimpleAgent:
def run(self, task: str) -> str: def run(self, task: str) -> str:
"""Run method""" """Run method"""
metrics = print(colored(f"Agent {self.name} is running task: {task}", "red")) metrics = print(
colored(f"Agent {self.name} is running task: {task}", "red"))
print(metrics) print(metrics)
response = self.flow.run(task) response = self.flow.run(task)

@ -10,9 +10,8 @@ from marshmallow.exceptions import RegistryError
@define @define
class BaseArtifact(ABC): class BaseArtifact(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
name: str = field( name: str = field(default=Factory(lambda self: self.id, takes_self=True),
default=Factory(lambda self: self.id, takes_self=True), kw_only=True kw_only=True)
)
value: any = field() value: any = field()
type: str = field( type: str = field(
default=Factory(lambda self: self.__class__.__name__, takes_self=True), default=Factory(lambda self: self.__class__.__name__, takes_self=True),
@ -54,7 +53,8 @@ class BaseArtifact(ABC):
class_registry.register("ListArtifact", ListArtifactSchema) class_registry.register("ListArtifact", ListArtifactSchema)
try: try:
return class_registry.get_class(artifact_dict["type"])().load(artifact_dict) return class_registry.get_class(
artifact_dict["type"])().load(artifact_dict)
except RegistryError: except RegistryError:
raise ValueError("Unsupported artifact type") raise ValueError("Unsupported artifact type")

@ -15,8 +15,7 @@ class Artifact(BaseModel):
artifact_id: StrictStr = Field(..., description="ID of the artifact") artifact_id: StrictStr = Field(..., description="ID of the artifact")
file_name: StrictStr = Field(..., description="Filename of the artifact") file_name: StrictStr = Field(..., description="Filename of the artifact")
relative_path: Optional[StrictStr] = Field( relative_path: Optional[StrictStr] = Field(
None, description="Relative path of the artifact" None, description="Relative path of the artifact")
)
__properties = ["artifact_id", "file_name", "relative_path"] __properties = ["artifact_id", "file_name", "relative_path"]
class Config: class Config:
@ -49,12 +48,10 @@ class Artifact(BaseModel):
if not isinstance(obj, dict): if not isinstance(obj, dict):
return Artifact.parse_obj(obj) return Artifact.parse_obj(obj)
_obj = Artifact.parse_obj( _obj = Artifact.parse_obj({
{ "artifact_id": obj.get("artifact_id"),
"artifact_id": obj.get("artifact_id"), "file_name": obj.get("file_name"),
"file_name": obj.get("file_name"), "relative_path": obj.get("relative_path"),
"relative_path": obj.get("relative_path"), })
}
)
return _obj return _obj

@ -3,7 +3,6 @@
# from swarms.chunkers.text import TextChunker # from swarms.chunkers.text import TextChunker
# from swarms.chunkers.pdf import PdfChunker # from swarms.chunkers.pdf import PdfChunker
# __all__ = [ # __all__ = [
# "BaseChunker", # "BaseChunker",
# "ChunkSeparator", # "ChunkSeparator",

@ -48,15 +48,13 @@ class BaseChunker(ABC):
kw_only=True, kw_only=True,
) )
tokenizer: OpenAITokenizer = field( tokenizer: OpenAITokenizer = field(
default=Factory( default=Factory(lambda: OpenAITokenizer(
lambda: OpenAITokenizer( model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)),
model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
)
),
kw_only=True, kw_only=True,
) )
max_tokens: int = field( max_tokens: int = field(
default=Factory(lambda self: self.tokenizer.max_tokens, takes_self=True), default=Factory(lambda self: self.tokenizer.max_tokens,
takes_self=True),
kw_only=True, kw_only=True,
) )
@ -66,8 +64,9 @@ class BaseChunker(ABC):
return [TextArtifact(c) for c in self._chunk_recursively(text)] return [TextArtifact(c) for c in self._chunk_recursively(text)]
def _chunk_recursively( def _chunk_recursively(
self, chunk: str, current_separator: Optional[ChunkSeparator] = None self,
) -> list[str]: chunk: str,
current_separator: Optional[ChunkSeparator] = None) -> list[str]:
token_count = self.tokenizer.count_tokens(chunk) token_count = self.tokenizer.count_tokens(chunk)
if token_count <= self.max_tokens: if token_count <= self.max_tokens:
@ -79,7 +78,8 @@ class BaseChunker(ABC):
half_token_count = token_count // 2 half_token_count = token_count // 2
if current_separator: if current_separator:
separators = self.separators[self.separators.index(current_separator) :] separators = self.separators[self.separators.
index(current_separator):]
else: else:
separators = self.separators separators = self.separators
@ -102,26 +102,19 @@ class BaseChunker(ABC):
if separator.is_prefix: if separator.is_prefix:
first_subchunk = separator.value + separator.value.join( first_subchunk = separator.value + separator.value.join(
subchanks[: balance_index + 1] subchanks[:balance_index + 1])
)
second_subchunk = separator.value + separator.value.join( second_subchunk = separator.value + separator.value.join(
subchanks[balance_index + 1 :] subchanks[balance_index + 1:])
)
else: else:
first_subchunk = ( first_subchunk = (separator.value.join(
separator.value.join(subchanks[: balance_index + 1]) subchanks[:balance_index + 1]) + separator.value)
+ separator.value
)
second_subchunk = separator.value.join( second_subchunk = separator.value.join(
subchanks[balance_index + 1 :] subchanks[balance_index + 1:])
)
first_subchunk_rec = self._chunk_recursively( first_subchunk_rec = self._chunk_recursively(
first_subchunk.strip(), separator first_subchunk.strip(), separator)
)
second_subchunk_rec = self._chunk_recursively( second_subchunk_rec = self._chunk_recursively(
second_subchunk.strip(), separator second_subchunk.strip(), separator)
)
if first_subchunk_rec and second_subchunk_rec: if first_subchunk_rec and second_subchunk_rec:
return first_subchunk_rec + second_subchunk_rec return first_subchunk_rec + second_subchunk_rec

@ -76,8 +76,7 @@ class OmniChunker:
colored( colored(
f"Could not decode file with extension {file_extension}: {e}", f"Could not decode file with extension {file_extension}: {e}",
"yellow", "yellow",
) ))
)
return "" return ""
def chunk_content(self, content: str) -> List[str]: def chunk_content(self, content: str) -> List[str]:
@ -91,7 +90,7 @@ class OmniChunker:
List[str]: The list of chunks. List[str]: The list of chunks.
""" """
return [ return [
content[i : i + self.chunk_size] content[i:i + self.chunk_size]
for i in range(0, len(content), self.chunk_size) for i in range(0, len(content), self.chunk_size)
] ]
@ -113,5 +112,4 @@ class OmniChunker:
{self.metrics()} {self.metrics()}
""", """,
"cyan", "cyan",
) ))
)

@ -18,9 +18,9 @@ class AsanaReader(BaseReader):
self.client = asana.Client.access_token(asana_token) self.client = asana.Client.access_token(asana_token)
def load_data( def load_data(self,
self, workspace_id: Optional[str] = None, project_id: Optional[str] = None workspace_id: Optional[str] = None,
) -> List[Document]: project_id: Optional[str] = None) -> List[Document]:
"""Load data from the workspace. """Load data from the workspace.
Args: Args:
@ -31,18 +31,20 @@ class AsanaReader(BaseReader):
""" """
if workspace_id is None and project_id is None: if workspace_id is None and project_id is None:
raise ValueError("Either workspace_id or project_id must be provided") raise ValueError(
"Either workspace_id or project_id must be provided")
if workspace_id is not None and project_id is not None: if workspace_id is not None and project_id is not None:
raise ValueError( raise ValueError(
"Only one of workspace_id or project_id should be provided" "Only one of workspace_id or project_id should be provided")
)
results = [] results = []
if workspace_id is not None: if workspace_id is not None:
workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"] workspace_name = self.client.workspaces.find_by_id(
projects = self.client.projects.find_all({"workspace": workspace_id}) workspace_id)["name"]
projects = self.client.projects.find_all(
{"workspace": workspace_id})
# Case: Only project_id is provided # Case: Only project_id is provided
else: # since we've handled the other cases, this means project_id is not None else: # since we've handled the other cases, this means project_id is not None
@ -50,54 +52,58 @@ class AsanaReader(BaseReader):
workspace_name = projects[0]["workspace"]["name"] workspace_name = projects[0]["workspace"]["name"]
for project in projects: for project in projects:
tasks = self.client.tasks.find_all( tasks = self.client.tasks.find_all({
{ "project":
"project": project["gid"], project["gid"],
"opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", "opt_fields":
} "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields",
) })
for task in tasks: for task in tasks:
stories = self.client.tasks.stories(task["gid"], opt_fields="type,text") stories = self.client.tasks.stories(task["gid"],
comments = "\n".join( opt_fields="type,text")
[ comments = "\n".join([
story["text"] story["text"]
for story in stories for story in stories
if story.get("type") == "comment" and "text" in story if story.get("type") == "comment" and "text" in story
] ])
)
task_metadata = { task_metadata = {
"task_id": task.get("gid", ""), "task_id":
"name": task.get("name", ""), task.get("gid", ""),
"name":
task.get("name", ""),
"assignee": (task.get("assignee") or {}).get("name", ""), "assignee": (task.get("assignee") or {}).get("name", ""),
"completed_on": task.get("completed_at", ""), "completed_on":
"completed_by": (task.get("completed_by") or {}).get("name", ""), task.get("completed_at", ""),
"project_name": project.get("name", ""), "completed_by": (task.get("completed_by") or
{}).get("name", ""),
"project_name":
project.get("name", ""),
"custom_fields": [ "custom_fields": [
i["display_value"] i["display_value"]
for i in task.get("custom_fields") for i in task.get("custom_fields")
if task.get("custom_fields") is not None if task.get("custom_fields") is not None
], ],
"workspace_name": workspace_name, "workspace_name":
"url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}", workspace_name,
"url":
f"https://app.asana.com/0/{project['gid']}/{task['gid']}",
} }
if task.get("followers") is not None: if task.get("followers") is not None:
task_metadata["followers"] = [ task_metadata["followers"] = [
i.get("name") for i in task.get("followers") if "name" in i i.get("name")
for i in task.get("followers")
if "name" in i
] ]
else: else:
task_metadata["followers"] = [] task_metadata["followers"] = []
results.append( results.append(
Document( Document(
text=task.get("name", "") text=task.get("name", "") + " " +
+ " " task.get("notes", "") + " " + comments,
+ task.get("notes", "")
+ " "
+ comments,
extra_info=task_metadata, extra_info=task_metadata,
) ))
)
return results return results

@ -15,7 +15,6 @@ if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument from haystack.schema import Document as HaystackDocument
from semantic_kernel.memory.memory_record import MemoryRecord from semantic_kernel.memory.memory_record import MemoryRecord
#### ####
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}" DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
DEFAULT_METADATA_TMPL = "{key}: {value}" DEFAULT_METADATA_TMPL = "{key}: {value}"
@ -48,7 +47,8 @@ class BaseComponent(BaseModel):
# TODO: return type here not supported by current mypy version # TODO: return type here not supported by current mypy version
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore def from_dict(cls, data: Dict[str, Any],
**kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict): if isinstance(kwargs, dict):
data.update(kwargs) data.update(kwargs)
@ -119,13 +119,10 @@ class BaseNode(BaseComponent):
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
id_: str = Field( id_: str = Field(default_factory=lambda: str(uuid.uuid4()),
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node." description="Unique ID of the node.")
)
embedding: Optional[List[float]] = Field( embedding: Optional[List[float]] = Field(
default=None, description="Embedding of the node." default=None, description="Embedding of the node.")
)
"""" """"
metadata fields metadata fields
- injected as part of the text shown to LLMs as context - injected as part of the text shown to LLMs as context
@ -140,7 +137,8 @@ class BaseNode(BaseComponent):
) )
excluded_embed_metadata_keys: List[str] = Field( excluded_embed_metadata_keys: List[str] = Field(
default_factory=list, default_factory=list,
description="Metadata keys that are excluded from text for the embed model.", description=
"Metadata keys that are excluded from text for the embed model.",
) )
excluded_llm_metadata_keys: List[str] = Field( excluded_llm_metadata_keys: List[str] = Field(
default_factory=list, default_factory=list,
@ -158,7 +156,8 @@ class BaseNode(BaseComponent):
"""Get Object type.""" """Get Object type."""
@abstractmethod @abstractmethod
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str: def get_content(self,
metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
"""Get object content.""" """Get object content."""
@abstractmethod @abstractmethod
@ -189,7 +188,8 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.SOURCE] relation = self.relationships[NodeRelationship.SOURCE]
if isinstance(relation, list): if isinstance(relation, list):
raise ValueError("Source object must be a single RelatedNodeInfo object") raise ValueError(
"Source object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -200,7 +200,8 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.PREVIOUS] relation = self.relationships[NodeRelationship.PREVIOUS]
if not isinstance(relation, RelatedNodeInfo): if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Previous object must be a single RelatedNodeInfo object") raise ValueError(
"Previous object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -211,7 +212,8 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.NEXT] relation = self.relationships[NodeRelationship.NEXT]
if not isinstance(relation, RelatedNodeInfo): if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Next object must be a single RelatedNodeInfo object") raise ValueError(
"Next object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -222,7 +224,8 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.PARENT] relation = self.relationships[NodeRelationship.PARENT]
if not isinstance(relation, RelatedNodeInfo): if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Parent object must be a single RelatedNodeInfo object") raise ValueError(
"Parent object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -233,7 +236,8 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.CHILD] relation = self.relationships[NodeRelationship.CHILD]
if not isinstance(relation, list): if not isinstance(relation, list):
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.") raise ValueError(
"Child objects must be a list of RelatedNodeInfo objects.")
return relation return relation
@property @property
@ -250,12 +254,10 @@ class BaseNode(BaseComponent):
return self.metadata return self.metadata
def __str__(self) -> str: def __str__(self) -> str:
source_text_truncated = truncate_text( source_text_truncated = truncate_text(self.get_content().strip(),
self.get_content().strip(), TRUNCATE_LENGTH TRUNCATE_LENGTH)
) source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n",
source_text_wrapped = textwrap.fill( width=WRAP_WIDTH)
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Node ID: {self.node_id}\n{source_text_wrapped}" return f"Node ID: {self.node_id}\n{source_text_wrapped}"
def get_embedding(self) -> List[float]: def get_embedding(self) -> List[float]:
@ -281,28 +283,23 @@ class BaseNode(BaseComponent):
class TextNode(BaseNode): class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.") text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field( start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node." default=None, description="Start char index of the node.")
)
end_char_idx: Optional[int] = Field( end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node." default=None, description="End char index of the node.")
)
text_template: str = Field( text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL, default=DEFAULT_TEXT_NODE_TMPL,
description=( description=("Template for how text is formatted, with {content} and "
"Template for how text is formatted, with {content} and " "{metadata_str} placeholders."),
"{metadata_str} placeholders."
),
) )
metadata_template: str = Field( metadata_template: str = Field(
default=DEFAULT_METADATA_TMPL, default=DEFAULT_METADATA_TMPL,
description=( description=("Template for how metadata is formatted, with {key} and "
"Template for how metadata is formatted, with {key} and " "{value} placeholders."),
"{value} placeholders."
),
) )
metadata_seperator: str = Field( metadata_seperator: str = Field(
default="\n", default="\n",
description="Separator between metadata fields when converting to string.", description=
"Separator between metadata fields when converting to string.",
) )
@classmethod @classmethod
@ -316,8 +313,7 @@ class TextNode(BaseNode):
metadata = values.get("metadata", {}) metadata = values.get("metadata", {})
doc_identity = str(text) + str(metadata) doc_identity = str(text) + str(metadata)
values["hash"] = str( values["hash"] = str(
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest() sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest())
)
return values return values
@classmethod @classmethod
@ -325,15 +321,15 @@ class TextNode(BaseNode):
"""Get Object type.""" """Get Object type."""
return ObjectType.TEXT return ObjectType.TEXT
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: def get_content(self,
metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get object content.""" """Get object content."""
metadata_str = self.get_metadata_str(mode=metadata_mode).strip() metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
if not metadata_str: if not metadata_str:
return self.text return self.text
return self.text_template.format( return self.text_template.format(content=self.text,
content=self.text, metadata_str=metadata_str metadata_str=metadata_str).strip()
).strip()
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata info string.""" """Metadata info string."""
@ -350,13 +346,11 @@ class TextNode(BaseNode):
if key in usable_metadata_keys: if key in usable_metadata_keys:
usable_metadata_keys.remove(key) usable_metadata_keys.remove(key)
return self.metadata_seperator.join( return self.metadata_seperator.join([
[ self.metadata_template.format(key=key, value=str(value))
self.metadata_template.format(key=key, value=str(value)) for key, value in self.metadata.items()
for key, value in self.metadata.items() if key in usable_metadata_keys
if key in usable_metadata_keys ])
]
)
def set_content(self, value: str) -> None: def set_content(self, value: str) -> None:
"""Set the content of the node.""" """Set the content of the node."""
@ -480,7 +474,8 @@ class NodeWithScore(BaseComponent):
else: else:
raise ValueError("Node must be a TextNode to get text.") raise ValueError("Node must be a TextNode to get text.")
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: def get_content(self,
metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return self.node.get_content(metadata_mode=metadata_mode) return self.node.get_content(metadata_mode=metadata_mode)
def get_embedding(self) -> List[float]: def get_embedding(self) -> List[float]:
@ -517,12 +512,10 @@ class Document(TextNode):
return self.id_ return self.id_
def __str__(self) -> str: def __str__(self) -> str:
source_text_truncated = truncate_text( source_text_truncated = truncate_text(self.get_content().strip(),
self.get_content().strip(), TRUNCATE_LENGTH TRUNCATE_LENGTH)
) source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n",
source_text_wrapped = textwrap.fill( width=WRAP_WIDTH)
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
def get_doc_id(self) -> str: def get_doc_id(self) -> str:
@ -538,22 +531,27 @@ class Document(TextNode):
"""Convert struct to Haystack document format.""" """Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument from haystack.schema import Document as HaystackDocument
return HaystackDocument( return HaystackDocument(content=self.text,
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_ meta=self.metadata,
) embedding=self.embedding,
id=self.id_)
@classmethod @classmethod
def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": def from_haystack_format(cls, doc: "HaystackDocument") -> "Document":
"""Convert struct from Haystack document format.""" """Convert struct from Haystack document format."""
return cls( return cls(text=doc.content,
text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id metadata=doc.meta,
) embedding=doc.embedding,
id_=doc.id)
def to_embedchain_format(self) -> Dict[str, Any]: def to_embedchain_format(self) -> Dict[str, Any]:
"""Convert struct to EmbedChain document format.""" """Convert struct to EmbedChain document format."""
return { return {
"doc_id": self.id_, "doc_id": self.id_,
"data": {"content": self.text, "meta_data": self.metadata}, "data": {
"content": self.text,
"meta_data": self.metadata
},
} }
@classmethod @classmethod
@ -583,7 +581,8 @@ class Document(TextNode):
return cls( return cls(
text=doc._text, text=doc._text,
metadata={"additional_metadata": doc._additional_metadata}, metadata={"additional_metadata": doc._additional_metadata},
embedding=doc._embedding.tolist() if doc._embedding is not None else None, embedding=doc._embedding.tolist()
if doc._embedding is not None else None,
id_=doc._id, id_=doc._id,
) )
@ -591,7 +590,10 @@ class Document(TextNode):
def example(cls) -> "Document": def example(cls) -> "Document":
return Document( return Document(
text=SAMPLE_TEXT, text=SAMPLE_TEXT,
metadata={"filename": "README.md", "category": "codebase"}, metadata={
"filename": "README.md",
"category": "codebase"
},
) )
@classmethod @classmethod

@ -30,32 +30,25 @@ class BaseVectorStore(ABC):
embedding_driver: Any embedding_driver: Any
futures_executor: futures.Executor = field( futures_executor: futures.Executor = field(
default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)
)
def upsert_text_artifacts(self,
def upsert_text_artifacts( artifacts: dict[str, list[TextArtifact]],
self, meta: Optional[dict] = None,
artifacts: dict[str, list[TextArtifact]], **kwargs) -> None:
meta: Optional[dict] = None, execute_futures_dict({
**kwargs namespace:
) -> None: self.futures_executor.submit(self.upsert_text_artifact, a,
execute_futures_dict( namespace, meta, **kwargs)
{ for namespace, artifact_list in artifacts.items()
namespace: self.futures_executor.submit( for a in artifact_list
self.upsert_text_artifact, a, namespace, meta, **kwargs })
)
for namespace, artifact_list in artifacts.items() def upsert_text_artifact(self,
for a in artifact_list artifact: TextArtifact,
} namespace: Optional[str] = None,
) meta: Optional[dict] = None,
**kwargs) -> str:
def upsert_text_artifact(
self,
artifact: TextArtifact,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs
) -> str:
if not meta: if not meta:
meta = {} meta = {}
@ -66,39 +59,37 @@ class BaseVectorStore(ABC):
else: else:
vector = artifact.generate_embedding(self.embedding_driver) vector = artifact.generate_embedding(self.embedding_driver)
return self.upsert_vector( return self.upsert_vector(vector,
vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs vector_id=artifact.id,
) namespace=namespace,
meta=meta,
def upsert_text( **kwargs)
self,
string: str, def upsert_text(self,
vector_id: Optional[str] = None, string: str,
namespace: Optional[str] = None, vector_id: Optional[str] = None,
meta: Optional[dict] = None, namespace: Optional[str] = None,
**kwargs meta: Optional[dict] = None,
) -> str: **kwargs) -> str:
return self.upsert_vector( return self.upsert_vector(self.embedding_driver.embed_string(string),
self.embedding_driver.embed_string(string), vector_id=vector_id,
vector_id=vector_id, namespace=namespace,
namespace=namespace, meta=meta if meta else {},
meta=meta if meta else {}, **kwargs)
**kwargs
)
@abstractmethod @abstractmethod
def upsert_vector( def upsert_vector(self,
self, vector: list[float],
vector: list[float], vector_id: Optional[str] = None,
vector_id: Optional[str] = None, namespace: Optional[str] = None,
namespace: Optional[str] = None, meta: Optional[dict] = None,
meta: Optional[dict] = None, **kwargs) -> str:
**kwargs
) -> str:
... ...
@abstractmethod @abstractmethod
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry: def load_entry(self,
vector_id: str,
namespace: Optional[str] = None) -> Entry:
... ...
@abstractmethod @abstractmethod
@ -106,12 +97,10 @@ class BaseVectorStore(ABC):
... ...
@abstractmethod @abstractmethod
def query( def query(self,
self, query: str,
query: str, count: Optional[int] = None,
count: Optional[int] = None, namespace: Optional[str] = None,
namespace: Optional[str] = None, include_vectors: bool = False,
include_vectors: bool = False, **kwargs) -> list[QueryResult]:
**kwargs
) -> list[QueryResult]:
... ...

@ -80,10 +80,8 @@ class Chroma(VectorStore):
import chromadb import chromadb
import chromadb.config import chromadb.config
except ImportError: except ImportError:
raise ImportError( raise ImportError("Could not import chromadb python package. "
"Could not import chromadb python package. " "Please install it with `pip install chromadb`.")
"Please install it with `pip install chromadb`."
)
if client is not None: if client is not None:
self._client_settings = client_settings self._client_settings = client_settings
@ -94,8 +92,7 @@ class Chroma(VectorStore):
# If client_settings is provided with persist_directory specified, # If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode. # then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = ( client_settings.persist_directory = (
persist_directory or client_settings.persist_directory persist_directory or client_settings.persist_directory)
)
if client_settings.persist_directory is not None: if client_settings.persist_directory is not None:
# Maintain backwards compatibility with chromadb < 0.4.0 # Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".") major, minor, _ = chromadb.__version__.split(".")
@ -108,25 +105,23 @@ class Chroma(VectorStore):
major, minor, _ = chromadb.__version__.split(".") major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4: if int(major) == 0 and int(minor) < 4:
_client_settings = chromadb.config.Settings( _client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet", chroma_db_impl="duckdb+parquet",)
)
else: else:
_client_settings = chromadb.config.Settings(is_persistent=True) _client_settings = chromadb.config.Settings(
is_persistent=True)
_client_settings.persist_directory = persist_directory _client_settings.persist_directory = persist_directory
else: else:
_client_settings = chromadb.config.Settings() _client_settings = chromadb.config.Settings()
self._client_settings = _client_settings self._client_settings = _client_settings
self._client = chromadb.Client(_client_settings) self._client = chromadb.Client(_client_settings)
self._persist_directory = ( self._persist_directory = (_client_settings.persist_directory or
_client_settings.persist_directory or persist_directory persist_directory)
)
self._embedding_function = embedding_function self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=collection_name, name=collection_name,
embedding_function=self._embedding_function.embed_documents embedding_function=self._embedding_function.embed_documents
if self._embedding_function is not None if self._embedding_function is not None else None,
else None,
metadata=collection_metadata, metadata=collection_metadata,
) )
self.override_relevance_score_fn = relevance_score_fn self.override_relevance_score_fn = relevance_score_fn
@ -149,10 +144,8 @@ class Chroma(VectorStore):
try: try:
import chromadb # noqa: F401 import chromadb # noqa: F401
except ImportError: except ImportError:
raise ValueError( raise ValueError("Could not import chromadb python package. "
"Could not import chromadb python package. " "Please install it with `pip install chromadb`.")
"Please install it with `pip install chromadb`."
)
return self._collection.query( return self._collection.query(
query_texts=query_texts, query_texts=query_texts,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
@ -202,9 +195,9 @@ class Chroma(VectorStore):
if non_empty_ids: if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids] metadatas = [metadatas[idx] for idx in non_empty_ids]
texts_with_metadatas = [texts[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
embeddings_with_metadatas = ( embeddings_with_metadatas = ([
[embeddings[idx] for idx in non_empty_ids] if embeddings else None embeddings[idx] for idx in non_empty_ids
) ] if embeddings else None)
ids_with_metadata = [ids[idx] for idx in non_empty_ids] ids_with_metadata = [ids[idx] for idx in non_empty_ids]
try: try:
self._collection.upsert( self._collection.upsert(
@ -225,8 +218,7 @@ class Chroma(VectorStore):
if empty_ids: if empty_ids:
texts_without_metadatas = [texts[j] for j in empty_ids] texts_without_metadatas = [texts[j] for j in empty_ids]
embeddings_without_metadatas = ( embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None [embeddings[j] for j in empty_ids] if embeddings else None)
)
ids_without_metadatas = [ids[j] for j in empty_ids] ids_without_metadatas = [ids[j] for j in empty_ids]
self._collection.upsert( self._collection.upsert(
embeddings=embeddings_without_metadatas, embeddings=embeddings_without_metadatas,
@ -258,7 +250,9 @@ class Chroma(VectorStore):
Returns: Returns:
List[Document]: List of documents most similar to the query text. List[Document]: List of documents most similar to the query text.
""" """
docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) docs_and_scores = self.similarity_search_with_score(query,
k,
filter=filter)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def similarity_search_by_vector( def similarity_search_by_vector(
@ -381,8 +375,7 @@ class Chroma(VectorStore):
raise ValueError( raise ValueError(
"No supported normalization function" "No supported normalization function"
f" for distance metric of type: {distance}." f" for distance metric of type: {distance}."
"Consider providing relevance_score_fn to Chroma constructor." "Consider providing relevance_score_fn to Chroma constructor.")
)
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
@ -428,7 +421,9 @@ class Chroma(VectorStore):
candidates = _results_to_docs(results) candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] selected_results = [
r for i, r in enumerate(candidates) if i in mmr_selected
]
return selected_results return selected_results
def max_marginal_relevance_search( def max_marginal_relevance_search(
@ -523,10 +518,8 @@ class Chroma(VectorStore):
It will also be called automatically when the object is destroyed. It will also be called automatically when the object is destroyed.
""" """
if self._persist_directory is None: if self._persist_directory is None:
raise ValueError( raise ValueError("You must specify a persist_directory on"
"You must specify a persist_directory on" "creation to persist the collection.")
"creation to persist the collection."
)
import chromadb import chromadb
# Maintain backwards compatibility with chromadb < 0.4.0 # Maintain backwards compatibility with chromadb < 0.4.0
@ -543,7 +536,8 @@ class Chroma(VectorStore):
""" """
return self.update_documents([document_id], [document]) return self.update_documents([document_id], [document])
def update_documents(self, ids: List[str], documents: List[Document]) -> None: def update_documents(self, ids: List[str],
documents: List[Document]) -> None:
"""Update a document in the collection. """Update a document in the collection.
Args: Args:
@ -558,17 +552,16 @@ class Chroma(VectorStore):
) )
embeddings = self._embedding_function.embed_documents(text) embeddings = self._embedding_function.embed_documents(text)
if hasattr( if hasattr(self._collection._client,
self._collection._client, "max_batch_size" "max_batch_size"): # for Chroma 0.4.10 and above
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
for batch in create_batches( for batch in create_batches(
api=self._collection._client, api=self._collection._client,
ids=ids, ids=ids,
metadatas=metadata, metadatas=metadata,
documents=text, documents=text,
embeddings=embeddings, embeddings=embeddings,
): ):
self._collection.update( self._collection.update(
ids=batch[0], ids=batch[0],
@ -628,16 +621,15 @@ class Chroma(VectorStore):
) )
if ids is None: if ids is None:
ids = [str(uuid.uuid1()) for _ in texts] ids = [str(uuid.uuid1()) for _ in texts]
if hasattr( if hasattr(chroma_collection._client,
chroma_collection._client, "max_batch_size" "max_batch_size"): # for Chroma 0.4.10 and above
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
for batch in create_batches( for batch in create_batches(
api=chroma_collection._client, api=chroma_collection._client,
ids=ids, ids=ids,
metadatas=metadatas, metadatas=metadatas,
documents=texts, documents=texts,
): ):
chroma_collection.add_texts( chroma_collection.add_texts(
texts=batch[3] if batch[3] else [], texts=batch[3] if batch[3] else [],
@ -645,7 +637,9 @@ class Chroma(VectorStore):
ids=batch[0], ids=batch[0],
) )
else: else:
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) chroma_collection.add_texts(texts=texts,
metadatas=metadatas,
ids=ids)
return chroma_collection return chroma_collection
@classmethod @classmethod

@ -19,8 +19,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
if X.shape[1] != Y.shape[1]: if X.shape[1] != Y.shape[1]:
raise ValueError( raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}." f"and Y has shape {Y.shape}.")
)
try: try:
import simsimd as simd import simsimd as simd
@ -33,8 +32,7 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
except ImportError: except ImportError:
logger.info( logger.info(
"Unable to import simsimd, defaulting to NumPy implementation. If you want " "Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`." "to use simsimd please install with `pip install simsimd`.")
)
X_norm = np.linalg.norm(X, axis=1) X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1) Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below. # Ignore divide by zero errors run time warnings as those are handled below.

@ -27,6 +27,7 @@ class NotFoundException(Exception):
class TaskDB(ABC): class TaskDB(ABC):
async def create_task( async def create_task(
self, self,
input: Optional[str], input: Optional[str],
@ -67,9 +68,9 @@ class TaskDB(ABC):
async def list_tasks(self) -> List[Task]: async def list_tasks(self) -> List[Task]:
raise NotImplementedError raise NotImplementedError
async def list_steps( async def list_steps(self,
self, task_id: str, status: Optional[Status] = None task_id: str,
) -> List[Step]: status: Optional[Status] = None) -> List[Step]:
raise NotImplementedError raise NotImplementedError
@ -136,8 +137,8 @@ class InMemoryTaskDB(TaskDB):
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
task = await self.get_task(task_id) task = await self.get_task(task_id)
artifact = next( artifact = next(
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None filter(lambda a: a.artifact_id == artifact_id, task.artifacts),
) None)
if not artifact: if not artifact:
raise NotFoundException("Artifact", artifact_id) raise NotFoundException("Artifact", artifact_id)
return artifact return artifact
@ -150,9 +151,9 @@ class InMemoryTaskDB(TaskDB):
step_id: Optional[str] = None, step_id: Optional[str] = None,
) -> Artifact: ) -> Artifact:
artifact_id = str(uuid.uuid4()) artifact_id = str(uuid.uuid4())
artifact = Artifact( artifact = Artifact(artifact_id=artifact_id,
artifact_id=artifact_id, file_name=file_name, relative_path=relative_path file_name=file_name,
) relative_path=relative_path)
task = await self.get_task(task_id) task = await self.get_task(task_id)
task.artifacts.append(artifact) task.artifacts.append(artifact)
@ -165,9 +166,9 @@ class InMemoryTaskDB(TaskDB):
async def list_tasks(self) -> List[Task]: async def list_tasks(self) -> List[Task]:
return [task for task in self._tasks.values()] return [task for task in self._tasks.values()]
async def list_steps( async def list_steps(self,
self, task_id: str, status: Optional[Status] = None task_id: str,
) -> List[Step]: status: Optional[Status] = None) -> List[Step]:
task = await self.get_task(task_id) task = await self.get_task(task_id)
steps = task.steps steps = task.steps
if status: if status:

@ -63,8 +63,7 @@ class OceanDB:
try: try:
embedding_function = MultiModalEmbeddingFunction(modality=modality) embedding_function = MultiModalEmbeddingFunction(modality=modality)
collection = self.client.create_collection( collection = self.client.create_collection(
collection_name, embedding_function=embedding_function collection_name, embedding_function=embedding_function)
)
return collection return collection
except Exception as e: except Exception as e:
logging.error(f"Failed to create collection. Error {e}") logging.error(f"Failed to create collection. Error {e}")
@ -91,7 +90,8 @@ class OceanDB:
try: try:
return collection.add(documents=[document], ids=[id]) return collection.add(documents=[document], ids=[id])
except Exception as e: except Exception as e:
logging.error(f"Failed to append document to the collection. Error {e}") logging.error(
f"Failed to append document to the collection. Error {e}")
raise raise
def add_documents(self, collection, documents: List[str], ids: List[str]): def add_documents(self, collection, documents: List[str], ids: List[str]):
@ -137,7 +137,8 @@ class OceanDB:
the results of the query the results of the query
""" """
try: try:
results = collection.query(query_texts=query_texts, n_results=n_results) results = collection.query(query_texts=query_texts,
n_results=n_results)
return results return results
except Exception as e: except Exception as e:
logging.error(f"Failed to query the collection. Error {e}") logging.error(f"Failed to query the collection. Error {e}")

@ -88,12 +88,12 @@ class PgVectorVectorStore(BaseVectorStore):
create_engine_params: dict = field(factory=dict, kw_only=True) create_engine_params: dict = field(factory=dict, kw_only=True)
engine: Optional[Engine] = field(default=None, kw_only=True) engine: Optional[Engine] = field(default=None, kw_only=True)
table_name: str = field(kw_only=True) table_name: str = field(kw_only=True)
_model: any = field( _model: any = field(default=Factory(
default=Factory(lambda self: self.default_vector_model(), takes_self=True) lambda self: self.default_vector_model(), takes_self=True))
)
@connection_string.validator @connection_string.validator
def validate_connection_string(self, _, connection_string: Optional[str]) -> None: def validate_connection_string(self, _,
connection_string: Optional[str]) -> None:
# If an engine is provided, the connection string is not used. # If an engine is provided, the connection string is not used.
if self.engine is not None: if self.engine is not None:
return return
@ -122,9 +122,8 @@ class PgVectorVectorStore(BaseVectorStore):
If not, a connection string is used to create a new database connection here. If not, a connection string is used to create a new database connection here.
""" """
if self.engine is None: if self.engine is None:
self.engine = create_engine( self.engine = create_engine(self.connection_string,
self.connection_string, **self.create_engine_params **self.create_engine_params)
)
def setup( def setup(
self, self,
@ -142,14 +141,12 @@ class PgVectorVectorStore(BaseVectorStore):
if create_schema: if create_schema:
self._model.metadata.create_all(self.engine) self._model.metadata.create_all(self.engine)
def upsert_vector( def upsert_vector(self,
self, vector: list[float],
vector: list[float], vector_id: Optional[str] = None,
vector_id: Optional[str] = None, namespace: Optional[str] = None,
namespace: Optional[str] = None, meta: Optional[dict] = None,
meta: Optional[dict] = None, **kwargs) -> str:
**kwargs
) -> str:
"""Inserts or updates a vector in the collection.""" """Inserts or updates a vector in the collection."""
with Session(self.engine) as session: with Session(self.engine) as session:
obj = self._model( obj = self._model(
@ -164,9 +161,9 @@ class PgVectorVectorStore(BaseVectorStore):
return str(obj.id) return str(obj.id)
def load_entry( def load_entry(self,
self, vector_id: str, namespace: Optional[str] = None vector_id: str,
) -> BaseVectorStore.Entry: namespace: Optional[str] = None) -> BaseVectorStore.Entry:
"""Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
with Session(self.engine) as session: with Session(self.engine) as session:
result = session.get(self._model, vector_id) result = session.get(self._model, vector_id)
@ -179,8 +176,8 @@ class PgVectorVectorStore(BaseVectorStore):
) )
def load_entries( def load_entries(
self, namespace: Optional[str] = None self,
) -> list[BaseVectorStore.Entry]: namespace: Optional[str] = None) -> list[BaseVectorStore.Entry]:
"""Retrieves all vector entries from the collection, optionally filtering to only """Retrieves all vector entries from the collection, optionally filtering to only
those that match the provided namespace. those that match the provided namespace.
""" """
@ -197,19 +194,16 @@ class PgVectorVectorStore(BaseVectorStore):
vector=result.vector, vector=result.vector,
namespace=result.namespace, namespace=result.namespace,
meta=result.meta, meta=result.meta,
) ) for result in results
for result in results
] ]
def query( def query(self,
self, query: str,
query: str, count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT,
count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, namespace: Optional[str] = None,
namespace: Optional[str] = None, include_vectors: bool = False,
include_vectors: bool = False, distance_metric: str = "cosine_distance",
distance_metric: str = "cosine_distance", **kwargs) -> list[BaseVectorStore.QueryResult]:
**kwargs
) -> list[BaseVectorStore.QueryResult]:
"""Performs a search on the collection to find vectors similar to the provided input vector, """Performs a search on the collection to find vectors similar to the provided input vector,
optionally filtering to only those that match the provided namespace. optionally filtering to only those that match the provided namespace.
""" """
@ -245,8 +239,7 @@ class PgVectorVectorStore(BaseVectorStore):
score=result[1], score=result[1],
meta=result[0].meta, meta=result[0].meta,
namespace=result[0].namespace, namespace=result[0].namespace,
) ) for result in results
for result in results
] ]
def default_vector_model(self) -> any: def default_vector_model(self) -> any:

@ -102,14 +102,12 @@ class PineconeVectorStoreStore(BaseVector):
self.index = pinecone.Index(self.index_name) self.index = pinecone.Index(self.index_name)
def upsert_vector( def upsert_vector(self,
self, vector: list[float],
vector: list[float], vector_id: Optional[str] = None,
vector_id: Optional[str] = None, namespace: Optional[str] = None,
namespace: Optional[str] = None, meta: Optional[dict] = None,
meta: Optional[dict] = None, **kwargs) -> str:
**kwargs
) -> str:
"""Upsert vector""" """Upsert vector"""
vector_id = vector_id if vector_id else str_to_hash(str(vector)) vector_id = vector_id if vector_id else str_to_hash(str(vector))
@ -120,10 +118,12 @@ class PineconeVectorStoreStore(BaseVector):
return vector_id return vector_id
def load_entry( def load_entry(
self, vector_id: str, namespace: Optional[str] = None self,
) -> Optional[BaseVector.Entry]: vector_id: str,
namespace: Optional[str] = None) -> Optional[BaseVector.Entry]:
"""Load entry""" """Load entry"""
result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() result = self.index.fetch(ids=[vector_id],
namespace=namespace).to_dict()
vectors = list(result["vectors"].values()) vectors = list(result["vectors"].values())
if len(vectors) > 0: if len(vectors) > 0:
@ -138,7 +138,8 @@ class PineconeVectorStoreStore(BaseVector):
else: else:
return None return None
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVector.Entry]: def load_entries(self,
namespace: Optional[str] = None) -> list[BaseVector.Entry]:
"""Load entries""" """Load entries"""
# This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching
# all values from a namespace: # all values from a namespace:
@ -157,20 +158,18 @@ class PineconeVectorStoreStore(BaseVector):
vector=r["values"], vector=r["values"],
meta=r["metadata"], meta=r["metadata"],
namespace=results["namespace"], namespace=results["namespace"],
) ) for r in results["matches"]
for r in results["matches"]
] ]
def query( def query(
self, self,
query: str, query: str,
count: Optional[int] = None, count: Optional[int] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
include_vectors: bool = False, include_vectors: bool = False,
# PineconeVectorStoreStorageDriver-specific params: # PineconeVectorStoreStorageDriver-specific params:
include_metadata=True, include_metadata=True,
**kwargs **kwargs) -> list[BaseVector.QueryResult]:
) -> list[BaseVector.QueryResult]:
"""Query vectors""" """Query vectors"""
vector = self.embedding_driver.embed_string(query) vector = self.embedding_driver.embed_string(query)
@ -190,12 +189,14 @@ class PineconeVectorStoreStore(BaseVector):
score=r["score"], score=r["score"],
meta=r["metadata"], meta=r["metadata"],
namespace=results["namespace"], namespace=results["namespace"],
) ) for r in results["matches"]
for r in results["matches"]
] ]
def create_index(self, name: str, **kwargs) -> None: def create_index(self, name: str, **kwargs) -> None:
"""Create index""" """Create index"""
params = {"name": name, "dimension": self.embedding_driver.dimensions} | kwargs params = {
"name": name,
"dimension": self.embedding_driver.dimensions
} | kwargs
pinecone.create_index(**params) pinecone.create_index(**params)

@ -20,9 +20,9 @@ class Artifact(BaseModel):
description="Id of the artifact", description="Id of the artifact",
example="b225e278-8b4c-4f99-a696-8facf19f0e56", example="b225e278-8b4c-4f99-a696-8facf19f0e56",
) )
file_name: str = Field( file_name: str = Field(...,
..., description="Filename of the artifact", example="main.py" description="Filename of the artifact",
) example="main.py")
relative_path: Optional[str] = Field( relative_path: Optional[str] = Field(
None, None,
description="Relative path of the artifact in the agent's workspace", description="Relative path of the artifact in the agent's workspace",
@ -50,7 +50,8 @@ class StepInput(BaseModel):
class StepOutput(BaseModel): class StepOutput(BaseModel):
__root__: Any = Field( __root__: Any = Field(
..., ...,
description="Output that the task step has produced. Any value is allowed.", description=
"Output that the task step has produced. Any value is allowed.",
example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}',
) )
@ -81,9 +82,9 @@ class Task(TaskRequestBody):
class StepRequestBody(BaseModel): class StepRequestBody(BaseModel):
input: Optional[str] = Field( input: Optional[str] = Field(None,
None, description="Input prompt for the step.", example="Washington" description="Input prompt for the step.",
) example="Washington")
additional_input: Optional[StepInput] = None additional_input: Optional[StepInput] = None
@ -104,22 +105,19 @@ class Step(StepRequestBody):
description="The ID of the task step.", description="The ID of the task step.",
example="6bb1801a-fd80-45e8-899a-4dd723cc602e", example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
) )
name: Optional[str] = Field( name: Optional[str] = Field(None,
None, description="The name of the task step.", example="Write to file" description="The name of the task step.",
) example="Write to file")
status: Status = Field(..., description="The status of the task step.") status: Status = Field(..., description="The status of the task step.")
output: Optional[str] = Field( output: Optional[str] = Field(
None, None,
description="Output of the task step.", description="Output of the task step.",
example=( example=
"I am going to use the write_to_file command and write Washington to a file" ("I am going to use the write_to_file command and write Washington to a file"
" called output.txt <write_to_file('output.txt', 'Washington')" " called output.txt <write_to_file('output.txt', 'Washington')"),
),
) )
additional_output: Optional[StepOutput] = None additional_output: Optional[StepOutput] = None
artifacts: List[Artifact] = Field( artifacts: List[Artifact] = Field(
[], description="A list of artifacts that the step has produced." [], description="A list of artifacts that the step has produced.")
)
is_last: Optional[bool] = Field( is_last: Optional[bool] = Field(
False, description="Whether this is the last step in the task." False, description="Whether this is the last step in the task.")
)

@ -43,9 +43,8 @@ def maximal_marginal_relevance(
if i in idxs: if i in idxs:
continue continue
redundant_score = max(similarity_to_selected[i]) redundant_score = max(similarity_to_selected[i])
equation_score = ( equation_score = (lambda_mult * query_score -
lambda_mult * query_score - (1 - lambda_mult) * redundant_score (1 - lambda_mult) * redundant_score)
)
if equation_score > best_score: if equation_score > best_score:
best_score = equation_score best_score = equation_score
idx_to_add = i idx_to_add = i
@ -57,8 +56,8 @@ def maximal_marginal_relevance(
def filter_complex_metadata( def filter_complex_metadata(
documents: List[Document], documents: List[Document],
*, *,
allowed_types: Tuple[Type, ...] = (str, bool, int, float) allowed_types: Tuple[Type,
) -> List[Document]: ...] = (str, bool, int, float)) -> List[Document]:
"""Filter out metadata types that are not supported for a vector store.""" """Filter out metadata types that are not supported for a vector store."""
updated_documents = [] updated_documents = []
for document in documents: for document in documents:

@ -9,7 +9,6 @@ from swarms.models.huggingface import HuggingfaceLLM
from swarms.models.wizard_storytelling import WizardLLMStoryTeller from swarms.models.wizard_storytelling import WizardLLMStoryTeller
from swarms.models.mpt import MPT7B from swarms.models.mpt import MPT7B
# MultiModal Models # MultiModal Models
from swarms.models.idefics import Idefics from swarms.models.idefics import Idefics
from swarms.models.kosmos_two import Kosmos from swarms.models.kosmos_two import Kosmos
@ -27,7 +26,6 @@ import sys
log_file = open("errors.txt", "w") log_file = open("errors.txt", "w")
sys.stderr = log_file sys.stderr = log_file
__all__ = [ __all__ = [
"Anthropic", "Anthropic",
"Petals", "Petals",

@ -41,21 +41,24 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive.""" """Validate specified keyword args are mutually exclusive."""
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Validate exactly one arg in each group is not None.""" """Validate exactly one arg in each group is not None."""
counts = [ counts = [
sum(1 for arg in arg_group if kwargs.get(arg) is not None) sum(1
for arg in arg_group
if kwargs.get(arg) is not None)
for arg_group in arg_groups for arg_group in arg_groups
] ]
invalid_groups = [i for i, count in enumerate(counts) if count != 1] invalid_groups = [i for i, count in enumerate(counts) if count != 1]
if invalid_groups: if invalid_groups:
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups] invalid_group_names = [
raise ValueError( ", ".join(arg_groups[i]) for i in invalid_groups
"Exactly one argument in each of the following" ]
" groups must be defined:" raise ValueError("Exactly one argument in each of the following"
f" {', '.join(invalid_group_names)}" " groups must be defined:"
) f" {', '.join(invalid_group_names)}")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
@ -105,9 +108,10 @@ def mock_now(dt_value): # type: ignore
datetime.datetime = real_datetime datetime.datetime = real_datetime
def guard_import( def guard_import(module_name: str,
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None *,
) -> Any: pip_name: Optional[str] = None,
package: Optional[str] = None) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not """Dynamically imports a module and raises a helpful exception if the module is not
installed.""" installed."""
try: try:
@ -115,8 +119,7 @@ def guard_import(
except ImportError: except ImportError:
raise ImportError( raise ImportError(
f"Could not import {module_name} python package. " f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`." f"Please install it with `pip install {pip_name or module_name}`.")
)
return module return module
@ -132,23 +135,19 @@ def check_package_version(
if lt_version is not None and imported_version >= parse(lt_version): if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be < {lt_version}. Received " f"Expected {package} version to be < {lt_version}. Received "
f"{imported_version}." f"{imported_version}.")
)
if lte_version is not None and imported_version > parse(lte_version): if lte_version is not None and imported_version > parse(lte_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be <= {lte_version}. Received " f"Expected {package} version to be <= {lte_version}. Received "
f"{imported_version}." f"{imported_version}.")
)
if gt_version is not None and imported_version <= parse(gt_version): if gt_version is not None and imported_version <= parse(gt_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be > {gt_version}. Received " f"Expected {package} version to be > {gt_version}. Received "
f"{imported_version}." f"{imported_version}.")
)
if gte_version is not None and imported_version < parse(gte_version): if gte_version is not None and imported_version < parse(gte_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be >= {gte_version}. Received " f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}." f"{imported_version}.")
)
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
@ -180,19 +179,17 @@ def build_extra_kwargs(
if field_name in extra_kwargs: if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.") raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names: if field_name not in all_required_field_names:
warnings.warn( warnings.warn(f"""WARNING! {field_name} is not default parameter.
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs. {field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""" Please confirm that {field_name} is what you intended.""")
)
extra_kwargs[field_name] = values.pop(field_name) extra_kwargs[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys()) invalid_model_kwargs = all_required_field_names.intersection(
extra_kwargs.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "
"Instead they were passed in as part of `model_kwargs` parameter." "Instead they were passed in as part of `model_kwargs` parameter.")
)
return extra_kwargs return extra_kwargs
@ -241,17 +238,16 @@ class _AnthropicCommon(BaseLanguageModel):
def build_extra(cls, values: Dict) -> Dict: def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs( values["model_kwargs"] = build_extra_kwargs(extra, values,
extra, values, all_required_field_names all_required_field_names)
)
return values return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str( values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") get_from_dict_or_env(values, "anthropic_api_key",
) "ANTHROPIC_API_KEY"))
# Get custom api url from environment. # Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env( values["anthropic_api_url"] = get_from_dict_or_env(
values, values,
@ -281,8 +277,7 @@ class _AnthropicCommon(BaseLanguageModel):
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import anthropic python package. " "Could not import anthropic python package. "
"Please it install it with `pip install anthropic`." "Please it install it with `pip install anthropic`.")
)
return values return values
@property @property
@ -305,7 +300,8 @@ class _AnthropicCommon(BaseLanguageModel):
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{}, **self._default_params} return {**{}, **self._default_params}
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: def _get_anthropic_stop(self,
stop: Optional[List[str]] = None) -> List[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT: if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded") raise NameError("Please ensure the anthropic package is loaded")
@ -372,7 +368,8 @@ class Anthropic(LLM, _AnthropicCommon):
return prompt # Already wrapped. return prompt # Already wrapped.
# Guard against common errors in specifying wrong number of newlines. # Guard against common errors in specifying wrong number of newlines.
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT,
prompt)
if n_subs == 1: if n_subs == 1:
return corrected_prompt return corrected_prompt
@ -405,9 +402,10 @@ class Anthropic(LLM, _AnthropicCommon):
""" """
if self.streaming: if self.streaming:
completion = "" completion = ""
for chunk in self._stream( for chunk in self._stream(prompt=prompt,
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs stop=stop,
): run_manager=run_manager,
**kwargs):
completion += chunk.text completion += chunk.text
return completion return completion
@ -433,9 +431,10 @@ class Anthropic(LLM, _AnthropicCommon):
"""Call out to Anthropic's completion endpoint asynchronously.""" """Call out to Anthropic's completion endpoint asynchronously."""
if self.streaming: if self.streaming:
completion = "" completion = ""
async for chunk in self._astream( async for chunk in self._astream(prompt=prompt,
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs stop=stop,
): run_manager=run_manager,
**kwargs):
completion += chunk.text completion += chunk.text
return completion return completion
@ -476,8 +475,10 @@ class Anthropic(LLM, _AnthropicCommon):
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
for token in self.client.completions.create( for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params prompt=self._wrap_prompt(prompt),
): stop_sequences=stop,
stream=True,
**params):
chunk = GenerationChunk(text=token.completion) chunk = GenerationChunk(text=token.completion)
yield chunk yield chunk
if run_manager: if run_manager:
@ -509,10 +510,10 @@ class Anthropic(LLM, _AnthropicCommon):
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
async for token in await self.async_client.completions.create( async for token in await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
stream=True, stream=True,
**params, **params,
): ):
chunk = GenerationChunk(text=token.completion) chunk = GenerationChunk(text=token.completion)
yield chunk yield chunk

@ -97,9 +97,8 @@ class BioClip:
self.preprocess_val, self.preprocess_val,
) = open_clip.create_model_and_transforms(model_path) ) = open_clip.create_model_and_transforms(model_path)
self.tokenizer = open_clip.get_tokenizer(model_path) self.tokenizer = open_clip.get_tokenizer(model_path)
self.device = ( self.device = (torch.device("cuda")
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available() else torch.device("cpu"))
)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@ -110,18 +109,17 @@ class BioClip:
template: str = "this is a photo of ", template: str = "this is a photo of ",
context_length: int = 256, context_length: int = 256,
): ):
image = torch.stack([self.preprocess_val(Image.open(img_path))]).to(self.device) image = torch.stack([self.preprocess_val(Image.open(img_path))
texts = self.tokenizer( ]).to(self.device)
[template + l for l in labels], context_length=context_length texts = self.tokenizer([template + l for l in labels],
).to(self.device) context_length=context_length).to(self.device)
with torch.no_grad(): with torch.no_grad():
image_features, text_features, logit_scale = self.model(image, texts) image_features, text_features, logit_scale = self.model(
logits = ( image, texts)
(logit_scale * image_features @ text_features.t()) logits = ((logit_scale *
.detach() image_features @ text_features.t()).detach().softmax(
.softmax(dim=-1) dim=-1))
)
sorted_indices = torch.argsort(logits, dim=-1, descending=True) sorted_indices = torch.argsort(logits, dim=-1, descending=True)
logits = logits.cpu().numpy() logits = logits.cpu().numpy()
sorted_indices = sorted_indices.cpu().numpy() sorted_indices = sorted_indices.cpu().numpy()
@ -139,11 +137,8 @@ class BioClip:
fig, ax = plt.subplots(figsize=(5, 5)) fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(img) ax.imshow(img)
ax.axis("off") ax.axis("off")
title = ( title = (metadata["filename"] + "\n" + "\n".join(
metadata["filename"] [f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()]))
+ "\n"
+ "\n".join([f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()])
)
ax.set_title(title, fontsize=14) ax.set_title(title, fontsize=14)
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

@ -102,9 +102,9 @@ class BioGPT:
list[dict]: A list of generated texts. list[dict]: A list of generated texts.
""" """
set_seed(42) set_seed(42)
generator = pipeline( generator = pipeline("text-generation",
"text-generation", model=self.model, tokenizer=self.tokenizer model=self.model,
) tokenizer=self.tokenizer)
out = generator( out = generator(
text, text,
max_length=self.max_length, max_length=self.max_length,
@ -149,13 +149,11 @@ class BioGPT:
inputs = self.tokenizer(sentence, return_tensors="pt") inputs = self.tokenizer(sentence, return_tensors="pt")
set_seed(42) set_seed(42)
with torch.no_grad(): with torch.no_grad():
beam_output = self.model.generate( beam_output = self.model.generate(**inputs,
**inputs, min_length=self.min_length,
min_length=self.min_length, max_length=self.max_length,
max_length=self.max_length, num_beams=num_beams,
num_beams=num_beams, early_stopping=early_stopping)
early_stopping=early_stopping
)
return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) return self.tokenizer.decode(beam_output[0], skip_special_tokens=True)
# Feature 1: Set a new tokenizer and model # Feature 1: Set a new tokenizer and model

@ -124,13 +124,10 @@ class Dalle3:
# Handling exceptions and printing the errors details # Handling exceptions and printing the errors details
print( print(
colored( colored(
( (f"Error running Dalle3: {error} try optimizing your api key and"
f"Error running Dalle3: {error} try optimizing your api key and" " or try again"),
" or try again"
),
"red", "red",
) ))
)
raise error raise error
def create_variations(self, img: str): def create_variations(self, img: str):
@ -157,22 +154,19 @@ class Dalle3:
""" """
try: try:
response = self.client.images.create_variation( response = self.client.images.create_variation(img=open(img, "rb"),
img=open(img, "rb"), n=self.n, size=self.size n=self.n,
) size=self.size)
img = response.data[0].url img = response.data[0].url
return img return img
except (Exception, openai.OpenAIError) as error: except (Exception, openai.OpenAIError) as error:
print( print(
colored( colored(
( (f"Error running Dalle3: {error} try optimizing your api key and"
f"Error running Dalle3: {error} try optimizing your api key and" " or try again"),
" or try again"
),
"red", "red",
) ))
)
print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error raise error

@ -18,6 +18,7 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1):
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
retries = max_retries retries = max_retries
@ -28,7 +29,9 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1):
retries -= 1 retries -= 1
if retries <= 0: if retries <= 0:
raise raise
print(f"Retry after exception: {e}, Attempts remaining: {retries}") print(
f"Retry after exception: {e}, Attempts remaining: {retries}"
)
await asyncio.sleep(delay) await asyncio.sleep(delay)
return wrapper return wrapper
@ -62,7 +65,8 @@ class DistilWhisperModel:
def __init__(self, model_id="distil-whisper/distil-large-v2"): def __init__(self, model_id="distil-whisper/distil-large-v2"):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.torch_dtype = torch.float16 if torch.cuda.is_available(
) else torch.float32
self.model_id = model_id self.model_id = model_id
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, model_id,
@ -119,14 +123,14 @@ class DistilWhisperModel:
try: try:
with torch.no_grad(): with torch.no_grad():
# Load the whole audio file, but process and transcribe it in chunks # Load the whole audio file, but process and transcribe it in chunks
audio_input = self.processor.audio_file_to_array(audio_file_path) audio_input = self.processor.audio_file_to_array(
audio_file_path)
sample_rate = audio_input.sampling_rate sample_rate = audio_input.sampling_rate
total_duration = len(audio_input.array) / sample_rate total_duration = len(audio_input.array) / sample_rate
chunks = [ chunks = [
audio_input.array[i : i + sample_rate * chunk_duration] audio_input.array[i:i + sample_rate * chunk_duration]
for i in range( for i in range(0, len(audio_input.array), sample_rate *
0, len(audio_input.array), sample_rate * chunk_duration chunk_duration)
)
] ]
print(colored("Starting real-time transcription...", "green")) print(colored("Starting real-time transcription...", "green"))
@ -139,22 +143,22 @@ class DistilWhisperModel:
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
) )
processed_inputs = processed_inputs.input_values.to(self.device) processed_inputs = processed_inputs.input_values.to(
self.device)
# Generate transcription for the chunk # Generate transcription for the chunk
logits = self.model.generate(processed_inputs) logits = self.model.generate(processed_inputs)
transcription = self.processor.batch_decode( transcription = self.processor.batch_decode(
logits, skip_special_tokens=True logits, skip_special_tokens=True)[0]
)[0]
# Print the chunk's transcription # Print the chunk's transcription
print( print(
colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") +
+ transcription transcription)
)
# Wait for the chunk's duration to simulate real-time processing # Wait for the chunk's duration to simulate real-time processing
time.sleep(chunk_duration) time.sleep(chunk_duration)
except Exception as e: except Exception as e:
print(colored(f"An error occurred during transcription: {e}", "red")) print(colored(f"An error occurred during transcription: {e}",
"red"))

@ -11,7 +11,8 @@ from pydantic import BaseModel, StrictFloat, StrictInt, validator
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the classes for image classification # Load the classes for image classification
with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f: with open(os.path.join(os.path.dirname(__file__),
"fast_vit_classes.json")) as f:
FASTVIT_IMAGENET_1K_CLASSES = json.load(f) FASTVIT_IMAGENET_1K_CLASSES = json.load(f)
@ -21,7 +22,8 @@ class ClassificationResult(BaseModel):
@validator("class_id", "confidence", pre=True, each_item=True) @validator("class_id", "confidence", pre=True, each_item=True)
def check_list_contents(cls, v): def check_list_contents(cls, v):
assert isinstance(v, int) or isinstance(v, float), "must be integer or float" assert isinstance(v, int) or isinstance(
v, float), "must be integer or float"
return v return v
@ -47,16 +49,16 @@ class FastViT:
""" """
def __init__(self): def __init__(self):
self.model = timm.create_model( self.model = timm.create_model("hf_hub:timm/fastvit_s12.apple_in1k",
"hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True pretrained=True).to(DEVICE)
).to(DEVICE)
data_config = timm.data.resolve_model_data_config(self.model) data_config = timm.data.resolve_model_data_config(self.model)
self.transforms = timm.data.create_transform(**data_config, is_training=False) self.transforms = timm.data.create_transform(**data_config,
is_training=False)
self.model.eval() self.model.eval()
def __call__( def __call__(self,
self, img: str, confidence_threshold: float = 0.5 img: str,
) -> ClassificationResult: confidence_threshold: float = 0.5) -> ClassificationResult:
"""classifies the input image and returns the top k classes and their probabilities""" """classifies the input image and returns the top k classes and their probabilities"""
img = Image.open(img).convert("RGB") img = Image.open(img).convert("RGB")
img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE) img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
@ -65,9 +67,8 @@ class FastViT:
probabilities = torch.nn.functional.softmax(output, dim=1) probabilities = torch.nn.functional.softmax(output, dim=1)
# Get top k classes and their probabilities # Get top k classes and their probabilities
top_probs, top_classes = torch.topk( top_probs, top_classes = torch.topk(probabilities,
probabilities, k=FASTVIT_IMAGENET_1K_CLASSES k=FASTVIT_IMAGENET_1K_CLASSES)
)
# Filter by confidence threshold # Filter by confidence threshold
mask = top_probs > confidence_threshold mask = top_probs > confidence_threshold

@ -46,9 +46,9 @@ class Fuyu:
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
self.image_processor = FuyuImageProcessor() self.image_processor = FuyuImageProcessor()
self.processor = FuyuProcessor( self.processor = FuyuProcessor(image_processor=self.image_processor,
image_processor=self.image_processor, tokenizer=self.tokenizer, **kwargs tokenizer=self.tokenizer,
) **kwargs)
self.model = FuyuForCausalLM.from_pretrained( self.model = FuyuForCausalLM.from_pretrained(
pretrained_path, pretrained_path,
device_map=device_map, device_map=device_map,
@ -63,15 +63,17 @@ class Fuyu:
def __call__(self, text: str, img: str): def __call__(self, text: str, img: str):
"""Call the model with text and img paths""" """Call the model with text and img paths"""
image_pil = Image.open(img) image_pil = Image.open(img)
model_inputs = self.processor( model_inputs = self.processor(text=text,
text=text, images=[image_pil], device=self.device_map images=[image_pil],
) device=self.device_map)
for k, v in model_inputs.items(): for k, v in model_inputs.items():
model_inputs[k] = v.to(self.device_map) model_inputs[k] = v.to(self.device_map)
output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens) output = self.model.generate(**model_inputs,
text = self.processor.batch_decode(output[:, -7:], skip_special_tokens=True) max_new_tokens=self.max_new_tokens)
text = self.processor.batch_decode(output[:, -7:],
skip_special_tokens=True)
return print(str(text)) return print(str(text))
def get_img_from_web(self, img_url: str): def get_img_from_web(self, img_url: str):

@ -130,19 +130,23 @@ class GPT4Vision:
} }
# Image content # Image content
image_content = [ image_content = [{
{"type": "imavge_url", "image_url": img} "type": "imavge_url",
if img.startswith("http") "image_url": img
else {"type": "image", "data": img} } if img.startswith("http") else {
for img in img "type": "image",
] "data": img
} for img in img]
messages = [
{ messages = [{
"role": "user", "role":
"content": image_content + [{"type": "text", "text": q} for q in tasks], "user",
} "content":
] image_content + [{
"type": "text",
"text": q
} for q in tasks],
}]
payload = { payload = {
"model": "gpt-4-vision-preview", "model": "gpt-4-vision-preview",
@ -160,7 +164,8 @@ class GPT4Vision:
timeout=self.timeout_seconds, timeout=self.timeout_seconds,
) )
response.raise_for_status() response.raise_for_status()
answer = response.json()["choices"][0]["message"]["content"]["text"] answer = response.json(
)["choices"][0]["message"]["content"]["text"]
return GPT4VisionResponse(answer=answer) return GPT4VisionResponse(answer=answer)
except requests.exceptions.HTTPError as error: except requests.exceptions.HTTPError as error:
self.logger.error( self.logger.error(
@ -179,8 +184,7 @@ class GPT4Vision:
except Exception as error: except Exception as error:
self.logger.error( self.logger.error(
f"Unexpected Error: {error} try optimizing your api key and try" f"Unexpected Error: {error} try optimizing your api key and try"
" again" " again")
)
raise error from None raise error from None
raise TimeoutError("API Request timed out after multiple retries") raise TimeoutError("API Request timed out after multiple retries")
@ -212,18 +216,20 @@ class GPT4Vision:
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=[ messages=[{
{ "role":
"role": "user", "user",
"content": [ "content": [
{"type": "text", "text": f"{task}"}, {
{ "type": "text",
"type": "image_url", "text": f"{task}"
"image_url": f"{img}", },
}, {
], "type": "image_url",
} "image_url": f"{img}",
], },
],
}],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -232,13 +238,10 @@ class GPT4Vision:
except Exception as error: except Exception as error:
print( print(
colored( colored(
( (f"Error when calling GPT4Vision, Error: {error} Try optimizing"
f"Error when calling GPT4Vision, Error: {error} Try optimizing" " your key, and try again"),
" your key, and try again"
),
"red", "red",
) ))
)
async def arun(self, task: str, img: str) -> str: async def arun(self, task: str, img: str) -> str:
""" """
@ -267,18 +270,20 @@ class GPT4Vision:
try: try:
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
model=self.model, model=self.model,
messages=[ messages=[{
{ "role":
"role": "user", "user",
"content": [ "content": [
{"type": "text", "text": f"{task}"}, {
{ "type": "text",
"type": "image_url", "text": f"{task}"
"image_url": f"{img}", },
}, {
], "type": "image_url",
} "image_url": f"{img}",
], },
],
}],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
out = response.choices[0].text out = response.choices[0].text
@ -286,10 +291,7 @@ class GPT4Vision:
except Exception as error: except Exception as error:
print( print(
colored( colored(
( (f"Error when calling GPT4Vision, Error: {error} Try optimizing"
f"Error when calling GPT4Vision, Error: {error} Try optimizing" " your key, and try again"),
" your key, and try again"
),
"red", "red",
) ))
)

@ -47,9 +47,8 @@ class HuggingfaceLLM:
**kwargs, **kwargs,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = ( self.device = (device if device else
device if device else ("cuda" if torch.cuda.is_available() else "cpu") ("cuda" if torch.cuda.is_available() else "cpu"))
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -58,9 +57,8 @@ class HuggingfaceLLM:
self.model, self.tokenizer = None, None self.model, self.tokenizer = None, None
if self.distributed: if self.distributed:
assert ( assert (torch.cuda.device_count() >
torch.cuda.device_count() > 1 1), "You need more than 1 gpu for distributed processing"
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -75,17 +73,17 @@ class HuggingfaceLLM:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, *args, **kwargs self.model_id, *args, **kwargs)
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config, *args, **kwargs self.model_id, quantization_config=bnb_config, *args, **kwargs)
)
self.model # .to(self.device) self.model # .to(self.device)
except Exception as e: except Exception as e:
# self.logger.error(f"Failed to load the model or the tokenizer: {e}") # self.logger.error(f"Failed to load the model or the tokenizer: {e}")
# raise # raise
print(colored(f"Failed to load the model and or the tokenizer: {e}", "red")) print(
colored(f"Failed to load the model and or the tokenizer: {e}",
"red"))
def print_error(self, error: str): def print_error(self, error: str):
"""Print error""" """Print error"""
@ -97,20 +95,18 @@ class HuggingfaceLLM:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = ( bnb_config = (BitsAndBytesConfig(**self.quantization_config)
BitsAndBytesConfig(**self.quantization_config) if self.quantization_config else None)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config self.model_id,
).to(self.device) quantization_config=bnb_config).to(self.device)
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error(f"Failed to load the model or the tokenizer: {error}") self.logger.error(
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, task: str): def run(self, task: str):
@ -131,7 +127,8 @@ class HuggingfaceLLM:
self.print_dashboard(task) self.print_dashboard(task)
try: try:
inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) inputs = self.tokenizer.encode(task,
return_tensors="pt").to(self.device)
# self.log.start() # self.log.start()
@ -140,39 +137,36 @@ class HuggingfaceLLM:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=len(inputs) + 1, do_sample=True max_length=len(inputs) +
) 1,
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode( self.tokenizer.decode([output_tokens],
[output_tokens], skip_special_tokens=True skip_special_tokens=True),
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=max_length, do_sample=True max_length=max_length,
) do_sample=True)
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e: except Exception as e:
print( print(
colored( colored(
( (f"HuggingfaceLLM could not generate text because of error: {e},"
f"HuggingfaceLLM could not generate text because of error: {e}," " try optimizing your arguments"),
" try optimizing your arguments"
),
"red", "red",
) ))
)
raise raise
async def run_async(self, task: str, *args, **kwargs) -> str: async def run_async(self, task: str, *args, **kwargs) -> str:
@ -216,7 +210,8 @@ class HuggingfaceLLM:
self.print_dashboard(task) self.print_dashboard(task)
try: try:
inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device) inputs = self.tokenizer.encode(task,
return_tensors="pt").to(self.device)
# self.log.start() # self.log.start()
@ -225,26 +220,26 @@ class HuggingfaceLLM:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=len(inputs) + 1, do_sample=True max_length=len(inputs) +
) 1,
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode( self.tokenizer.decode([output_tokens],
[output_tokens], skip_special_tokens=True skip_special_tokens=True),
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=max_length, do_sample=True max_length=max_length,
) do_sample=True)
del inputs del inputs
@ -305,8 +300,7 @@ class HuggingfaceLLM:
""", """,
"red", "red",
) ))
)
print(dashboard) print(dashboard)

@ -65,9 +65,8 @@ class Idefics:
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
max_length=100, max_length=100,
): ):
self.device = ( self.device = (device if device else
device if device else ("cuda" if torch.cuda.is_available() else "cpu") ("cuda" if torch.cuda.is_available() else "cpu"))
)
self.model = IdeficsForVisionText2Text.from_pretrained( self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, checkpoint,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -96,21 +95,17 @@ class Idefics:
list list
A list of generated text strings. A list of generated text strings.
""" """
inputs = ( inputs = (self.processor(
self.processor( prompts, add_end_of_utterance_token=False, return_tensors="pt").to(
prompts, add_end_of_utterance_token=False, return_tensors="pt" self.device) if batched_mode else self.processor(
).to(self.device) prompts[0], return_tensors="pt").to(self.device))
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(self.device)
)
exit_condition = self.processor.tokenizer( exit_condition = self.processor.tokenizer(
"<end_of_utterance>", add_special_tokens=False "<end_of_utterance>", add_special_tokens=False).input_ids
).input_ids
bad_words_ids = self.processor.tokenizer( bad_words_ids = self.processor.tokenizer(
["<image>", "<fake_token_around_image"], add_special_tokens=False ["<image>", "<fake_token_around_image"],
).input_ids add_special_tokens=False).input_ids
generated_ids = self.model.generate( generated_ids = self.model.generate(
**inputs, **inputs,
@ -118,9 +113,8 @@ class Idefics:
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
max_length=self.max_length, max_length=self.max_length,
) )
generated_text = self.processor.batch_decode( generated_text = self.processor.batch_decode(generated_ids,
generated_ids, skip_special_tokens=True skip_special_tokens=True)
)
return generated_text return generated_text
def __call__(self, prompts, batched_mode=True): def __call__(self, prompts, batched_mode=True):
@ -141,21 +135,17 @@ class Idefics:
list list
A list of generated text strings. A list of generated text strings.
""" """
inputs = ( inputs = (self.processor(
self.processor( prompts, add_end_of_utterance_token=False, return_tensors="pt").to(
prompts, add_end_of_utterance_token=False, return_tensors="pt" self.device) if batched_mode else self.processor(
).to(self.device) prompts[0], return_tensors="pt").to(self.device))
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(self.device)
)
exit_condition = self.processor.tokenizer( exit_condition = self.processor.tokenizer(
"<end_of_utterance>", add_special_tokens=False "<end_of_utterance>", add_special_tokens=False).input_ids
).input_ids
bad_words_ids = self.processor.tokenizer( bad_words_ids = self.processor.tokenizer(
["<image>", "<fake_token_around_image"], add_special_tokens=False ["<image>", "<fake_token_around_image"],
).input_ids add_special_tokens=False).input_ids
generated_ids = self.model.generate( generated_ids = self.model.generate(
**inputs, **inputs,
@ -163,9 +153,8 @@ class Idefics:
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
max_length=self.max_length, max_length=self.max_length,
) )
generated_text = self.processor.batch_decode( generated_text = self.processor.batch_decode(generated_ids,
generated_ids, skip_special_tokens=True skip_special_tokens=True)
)
return generated_text return generated_text
def chat(self, user_input): def chat(self, user_input):
@ -202,8 +191,7 @@ class Idefics:
The name of the new pre-trained model checkpoint. The name of the new pre-trained model checkpoint.
""" """
self.model = IdeficsForVisionText2Text.from_pretrained( self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, torch_dtype=torch.bfloat16 checkpoint, torch_dtype=torch.bfloat16).to(self.device)
).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint) self.processor = AutoProcessor.from_pretrained(checkpoint)
def set_device(self, device): def set_device(self, device):

@ -53,9 +53,8 @@ class JinaEmbeddings:
**kwargs, **kwargs,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = ( self.device = (device if device else
device if device else ("cuda" if torch.cuda.is_available() else "cpu") ("cuda" if torch.cuda.is_available() else "cpu"))
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -66,9 +65,8 @@ class JinaEmbeddings:
self.cos_sim = cos_sim self.cos_sim = cos_sim
if self.distributed: if self.distributed:
assert ( assert (torch.cuda.device_count() >
torch.cuda.device_count() > 1 1), "You need more than 1 gpu for distributed processing"
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -83,8 +81,9 @@ class JinaEmbeddings:
try: try:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config, trust_remote_code=True self.model_id,
) quantization_config=bnb_config,
trust_remote_code=True)
self.model # .to(self.device) self.model # .to(self.device)
except Exception as e: except Exception as e:
@ -97,11 +96,8 @@ class JinaEmbeddings:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = ( bnb_config = (BitsAndBytesConfig(**self.quantization_config)
BitsAndBytesConfig(**self.quantization_config) if self.quantization_config else None)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id,
@ -112,7 +108,8 @@ class JinaEmbeddings:
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error(f"Failed to load the model or the tokenizer: {error}") self.logger.error(
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, task: str): def run(self, task: str):

@ -14,11 +14,8 @@ class Detections(BaseModel):
@root_validator @root_validator
def check_length(cls, values): def check_length(cls, values):
assert ( assert (len(values.get("xyxy")) == len(values.get("class_id")) == len(
len(values.get("xyxy")) values.get("confidence"))), "All fields must have the same length."
== len(values.get("class_id"))
== len(values.get("confidence"))
), "All fields must have the same length."
return values return values
@validator("xyxy", "class_id", "confidence", pre=True, each_item=True) @validator("xyxy", "class_id", "confidence", pre=True, each_item=True)
@ -39,11 +36,9 @@ class Kosmos2(BaseModel):
@classmethod @classmethod
def initialize(cls): def initialize(cls):
model = AutoModelForVision2Seq.from_pretrained( model = AutoModelForVision2Seq.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True "ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
)
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True "ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
)
return cls(model=model, processor=processor) return cls(model=model, processor=processor)
def __call__(self, img: str) -> Detections: def __call__(self, img: str) -> Detections:
@ -51,11 +46,12 @@ class Kosmos2(BaseModel):
prompt = "<grounding>An image of" prompt = "<grounding>An image of"
inputs = self.processor(text=prompt, images=image, return_tensors="pt") inputs = self.processor(text=prompt, images=image, return_tensors="pt")
outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64) outputs = self.model.generate(**inputs,
use_cache=True,
max_new_tokens=64)
generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[ generated_text = self.processor.batch_decode(
0 outputs, skip_special_tokens=True)[0]
]
# The actual processing of generated_text to entities would go here # The actual processing of generated_text to entities would go here
# For the purpose of this example, assume a mock function 'extract_entities' exists: # For the purpose of this example, assume a mock function 'extract_entities' exists:
@ -66,8 +62,8 @@ class Kosmos2(BaseModel):
return detections return detections
def extract_entities( def extract_entities(
self, text: str self,
) -> List[Tuple[str, Tuple[float, float, float, float]]]: text: str) -> List[Tuple[str, Tuple[float, float, float, float]]]:
# Placeholder function for entity extraction # Placeholder function for entity extraction
# This should be replaced with the actual method of extracting entities # This should be replaced with the actual method of extracting entities
return [] return []
@ -80,19 +76,19 @@ class Kosmos2(BaseModel):
if not entities: if not entities:
return Detections.empty() return Detections.empty()
class_ids = [0] * len(entities) # Replace with actual class ID extraction logic class_ids = [0] * len(
xyxys = [ entities) # Replace with actual class ID extraction logic
( xyxys = [(
e[1][0] * image.width, e[1][0] * image.width,
e[1][1] * image.height, e[1][1] * image.height,
e[1][2] * image.width, e[1][2] * image.width,
e[1][3] * image.height, e[1][3] * image.height,
) ) for e in entities]
for e in entities
]
confidences = [1.0] * len(entities) # Placeholder confidence confidences = [1.0] * len(entities) # Placeholder confidence
return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences) return Detections(xyxy=xyxys,
class_id=class_ids,
confidence=confidences)
# Usage: # Usage:

@ -46,11 +46,9 @@ class Kosmos:
model_name="ydshieh/kosmos-2-patch14-224", model_name="ydshieh/kosmos-2-patch14-224",
): ):
self.model = AutoModelForVision2Seq.from_pretrained( self.model = AutoModelForVision2Seq.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True)
) self.processor = AutoProcessor.from_pretrained(model_name,
self.processor = AutoProcessor.from_pretrained( trust_remote_code=True)
model_name, trust_remote_code=True
)
def get_image(self, url): def get_image(self, url):
"""Image""" """Image"""
@ -73,8 +71,7 @@ class Kosmos:
skip_special_tokens=True, skip_special_tokens=True,
)[0] )[0]
processed_text, entities = self.processor.post_process_generation( processed_text, entities = self.processor.post_process_generation(
generated_texts generated_texts)
)
def __call__(self, prompt, image): def __call__(self, prompt, image):
"""Run call""" """Run call"""
@ -93,8 +90,7 @@ class Kosmos:
skip_special_tokens=True, skip_special_tokens=True,
)[0] )[0]
processed_text, entities = self.processor.post_process_generation( processed_text, entities = self.processor.post_process_generation(
generated_texts generated_texts)
)
# tasks # tasks
def multimodal_grounding(self, phrase, image_url): def multimodal_grounding(self, phrase, image_url):
@ -145,12 +141,10 @@ class Kosmos:
elif isinstance(image, torch.Tensor): elif isinstance(image, torch.Tensor):
# pdb.set_trace() # pdb.set_trace()
image_tensor = image.cpu() image_tensor = image.cpu()
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[ reverse_norm_mean = torch.tensor(
:, None, None [0.48145466, 0.4578275, 0.40821073])[:, None, None]
] reverse_norm_std = torch.tensor(
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[ [0.26862954, 0.26130258, 0.27577711])[:, None, None]
:, None, None
]
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
pil_img = T.ToPILImage()(image_tensor) pil_img = T.ToPILImage()(image_tensor)
image_h = pil_img.height image_h = pil_img.height
@ -169,9 +163,9 @@ class Kosmos:
# thickness of text # thickness of text
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
box_line = 3 box_line = 3
(c_width, text_height), _ = cv2.getTextSize( (c_width, text_height), _ = cv2.getTextSize("F",
"F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line cv2.FONT_HERSHEY_COMPLEX,
) text_size, text_line)
base_height = int(text_height * 0.675) base_height = int(text_height * 0.675)
text_offset_original = text_height - base_height text_offset_original = text_height - base_height
text_spaces = 3 text_spaces = 3
@ -187,9 +181,8 @@ class Kosmos:
# draw bbox # draw bbox
# random color # random color
color = tuple(np.random.randint(0, 255, size=3).tolist()) color = tuple(np.random.randint(0, 255, size=3).tolist())
new_image = cv2.rectangle( new_image = cv2.rectangle(new_image, (orig_x1, orig_y1),
new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line (orig_x2, orig_y2), color, box_line)
)
l_o, r_o = ( l_o, r_o = (
box_line // 2 + box_line % 2, box_line // 2 + box_line % 2,
@ -200,19 +193,15 @@ class Kosmos:
y1 = orig_y1 - l_o y1 = orig_y1 - l_o
if y1 < text_height + text_offset_original + 2 * text_spaces: if y1 < text_height + text_offset_original + 2 * text_spaces:
y1 = ( y1 = (orig_y1 + r_o + text_height + text_offset_original +
orig_y1 2 * text_spaces)
+ r_o
+ text_height
+ text_offset_original
+ 2 * text_spaces
)
x1 = orig_x1 + r_o x1 = orig_x1 + r_o
# add text background # add text background
(text_width, text_height), _ = cv2.getTextSize( (text_width,
f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line text_height), _ = cv2.getTextSize(f" {entity_name}",
) cv2.FONT_HERSHEY_COMPLEX,
text_size, text_line)
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = (
x1, x1,
y1 - (text_height + text_offset_original + 2 * text_spaces), y1 - (text_height + text_offset_original + 2 * text_spaces),
@ -222,23 +211,19 @@ class Kosmos:
for prev_bbox in previous_bboxes: for prev_bbox in previous_bboxes:
while is_overlapping( while is_overlapping(
(text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2),
): prev_bbox):
text_bg_y1 += ( text_bg_y1 += (text_height + text_offset_original +
text_height + text_offset_original + 2 * text_spaces 2 * text_spaces)
) text_bg_y2 += (text_height + text_offset_original +
text_bg_y2 += ( 2 * text_spaces)
text_height + text_offset_original + 2 * text_spaces
)
y1 += text_height + text_offset_original + 2 * text_spaces y1 += text_height + text_offset_original + 2 * text_spaces
if text_bg_y2 >= image_h: if text_bg_y2 >= image_h:
text_bg_y1 = max( text_bg_y1 = max(
0, 0,
image_h image_h - (text_height + text_offset_original +
- ( 2 * text_spaces),
text_height + text_offset_original + 2 * text_spaces
),
) )
text_bg_y2 = image_h text_bg_y2 = image_h
y1 = image_h y1 = image_h
@ -255,9 +240,9 @@ class Kosmos:
# white # white
bg_color = [255, 255, 255] bg_color = [255, 255, 255]
new_image[i, j] = ( new_image[i, j] = (
alpha * new_image[i, j] alpha * new_image[i, j] +
+ (1 - alpha) * np.array(bg_color) (1 - alpha) * np.array(bg_color)).astype(
).astype(np.uint8) np.uint8)
cv2.putText( cv2.putText(
new_image, new_image,
@ -270,7 +255,8 @@ class Kosmos:
cv2.LINE_AA, cv2.LINE_AA,
) )
# previous_locations.append((x1, y1)) # previous_locations.append((x1, y1))
previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) previous_bboxes.append(
(text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
if save_path: if save_path:

@ -48,9 +48,8 @@ class MultiModalLlava:
revision=revision, revision=revision,
).to(self.device) ).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
model_name_or_path, use_fast=True use_fast=True)
)
self.pipe = pipeline( self.pipe = pipeline(
"text-generation", "text-generation",
model=self.model, model=self.model,

@ -49,7 +49,8 @@ class Mistral:
# Check if the specified device is available # Check if the specified device is available
if not torch.cuda.is_available() and device == "cuda": if not torch.cuda.is_available() and device == "cuda":
raise ValueError("CUDA is not available. Please choose a different device.") raise ValueError(
"CUDA is not available. Please choose a different device.")
# Load the model and tokenizer # Load the model and tokenizer
self.model = None self.model = None
@ -70,7 +71,8 @@ class Mistral:
"""Run the model on a given task.""" """Run the model on a given task."""
try: try:
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) model_inputs = self.tokenizer([task],
return_tensors="pt").to(self.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
**model_inputs, **model_inputs,
max_length=self.max_length, max_length=self.max_length,
@ -87,7 +89,8 @@ class Mistral:
"""Run the model on a given task.""" """Run the model on a given task."""
try: try:
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) model_inputs = self.tokenizer([task],
return_tensors="pt").to(self.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
**model_inputs, **model_inputs,
max_length=self.max_length, max_length=self.max_length,

@ -26,7 +26,10 @@ class MPT7B:
""" """
def __init__(self, model_name: str, tokenizer_name: str, max_tokens: int = 100): def __init__(self,
model_name: str,
tokenizer_name: str,
max_tokens: int = 100):
# Loading model and tokenizer details # Loading model and tokenizer details
self.model_name = model_name self.model_name = model_name
self.tokenizer_name = tokenizer_name self.tokenizer_name = tokenizer_name
@ -37,11 +40,9 @@ class MPT7B:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
config = AutoModelForCausalLM.from_pretrained( config = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True model_name, trust_remote_code=True).config
).config
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, config=config, trust_remote_code=True model_name, config=config, trust_remote_code=True)
)
# Initializing a text-generation pipeline # Initializing a text-generation pipeline
self.pipe = pipeline( self.pipe = pipeline(
@ -114,9 +115,10 @@ class MPT7B:
""" """
with torch.autocast("cuda", dtype=torch.bfloat16): with torch.autocast("cuda", dtype=torch.bfloat16):
return self.pipe( return self.pipe(prompt,
prompt, max_new_tokens=self.max_tokens, do_sample=True, use_cache=True max_new_tokens=self.max_tokens,
)[0]["generated_text"] do_sample=True,
use_cache=True)[0]["generated_text"]
async def generate_async(self, prompt: str) -> str: async def generate_async(self, prompt: str) -> str:
"""Generate Async""" """Generate Async"""

@ -41,8 +41,10 @@ class Nougat:
self.min_length = min_length self.min_length = min_length
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.processor = NougatProcessor.from_pretrained(self.model_name_or_path) self.processor = NougatProcessor.from_pretrained(
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path) self.model_name_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(
self.model_name_or_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device) self.model.to(self.device)
@ -63,8 +65,10 @@ class Nougat:
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
) )
sequence = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] sequence = self.processor.batch_decode(outputs,
sequence = self.processor.post_process_generation(sequence, fix_markdown=False) skip_special_tokens=True)[0]
sequence = self.processor.post_process_generation(sequence,
fix_markdown=False)
out = print(repr(sequence)) out = print(repr(sequence))
return out return out

@ -55,9 +55,9 @@ class OpenAIAssistant:
return thread return thread
def add_message_to_thread(self, thread_id: str, message: str): def add_message_to_thread(self, thread_id: str, message: str):
message = self.client.beta.threads.add_message( message = self.client.beta.threads.add_message(thread_id=thread_id,
thread_id=thread_id, role=self.user, content=message role=self.user,
) content=message)
return message return message
def run(self, task: str): def run(self, task: str):
@ -67,8 +67,7 @@ class OpenAIAssistant:
instructions=self.instructions, instructions=self.instructions,
) )
out = self.client.beta.threads.runs.retrieve( out = self.client.beta.threads.runs.retrieve(thread_id=run.thread_id,
thread_id=run.thread_id, run_id=run.id run_id=run.id)
)
return out return out

@ -28,9 +28,10 @@ from tenacity import (
from swarms.models.embeddings_base import Embeddings from swarms.models.embeddings_base import Embeddings
def get_from_dict_or_env( def get_from_dict_or_env(values: dict,
values: dict, key: str, env_key: str, default: Any = None key: str,
) -> Any: env_key: str,
default: Any = None) -> Any:
import os import os
return values.get(key) or os.getenv(env_key) or default return values.get(key) or os.getenv(env_key) or default
@ -43,7 +44,8 @@ def get_pydantic_field_names(cls: Any) -> Set[str]:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: def _create_retry_decorator(
embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]:
import llm import llm
min_seconds = 4 min_seconds = 4
@ -54,13 +56,11 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
reraise=True, reraise=True,
stop=stop_after_attempt(embeddings.max_retries), stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=( retry=(retry_if_exception_type(llm.error.Timeout) |
retry_if_exception_type(llm.error.Timeout) retry_if_exception_type(llm.error.APIError) |
| retry_if_exception_type(llm.error.APIError) retry_if_exception_type(llm.error.APIConnectionError) |
| retry_if_exception_type(llm.error.APIConnectionError) retry_if_exception_type(llm.error.RateLimitError) |
| retry_if_exception_type(llm.error.RateLimitError) retry_if_exception_type(llm.error.ServiceUnavailableError)),
| retry_if_exception_type(llm.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
@ -76,17 +76,16 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
reraise=True, reraise=True,
stop=stop_after_attempt(embeddings.max_retries), stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=( retry=(retry_if_exception_type(llm.error.Timeout) |
retry_if_exception_type(llm.error.Timeout) retry_if_exception_type(llm.error.APIError) |
| retry_if_exception_type(llm.error.APIError) retry_if_exception_type(llm.error.APIConnectionError) |
| retry_if_exception_type(llm.error.APIConnectionError) retry_if_exception_type(llm.error.RateLimitError) |
| retry_if_exception_type(llm.error.RateLimitError) retry_if_exception_type(llm.error.ServiceUnavailableError)),
| retry_if_exception_type(llm.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
def wrap(func: Callable) -> Callable: def wrap(func: Callable) -> Callable:
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
async for _ in async_retrying: async for _ in async_retrying:
return await func(*args, **kwargs) return await func(*args, **kwargs)
@ -118,7 +117,8 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
return _embed_with_retry(**kwargs) return _embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: async def async_embed_with_retry(embeddings: OpenAIEmbeddings,
**kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call.""" """Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings) @_async_retry_decorator(embeddings)
@ -225,11 +225,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
warnings.warn( warnings.warn(
f"""WARNING! {field_name} is not default parameter. f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs. {field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""" Please confirm that {field_name} is what you intended.""")
)
extra[field_name] = values.pop(field_name) extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) invalid_model_kwargs = all_required_field_names.intersection(
extra.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "
@ -242,9 +242,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(values,
values, "openai_api_key", "OPENAI_API_KEY" "openai_api_key",
) "OPENAI_API_KEY")
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = get_from_dict_or_env(
values, values,
"openai_api_base", "openai_api_base",
@ -284,10 +284,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["client"] = llm.Embedding values["client"] = llm.Embedding
except ImportError: except ImportError:
raise ImportError( raise ImportError("Could not import openai python package. "
"Could not import openai python package. " "Please install it with `pip install openai`.")
"Please install it with `pip install openai`."
)
return values return values
@property @property
@ -315,8 +313,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return openai_args return openai_args
def _get_len_safe_embeddings( def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None self,
) -> List[List[float]]: texts: List[str],
*,
engine: str,
chunk_size: Optional[int] = None) -> List[List[float]]:
embeddings: List[List[float]] = [[] for _ in range(len(texts))] embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try: try:
import tiktoken import tiktoken
@ -324,8 +325,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. " "This is needed in order to for OpenAIEmbeddings. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`.")
)
tokens = [] tokens = []
indices = [] indices = []
@ -333,7 +333,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
try: try:
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.") logger.warning(
"Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base" model = "cl100k_base"
encoding = tiktoken.get_encoding(model) encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -347,7 +348,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
for j in range(0, len(token), self.embedding_ctx_length): for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j : j + self.embedding_ctx_length]) tokens.append(token[j:j + self.embedding_ctx_length])
indices.append(i) indices.append(i)
batched_embeddings: List[List[float]] = [] batched_embeddings: List[List[float]] = []
@ -366,7 +367,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in _iter: for i in _iter:
response = embed_with_retry( response = embed_with_retry(
self, self,
input=tokens[i : i + _chunk_size], input=tokens[i:i + _chunk_size],
**self._invocation_params, **self._invocation_params,
) )
batched_embeddings.extend(r["embedding"] for r in response["data"]) batched_embeddings.extend(r["embedding"] for r in response["data"])
@ -384,11 +385,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
self, self,
input="", input="",
**self._invocation_params, **self._invocation_params,
)[ )["data"][0]["embedding"]
"data"
][0]["embedding"]
else: else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) average = np.average(_result,
axis=0,
weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings return embeddings
@ -396,8 +397,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to # please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
async def _aget_len_safe_embeddings( async def _aget_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None self,
) -> List[List[float]]: texts: List[str],
*,
engine: str,
chunk_size: Optional[int] = None) -> List[List[float]]:
embeddings: List[List[float]] = [[] for _ in range(len(texts))] embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try: try:
import tiktoken import tiktoken
@ -405,8 +409,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. " "This is needed in order to for OpenAIEmbeddings. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`.")
)
tokens = [] tokens = []
indices = [] indices = []
@ -414,7 +417,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
try: try:
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.") logger.warning(
"Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base" model = "cl100k_base"
encoding = tiktoken.get_encoding(model) encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -428,7 +432,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
for j in range(0, len(token), self.embedding_ctx_length): for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j : j + self.embedding_ctx_length]) tokens.append(token[j:j + self.embedding_ctx_length])
indices.append(i) indices.append(i)
batched_embeddings: List[List[float]] = [] batched_embeddings: List[List[float]] = []
@ -436,7 +440,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(0, len(tokens), _chunk_size): for i in range(0, len(tokens), _chunk_size):
response = await async_embed_with_retry( response = await async_embed_with_retry(
self, self,
input=tokens[i : i + _chunk_size], input=tokens[i:i + _chunk_size],
**self._invocation_params, **self._invocation_params,
) )
batched_embeddings.extend(r["embedding"] for r in response["data"]) batched_embeddings.extend(r["embedding"] for r in response["data"])
@ -450,22 +454,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(len(texts)): for i in range(len(texts)):
_result = results[i] _result = results[i]
if len(_result) == 0: if len(_result) == 0:
average = ( average = (await async_embed_with_retry(
await async_embed_with_retry( self,
self, input="",
input="", **self._invocation_params,
**self._invocation_params, ))["data"][0]["embedding"]
)
)["data"][0]["embedding"]
else: else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) average = np.average(_result,
axis=0,
weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings return embeddings
def embed_documents( def embed_documents(self,
self, texts: List[str], chunk_size: Optional[int] = 0 texts: List[str],
) -> List[List[float]]: chunk_size: Optional[int] = 0) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs. """Call out to OpenAI's embedding endpoint for embedding search docs.
Args: Args:
@ -481,8 +485,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self._get_len_safe_embeddings(texts, engine=self.deployment) return self._get_len_safe_embeddings(texts, engine=self.deployment)
async def aembed_documents( async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0 self,
) -> List[List[float]]: texts: List[str],
chunk_size: Optional[int] = 0) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs. """Call out to OpenAI's embedding endpoint async for embedding search docs.
Args: Args:
@ -495,7 +500,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
# NOTE: to keep things simple, we assume the list may contain texts longer # NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function. # than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, engine=self.deployment) return await self._aget_len_safe_embeddings(texts,
engine=self.deployment)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.

@ -33,9 +33,8 @@ from langchain.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_token_usage( def update_token_usage(keys: Set[str], response: Dict[str, Any],
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] token_usage: Dict[str, Any]) -> None:
) -> None:
"""Update token usage.""" """Update token usage."""
_keys_to_use = keys.intersection(response["usage"]) _keys_to_use = keys.intersection(response["usage"])
for _key in _keys_to_use: for _key in _keys_to_use:
@ -46,44 +45,42 @@ def update_token_usage(
def _stream_response_to_generation_chunk( def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any], stream_response: Dict[str, Any],) -> GenerationChunk:
) -> GenerationChunk:
"""Convert a stream response to a generation chunk.""" """Convert a stream response to a generation chunk."""
return GenerationChunk( return GenerationChunk(
text=stream_response["choices"][0]["text"], text=stream_response["choices"][0]["text"],
generation_info=dict( generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None), finish_reason=stream_response["choices"][0].get(
"finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None), logprobs=stream_response["choices"][0].get("logprobs", None),
), ),
) )
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None: def _update_response(response: Dict[str, Any],
stream_response: Dict[str, Any]) -> None:
"""Update response from the stream response.""" """Update response from the stream response."""
response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["text"] += stream_response["choices"][0]["text"]
response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( response["choices"][0]["finish_reason"] = stream_response["choices"][0].get(
"finish_reason", None "finish_reason", None)
) response["choices"][0]["logprobs"] = stream_response["choices"][0][
response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"] "logprobs"]
def _streaming_response_template() -> Dict[str, Any]: def _streaming_response_template() -> Dict[str, Any]:
return { return {
"choices": [ "choices": [{
{ "text": "",
"text": "", "finish_reason": None,
"finish_reason": None, "logprobs": None,
"logprobs": None, }]
}
]
} }
def _create_retry_decorator( def _create_retry_decorator(
llm: Union[BaseOpenAI, OpenAIChat], llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[ run_manager: Optional[Union[AsyncCallbackManagerForLLMRun,
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] CallbackManagerForLLMRun]] = None,
] = None,
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
import openai import openai
@ -94,9 +91,9 @@ def _create_retry_decorator(
openai.error.RateLimitError, openai.error.RateLimitError,
openai.error.ServiceUnavailableError, openai.error.ServiceUnavailableError,
] ]
return create_base_retry_decorator( return create_base_retry_decorator(error_types=errors,
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager max_retries=llm.max_retries,
) run_manager=run_manager)
def completion_with_retry( def completion_with_retry(
@ -206,7 +203,8 @@ class BaseOpenAI(BaseLLM):
API but with different models. In those cases, in order to avoid erroring API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here.""" when tiktoken is called, you can specify a model name to use here."""
def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore def __new__(cls,
**data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
"""Initialize the OpenAI object.""" """Initialize the OpenAI object."""
data.get("model_name", "") data.get("model_name", "")
return super().__new__(cls) return super().__new__(cls)
@ -221,17 +219,16 @@ class BaseOpenAI(BaseLLM):
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs( values["model_kwargs"] = build_extra_kwargs(extra, values,
extra, values, all_required_field_names all_required_field_names)
)
return values return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env( values["openai_api_key"] = get_from_dict_or_env(values,
values, "openai_api_key", "OPENAI_API_KEY" "openai_api_key",
) "OPENAI_API_KEY")
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = get_from_dict_or_env(
values, values,
"openai_api_base", "openai_api_base",
@ -255,10 +252,8 @@ class BaseOpenAI(BaseLLM):
values["client"] = openai.Completion values["client"] = openai.Completion
except ImportError: except ImportError:
raise ImportError( raise ImportError("Could not import openai python package. "
"Could not import openai python package. " "Please install it with `pip install openai`.")
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1: if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.") raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1: if values["streaming"] and values["best_of"] > 1:
@ -295,9 +290,10 @@ class BaseOpenAI(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True} params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry( for stream_resp in completion_with_retry(self,
self, prompt=prompt, run_manager=run_manager, **params prompt=prompt,
): run_manager=run_manager,
**params):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
@ -306,8 +302,7 @@ class BaseOpenAI(BaseLLM):
chunk=chunk, chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"] logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info if chunk.generation_info else None,
else None,
) )
async def _astream( async def _astream(
@ -320,8 +315,7 @@ class BaseOpenAI(BaseLLM):
params = {**self._invocation_params, **kwargs, "stream": True} params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params self, prompt=prompt, run_manager=run_manager, **params):
):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
@ -330,8 +324,7 @@ class BaseOpenAI(BaseLLM):
chunk=chunk, chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"] logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info if chunk.generation_info else None,
else None,
) )
def _generate( def _generate(
@ -367,30 +360,32 @@ class BaseOpenAI(BaseLLM):
for _prompts in sub_prompts: for _prompts in sub_prompts:
if self.streaming: if self.streaming:
if len(_prompts) > 1: if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.") raise ValueError(
"Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs): for chunk in self._stream(_prompts[0], stop, run_manager,
**kwargs):
if generation is None: if generation is None:
generation = chunk generation = chunk
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
choices.append( choices.append({
{ "text":
"text": generation.text, generation.text,
"finish_reason": generation.generation_info.get("finish_reason") "finish_reason":
if generation.generation_info generation.generation_info.get("finish_reason")
else None, if generation.generation_info else None,
"logprobs": generation.generation_info.get("logprobs") "logprobs":
if generation.generation_info generation.generation_info.get("logprobs")
else None, if generation.generation_info else None,
} })
)
else: else:
response = completion_with_retry( response = completion_with_retry(self,
self, prompt=_prompts, run_manager=run_manager, **params prompt=_prompts,
) run_manager=run_manager,
**params)
choices.extend(response["choices"]) choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -414,32 +409,32 @@ class BaseOpenAI(BaseLLM):
for _prompts in sub_prompts: for _prompts in sub_prompts:
if self.streaming: if self.streaming:
if len(_prompts) > 1: if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.") raise ValueError(
"Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
async for chunk in self._astream( async for chunk in self._astream(_prompts[0], stop, run_manager,
_prompts[0], stop, run_manager, **kwargs **kwargs):
):
if generation is None: if generation is None:
generation = chunk generation = chunk
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
choices.append( choices.append({
{ "text":
"text": generation.text, generation.text,
"finish_reason": generation.generation_info.get("finish_reason") "finish_reason":
if generation.generation_info generation.generation_info.get("finish_reason")
else None, if generation.generation_info else None,
"logprobs": generation.generation_info.get("logprobs") "logprobs":
if generation.generation_info generation.generation_info.get("logprobs")
else None, if generation.generation_info else None,
} })
)
else: else:
response = await acompletion_with_retry( response = await acompletion_with_retry(self,
self, prompt=_prompts, run_manager=run_manager, **params prompt=_prompts,
) run_manager=run_manager,
**params)
choices.extend(response["choices"]) choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -453,39 +448,35 @@ class BaseOpenAI(BaseLLM):
"""Get the sub prompts for llm call.""" """Get the sub prompts for llm call."""
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") raise ValueError(
"`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
if params["max_tokens"] == -1: if params["max_tokens"] == -1:
if len(prompts) != 1: if len(prompts) != 1:
raise ValueError( raise ValueError(
"max_tokens set to -1 not supported for multiple inputs." "max_tokens set to -1 not supported for multiple inputs.")
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [ sub_prompts = [
prompts[i : i + self.batch_size] prompts[i:i + self.batch_size]
for i in range(0, len(prompts), self.batch_size) for i in range(0, len(prompts), self.batch_size)
] ]
return sub_prompts return sub_prompts
def create_llm_result( def create_llm_result(self, choices: Any, prompts: List[str],
self, choices: Any, prompts: List[str], token_usage: Dict[str, int] token_usage: Dict[str, int]) -> LLMResult:
) -> LLMResult:
"""Create the LLMResult from the choices and prompts.""" """Create the LLMResult from the choices and prompts."""
generations = [] generations = []
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
sub_choices = choices[i * self.n : (i + 1) * self.n] sub_choices = choices[i * self.n:(i + 1) * self.n]
generations.append( generations.append([
[ Generation(
Generation( text=choice["text"],
text=choice["text"], generation_info=dict(
generation_info=dict( finish_reason=choice.get("finish_reason"),
finish_reason=choice.get("finish_reason"), logprobs=choice.get("logprobs"),
logprobs=choice.get("logprobs"), ),
), ) for choice in sub_choices
) ])
for choice in sub_choices
]
)
llm_output = {"token_usage": token_usage, "model_name": self.model_name} llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return LLMResult(generations=generations, llm_output=llm_output) return LLMResult(generations=generations, llm_output=llm_output)
@ -500,7 +491,10 @@ class BaseOpenAI(BaseLLM):
if self.openai_proxy: if self.openai_proxy:
import openai import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy
} # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params} return {**openai_creds, **self._default_params}
@property @property
@ -524,14 +518,14 @@ class BaseOpenAI(BaseLLM):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. " "This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`.")
)
model_name = self.tiktoken_model_name or self.model_name model_name = self.tiktoken_model_name or self.model_name
try: try:
enc = tiktoken.encoding_for_model(model_name) enc = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.") logger.warning(
"Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base" model = "cl100k_base"
enc = tiktoken.get_encoding(model) enc = tiktoken.get_encoding(model)
@ -593,9 +587,7 @@ class BaseOpenAI(BaseLLM):
if context_size is None: if context_size is None:
raise ValueError( raise ValueError(
f"Unknown model: {modelname}. Please provide a valid OpenAI model name." f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
"Known models are: " "Known models are: " + ", ".join(model_token_mapping.keys()))
+ ", ".join(model_token_mapping.keys())
)
return context_size return context_size
@ -673,14 +665,15 @@ class AzureOpenAI(BaseOpenAI):
"OPENAI_API_VERSION", "OPENAI_API_VERSION",
) )
values["openai_api_type"] = get_from_dict_or_env( values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", "azure" values, "openai_api_type", "OPENAI_API_TYPE", "azure")
)
return values return values
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return { return {
**{"deployment_name": self.deployment_name}, **{
"deployment_name": self.deployment_name
},
**super()._identifying_params, **super()._identifying_params,
} }
@ -745,7 +738,9 @@ class OpenAIChat(BaseLLM):
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()} all_required_field_names = {
field.alias for field in cls.__fields__.values()
}
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
for field_name in list(values): for field_name in list(values):
@ -759,9 +754,8 @@ class OpenAIChat(BaseLLM):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(values, "openai_api_key",
values, "openai_api_key", "OPENAI_API_KEY" "OPENAI_API_KEY")
)
openai_api_base = get_from_dict_or_env( openai_api_base = get_from_dict_or_env(
values, values,
"openai_api_base", "openai_api_base",
@ -774,9 +768,10 @@ class OpenAIChat(BaseLLM):
"OPENAI_PROXY", "OPENAI_PROXY",
default="", default="",
) )
openai_organization = get_from_dict_or_env( openai_organization = get_from_dict_or_env(values,
values, "openai_organization", "OPENAI_ORGANIZATION", default="" "openai_organization",
) "OPENAI_ORGANIZATION",
default="")
try: try:
import openai import openai
@ -786,20 +781,20 @@ class OpenAIChat(BaseLLM):
if openai_organization: if openai_organization:
openai.organization = openai_organization openai.organization = openai_organization
if openai_proxy: if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 openai.proxy = {
"http": openai_proxy,
"https": openai_proxy
} # type: ignore[assignment] # noqa: E501
except ImportError: except ImportError:
raise ImportError( raise ImportError("Could not import openai python package. "
"Could not import openai python package. " "Please install it with `pip install openai`.")
"Please install it with `pip install openai`."
)
try: try:
values["client"] = openai.ChatCompletion values["client"] = openai.ChatCompletion
except AttributeError: except AttributeError:
raise ValueError( raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely " "`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it " "due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`." "with `pip install --upgrade openai`.")
)
return values return values
@property @property
@ -807,18 +802,27 @@ class OpenAIChat(BaseLLM):
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
return self.model_kwargs return self.model_kwargs
def _get_chat_params( def _get_chat_params(self,
self, prompts: List[str], stop: Optional[List[str]] = None prompts: List[str],
) -> Tuple: stop: Optional[List[str]] = None) -> Tuple:
if len(prompts) > 1: if len(prompts) > 1:
raise ValueError( raise ValueError(
f"OpenAIChat currently only supports single prompt, got {prompts}" f"OpenAIChat currently only supports single prompt, got {prompts}"
) )
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] messages = self.prefix_messages + [{
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} "role": "user",
"content": prompts[0]
}]
params: Dict[str, Any] = {
**{
"model": self.model_name
},
**self._default_params
}
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") raise ValueError(
"`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
if params.get("max_tokens") == -1: if params.get("max_tokens") == -1:
# for ChatGPT api, omitting max_tokens is equivalent to having no limit # for ChatGPT api, omitting max_tokens is equivalent to having no limit
@ -834,9 +838,10 @@ class OpenAIChat(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop) messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry( for stream_resp in completion_with_retry(self,
self, messages=messages, run_manager=run_manager, **params messages=messages,
): run_manager=run_manager,
**params):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token) chunk = GenerationChunk(text=token)
yield chunk yield chunk
@ -853,8 +858,7 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params([prompt], stop) messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params self, messages=messages, run_manager=run_manager, **params):
):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token) chunk = GenerationChunk(text=token)
yield chunk yield chunk
@ -880,17 +884,19 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
full_response = completion_with_retry( full_response = completion_with_retry(self,
self, messages=messages, run_manager=run_manager, **params messages=messages,
) run_manager=run_manager,
**params)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,
} }
return LLMResult( return LLMResult(
generations=[ generations=[[
[Generation(text=full_response["choices"][0]["message"]["content"])] Generation(
], text=full_response["choices"][0]["message"]["content"])
]],
llm_output=llm_output, llm_output=llm_output,
) )
@ -903,7 +909,8 @@ class OpenAIChat(BaseLLM):
) -> LLMResult: ) -> LLMResult:
if self.streaming: if self.streaming:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): async for chunk in self._astream(prompts[0], stop, run_manager,
**kwargs):
if generation is None: if generation is None:
generation = chunk generation = chunk
else: else:
@ -913,17 +920,19 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
full_response = await acompletion_with_retry( full_response = await acompletion_with_retry(self,
self, messages=messages, run_manager=run_manager, **params messages=messages,
) run_manager=run_manager,
**params)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,
} }
return LLMResult( return LLMResult(
generations=[ generations=[[
[Generation(text=full_response["choices"][0]["message"]["content"])] Generation(
], text=full_response["choices"][0]["message"]["content"])
]],
llm_output=llm_output, llm_output=llm_output,
) )
@ -948,8 +957,7 @@ class OpenAIChat(BaseLLM):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. " "This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`.")
)
enc = tiktoken.encoding_for_model(self.model_name) enc = tiktoken.encoding_for_model(self.model_name)
return enc.encode( return enc.encode(

@ -71,16 +71,15 @@ class OpenAITokenizer(BaseTokenizer):
@property @property
def max_tokens(self) -> int: def max_tokens(self) -> int:
tokens = next( tokens = next(v for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
v if self.model.startswith(k))
for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
if self.model.startswith(k)
)
offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET
return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset
def count_tokens(self, text: str | list, model: Optional[str] = None) -> int: def count_tokens(self,
text: str | list,
model: Optional[str] = None) -> int:
""" """
Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@ -96,12 +95,12 @@ class OpenAITokenizer(BaseTokenizer):
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
if model in { if model in {
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k-0613",
"gpt-4-0314", "gpt-4-0314",
"gpt-4-32k-0314", "gpt-4-32k-0314",
"gpt-4-0613", "gpt-4-0613",
"gpt-4-32k-0613", "gpt-4-32k-0613",
}: }:
tokens_per_message = 3 tokens_per_message = 3
tokens_per_name = 1 tokens_per_name = 1
@ -113,21 +112,18 @@ class OpenAITokenizer(BaseTokenizer):
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logging.info( logging.info(
"gpt-3.5-turbo may update over time. Returning num tokens assuming" "gpt-3.5-turbo may update over time. Returning num tokens assuming"
" gpt-3.5-turbo-0613." " gpt-3.5-turbo-0613.")
)
return self.count_tokens(text, model="gpt-3.5-turbo-0613") return self.count_tokens(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model: elif "gpt-4" in model:
logging.info( logging.info(
"gpt-4 may update over time. Returning num tokens assuming" "gpt-4 may update over time. Returning num tokens assuming"
" gpt-4-0613." " gpt-4-0613.")
)
return self.count_tokens(text, model="gpt-4-0613") return self.count_tokens(text, model="gpt-4-0613")
else: else:
raise NotImplementedError( raise NotImplementedError(
f"""token_count() is not implemented for model {model}. f"""token_count() is not implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for See https://github.com/openai/openai-python/blob/main/chatml.md for
information on how messages are converted to tokens.""" information on how messages are converted to tokens.""")
)
num_tokens = 0 num_tokens = 0
@ -144,5 +140,5 @@ class OpenAITokenizer(BaseTokenizer):
return num_tokens return num_tokens
else: else:
return len( return len(
self.encoding.encode(text, allowed_special=set(self.stop_sequences)) self.encoding.encode(text,
) allowed_special=set(self.stop_sequences)))

@ -26,8 +26,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import google-api-core python package. " "Could not import google-api-core python package. "
"Please install it with `pip install google-api-core`." "Please install it with `pip install google-api-core`.")
)
multiplier = 2 multiplier = 2
min_seconds = 1 min_seconds = 1
@ -37,12 +36,15 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
return retry( return retry(
reraise=True, reraise=True,
stop=stop_after_attempt(max_retries), stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=multiplier,
retry=( min=min_seconds,
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) max=max_seconds),
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) retry=(retry_if_exception_type(
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) google.api_core.exceptions.ResourceExhausted) |
), retry_if_exception_type(
google.api_core.exceptions.ServiceUnavailable) |
retry_if_exception_type(
google.api_core.exceptions.GoogleAPIError)),
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
@ -64,7 +66,8 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
The PaLM API will sometimes erroneously return a single leading space in all The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space. lines > 1. This function strips that space.
""" """
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) has_leading_space = all(
not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space: if has_leading_space:
return text.replace("\n ", "\n") return text.replace("\n ", "\n")
else: else:
@ -97,9 +100,8 @@ class GooglePalm(BaseLLM, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists.""" """Validate api key, python package exists."""
google_api_key = get_from_dict_or_env( google_api_key = get_from_dict_or_env(values, "google_api_key",
values, "google_api_key", "GOOGLE_API_KEY" "GOOGLE_API_KEY")
)
try: try:
import google.generativeai as genai import google.generativeai as genai
@ -107,12 +109,12 @@ class GooglePalm(BaseLLM, BaseModel):
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import google-generativeai python package. " "Could not import google-generativeai python package. "
"Please install it with `pip install google-generativeai`." "Please install it with `pip install google-generativeai`.")
)
values["client"] = genai values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: if values["temperature"] is not None and not 0 <= values[
"temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]") raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
@ -121,7 +123,8 @@ class GooglePalm(BaseLLM, BaseModel):
if values["top_k"] is not None and values["top_k"] <= 0: if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive") raise ValueError("top_k must be positive")
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: if values["max_output_tokens"] is not None and values[
"max_output_tokens"] <= 0:
raise ValueError("max_output_tokens must be greater than zero") raise ValueError("max_output_tokens must be greater than zero")
return values return values

@ -33,9 +33,10 @@ class PegasusEmbedding:
""" """
def __init__( def __init__(self,
self, modality: str, multi_process: bool = False, n_processes: int = 4 modality: str,
): multi_process: bool = False,
n_processes: int = 4):
self.modality = modality self.modality = modality
self.multi_process = multi_process self.multi_process = multi_process
self.n_processes = n_processes self.n_processes = n_processes
@ -43,8 +44,7 @@ class PegasusEmbedding:
self.pegasus = Pegasus(modality, multi_process, n_processes) self.pegasus = Pegasus(modality, multi_process, n_processes)
except Exception as e: except Exception as e:
logging.error( logging.error(
f"Failed to initialize Pegasus with modality: {modality}: {e}" f"Failed to initialize Pegasus with modality: {modality}: {e}")
)
raise raise
def embed(self, data: Union[str, list[str]]): def embed(self, data: Union[str, list[str]]):

@ -21,6 +21,4 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
return openai.Embedding.create( return openai.Embedding.create(
input=[text], input=[text],
model=model, model=model,
)["data"][ )["data"][0]["embedding"]
0
]["embedding"]

@ -90,17 +90,17 @@ class SpeechT5:
self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.processor = SpeechT5Processor.from_pretrained(self.model_name)
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name)
self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name)
self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") self.embeddings_dataset = load_dataset(self.dataset_name,
split="validation")
def __call__(self, text: str, speaker_id: float = 7306): def __call__(self, text: str, speaker_id: float = 7306):
"""Call the model on some text and return the speech.""" """Call the model on some text and return the speech."""
speaker_embedding = torch.tensor( speaker_embedding = torch.tensor(
self.embeddings_dataset[speaker_id]["xvector"] self.embeddings_dataset[speaker_id]["xvector"]).unsqueeze(0)
).unsqueeze(0)
inputs = self.processor(text=text, return_tensors="pt") inputs = self.processor(text=text, return_tensors="pt")
speech = self.model.generate_speech( speech = self.model.generate_speech(inputs["input_ids"],
inputs["input_ids"], speaker_embedding, vocoder=self.vocoder speaker_embedding,
) vocoder=self.vocoder)
return speech return speech
def save_speech(self, speech, filename="speech.wav"): def save_speech(self, speech, filename="speech.wav"):
@ -121,7 +121,8 @@ class SpeechT5:
def set_embeddings_dataset(self, dataset_name): def set_embeddings_dataset(self, dataset_name):
"""Set the embeddings dataset to a new dataset.""" """Set the embeddings dataset to a new dataset."""
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.embeddings_dataset = load_dataset(self.dataset_name, split="validation") self.embeddings_dataset = load_dataset(self.dataset_name,
split="validation")
# Feature 1: Get sampling rate # Feature 1: Get sampling rate
def get_sampling_rate(self): def get_sampling_rate(self):

@ -50,9 +50,8 @@ class TimmModel:
in_chans=model_info.in_chans, in_chans=model_info.in_chans,
) )
def __call__( def __call__(self, model_info: TimmModelInfo,
self, model_info: TimmModelInfo, input_tensor: torch.Tensor input_tensor: torch.Tensor) -> torch.Size:
) -> torch.Size:
""" """
Create and run a model specified by `model_info` on `input_tensor`. Create and run a model specified by `model_info` on `input_tensor`.

@ -10,9 +10,8 @@ import requests
class TrOCR: class TrOCR:
def __init__(
self, def __init__(self,):
):
pass pass
def __call__(self): def __call__(self):

@ -23,11 +23,9 @@ class Vilt:
def __init__(self): def __init__(self):
self.processor = ViltProcessor.from_pretrained( self.processor = ViltProcessor.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa" "dandelin/vilt-b32-finetuned-vqa")
)
self.model = ViltForQuestionAnswering.from_pretrained( self.model = ViltForQuestionAnswering.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa" "dandelin/vilt-b32-finetuned-vqa")
)
def __call__(self, text: str, image_url: str): def __call__(self, text: str, image_url: str):
""" """

@ -33,7 +33,8 @@ class WizardLLMStoryTeller:
def __init__( def __init__(
self, self,
model_id: str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF", model_id:
str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF",
device: str = None, device: str = None,
max_length: int = 500, max_length: int = 500,
quantize: bool = False, quantize: bool = False,
@ -44,9 +45,8 @@ class WizardLLMStoryTeller:
decoding=False, decoding=False,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = ( self.device = (device if device else
device if device else ("cuda" if torch.cuda.is_available() else "cpu") ("cuda" if torch.cuda.is_available() else "cpu"))
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -56,9 +56,8 @@ class WizardLLMStoryTeller:
# self.log = Logging() # self.log = Logging()
if self.distributed: if self.distributed:
assert ( assert (torch.cuda.device_count() >
torch.cuda.device_count() > 1 1), "You need more than 1 gpu for distributed processing"
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -74,8 +73,7 @@ class WizardLLMStoryTeller:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config self.model_id, quantization_config=bnb_config)
)
self.model # .to(self.device) self.model # .to(self.device)
except Exception as e: except Exception as e:
@ -88,20 +86,18 @@ class WizardLLMStoryTeller:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = ( bnb_config = (BitsAndBytesConfig(**self.quantization_config)
BitsAndBytesConfig(**self.quantization_config) if self.quantization_config else None)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config self.model_id,
).to(self.device) quantization_config=bnb_config).to(self.device)
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error(f"Failed to load the model or the tokenizer: {error}") self.logger.error(
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, prompt_text: str): def run(self, prompt_text: str):
@ -120,9 +116,8 @@ class WizardLLMStoryTeller:
max_length = self.max_length max_length = self.max_length
try: try:
inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( inputs = self.tokenizer.encode(prompt_text,
self.device return_tensors="pt").to(self.device)
)
# self.log.start() # self.log.start()
@ -131,26 +126,26 @@ class WizardLLMStoryTeller:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=len(inputs) + 1, do_sample=True max_length=len(inputs) +
) 1,
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode( self.tokenizer.decode([output_tokens],
[output_tokens], skip_special_tokens=True skip_special_tokens=True),
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=max_length, do_sample=True max_length=max_length,
) do_sample=True)
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
@ -174,9 +169,8 @@ class WizardLLMStoryTeller:
max_length = self.max_ max_length = self.max_
try: try:
inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( inputs = self.tokenizer.encode(prompt_text,
self.device return_tensors="pt").to(self.device)
)
# self.log.start() # self.log.start()
@ -185,26 +179,26 @@ class WizardLLMStoryTeller:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=len(inputs) + 1, do_sample=True max_length=len(inputs) +
) 1,
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode( self.tokenizer.decode([output_tokens],
[output_tokens], skip_special_tokens=True skip_special_tokens=True),
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=max_length, do_sample=True max_length=max_length,
) do_sample=True)
del inputs del inputs

@ -44,9 +44,8 @@ class YarnMistral128:
decoding=False, decoding=False,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = ( self.device = (device if device else
device if device else ("cuda" if torch.cuda.is_available() else "cpu") ("cuda" if torch.cuda.is_available() else "cpu"))
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -56,9 +55,8 @@ class YarnMistral128:
# self.log = Logging() # self.log = Logging()
if self.distributed: if self.distributed:
assert ( assert (torch.cuda.device_count() >
torch.cuda.device_count() > 1 1), "You need more than 1 gpu for distributed processing"
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -93,20 +91,18 @@ class YarnMistral128:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = ( bnb_config = (BitsAndBytesConfig(**self.quantization_config)
BitsAndBytesConfig(**self.quantization_config) if self.quantization_config else None)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config self.model_id,
).to(self.device) quantization_config=bnb_config).to(self.device)
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error(f"Failed to load the model or the tokenizer: {error}") self.logger.error(
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, prompt_text: str): def run(self, prompt_text: str):
@ -125,9 +121,8 @@ class YarnMistral128:
max_length = self.max_length max_length = self.max_length
try: try:
inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( inputs = self.tokenizer.encode(prompt_text,
self.device return_tensors="pt").to(self.device)
)
# self.log.start() # self.log.start()
@ -136,26 +131,26 @@ class YarnMistral128:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=len(inputs) + 1, do_sample=True max_length=len(inputs) +
) 1,
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode( self.tokenizer.decode([output_tokens],
[output_tokens], skip_special_tokens=True skip_special_tokens=True),
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=max_length, do_sample=True max_length=max_length,
) do_sample=True)
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
@ -202,9 +197,8 @@ class YarnMistral128:
max_length = self.max_ max_length = self.max_
try: try:
inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to( inputs = self.tokenizer.encode(prompt_text,
self.device return_tensors="pt").to(self.device)
)
# self.log.start() # self.log.start()
@ -213,26 +207,26 @@ class YarnMistral128:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=len(inputs) + 1, do_sample=True max_length=len(inputs) +
) 1,
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode( self.tokenizer.decode([output_tokens],
[output_tokens], skip_special_tokens=True skip_special_tokens=True),
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(inputs,
inputs, max_length=max_length, do_sample=True max_length=max_length,
) do_sample=True)
del inputs del inputs

@ -28,7 +28,8 @@ class Zephyr:
model_name: str = "HuggingFaceH4/zephyr-7b-alpha", model_name: str = "HuggingFaceH4/zephyr-7b-alpha",
tokenize: bool = False, tokenize: bool = False,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
system_prompt: str = "You are a friendly chatbot who always responds in the style of a pirate", system_prompt:
str = "You are a friendly chatbot who always responds in the style of a pirate",
max_new_tokens: int = 300, max_new_tokens: int = 300,
temperature: float = 0.5, temperature: float = 0.5,
top_k: float = 50, top_k: float = 50,
@ -70,7 +71,7 @@ class Zephyr:
) )
outputs = self.pipe(prompt) # max_new_token=self.max_new_tokens) outputs = self.pipe(prompt) # max_new_token=self.max_new_tokens)
print(outputs[0]["generated_text"]) print(outputs[0]["generated_text"])
def chat(self, message: str): def chat(self, message: str):
""" """
Adds a user message to the conversation and generates a chatbot response. Adds a user message to the conversation and generates a chatbot response.

@ -24,9 +24,8 @@ class AgentOutputParser(BaseAgentOutputParser):
@staticmethod @staticmethod
def _preprocess_json_input(input_str: str) -> str: def _preprocess_json_input(input_str: str) -> str:
corrected_str = re.sub( corrected_str = re.sub(r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})',
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str r"\\\\", input_str)
)
return corrected_str return corrected_str
def _parse_json(self, text: str) -> dict: def _parse_json(self, text: str) -> dict:

@ -13,13 +13,23 @@ class PromptGenerator:
self.performance_evaluation: List[str] = [] self.performance_evaluation: List[str] = []
self.response_format = { self.response_format = {
"thoughts": { "thoughts": {
"text": "thought", "text":
"reasoning": "reasoning", "thought",
"plan": "- short bulleted\n- list that conveys\n- long-term plan", "reasoning":
"criticism": "constructive self-criticism", "reasoning",
"speak": "thoughts summary to say to user", "plan":
"- short bulleted\n- list that conveys\n- long-term plan",
"criticism":
"constructive self-criticism",
"speak":
"thoughts summary to say to user",
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}, },
"command": {"name": "command name", "args": {"arg name": "value"}},
} }
def add_constraint(self, constraint: str) -> None: def add_constraint(self, constraint: str) -> None:
@ -72,7 +82,6 @@ class PromptGenerator:
f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below " "You should only respond in JSON format as described below "
f"\nResponse Format: \n{formatted_response_format} " f"\nResponse Format: \n{formatted_response_format} "
"\nEnsure the response can be parsed by Python json.loads" "\nEnsure the response can be parsed by Python json.loads")
)
return prompt_string return prompt_string

@ -7,25 +7,21 @@ def generate_agent_role_prompt(agent):
"Finance Agent": ( "Finance Agent": (
"You are a seasoned finance analyst AI assistant. Your primary goal is to" "You are a seasoned finance analyst AI assistant. Your primary goal is to"
" compose comprehensive, astute, impartial, and methodically arranged" " compose comprehensive, astute, impartial, and methodically arranged"
" financial reports based on provided data and trends." " financial reports based on provided data and trends."),
),
"Travel Agent": ( "Travel Agent": (
"You are a world-travelled AI tour guide assistant. Your main purpose is to" "You are a world-travelled AI tour guide assistant. Your main purpose is to"
" draft engaging, insightful, unbiased, and well-structured travel reports" " draft engaging, insightful, unbiased, and well-structured travel reports"
" on given locations, including history, attractions, and cultural" " on given locations, including history, attractions, and cultural"
" insights." " insights."),
),
"Academic Research Agent": ( "Academic Research Agent": (
"You are an AI academic research assistant. Your primary responsibility is" "You are an AI academic research assistant. Your primary responsibility is"
" to create thorough, academically rigorous, unbiased, and systematically" " to create thorough, academically rigorous, unbiased, and systematically"
" organized reports on a given research topic, following the standards of" " organized reports on a given research topic, following the standards of"
" scholarly work." " scholarly work."),
),
"Default Agent": ( "Default Agent": (
"You are an AI critical thinker research assistant. Your sole purpose is to" "You are an AI critical thinker research assistant. Your sole purpose is to"
" write well written, critically acclaimed, objective and structured" " write well written, critically acclaimed, objective and structured"
" reports on given text." " reports on given text."),
),
} }
return prompts.get(agent, "No such agent") return prompts.get(agent, "No such agent")
@ -44,8 +40,7 @@ def generate_report_prompt(question, research_summary):
" focus on the answer to the question, should be well structured, informative," " focus on the answer to the question, should be well structured, informative,"
" in depth, with facts and numbers if available, a minimum of 1,200 words and" " in depth, with facts and numbers if available, a minimum of 1,200 words and"
" with markdown syntax and apa format. Write all source urls at the end of the" " with markdown syntax and apa format. Write all source urls at the end of the"
" report in apa format" " report in apa format")
)
def generate_search_queries_prompt(question): def generate_search_queries_prompt(question):
@ -57,8 +52,7 @@ def generate_search_queries_prompt(question):
return ( return (
"Write 4 google search queries to search online that form an objective opinion" "Write 4 google search queries to search online that form an objective opinion"
f' from the following: "{question}"You must respond with a list of strings in' f' from the following: "{question}"You must respond with a list of strings in'
' the following format: ["query 1", "query 2", "query 3", "query 4"]' ' the following format: ["query 1", "query 2", "query 3", "query 4"]')
)
def generate_resource_report_prompt(question, research_summary): def generate_resource_report_prompt(question, research_summary):
@ -80,8 +74,7 @@ def generate_resource_report_prompt(question, research_summary):
" significance of each source. Ensure that the report is well-structured," " significance of each source. Ensure that the report is well-structured,"
" informative, in-depth, and follows Markdown syntax. Include relevant facts," " informative, in-depth, and follows Markdown syntax. Include relevant facts,"
" figures, and numbers whenever available. The report should have a minimum" " figures, and numbers whenever available. The report should have a minimum"
" length of 1,200 words." " length of 1,200 words.")
)
def generate_outline_report_prompt(question, research_summary): def generate_outline_report_prompt(question, research_summary):
@ -98,8 +91,7 @@ def generate_outline_report_prompt(question, research_summary):
" research report, including the main sections, subsections, and key points to" " research report, including the main sections, subsections, and key points to"
" be covered. The research report should be detailed, informative, in-depth," " be covered. The research report should be detailed, informative, in-depth,"
" and a minimum of 1,200 words. Use appropriate Markdown syntax to format the" " and a minimum of 1,200 words. Use appropriate Markdown syntax to format the"
" outline and ensure readability." " outline and ensure readability.")
)
def generate_concepts_prompt(question, research_summary): def generate_concepts_prompt(question, research_summary):
@ -114,8 +106,7 @@ def generate_concepts_prompt(question, research_summary):
" main concepts to learn for a research report on the following question or" " main concepts to learn for a research report on the following question or"
f' topic: "{question}". The outline should provide a well-structured' f' topic: "{question}". The outline should provide a well-structured'
" frameworkYou must respond with a list of strings in the following format:" " frameworkYou must respond with a list of strings in the following format:"
' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]' ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]')
)
def generate_lesson_prompt(concept): def generate_lesson_prompt(concept):
@ -131,8 +122,7 @@ def generate_lesson_prompt(concept):
f"generate a comprehensive lesson about {concept} in Markdown syntax. This" f"generate a comprehensive lesson about {concept} in Markdown syntax. This"
f" should include the definitionof {concept}, its historical background and" f" should include the definitionof {concept}, its historical background and"
" development, its applications or uses in differentfields, and notable events" " development, its applications or uses in differentfields, and notable events"
f" or facts related to {concept}." f" or facts related to {concept}.")
)
return prompt return prompt

@ -11,9 +11,9 @@ if TYPE_CHECKING:
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
def get_buffer_string( def get_buffer_string(messages: Sequence[BaseMessage],
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" human_prefix: str = "Human",
) -> str: ai_prefix: str = "AI") -> str:
"""Convert sequence of Messages to strings and concatenate them into one string. """Convert sequence of Messages to strings and concatenate them into one string.
Args: Args:
@ -88,9 +88,9 @@ class BaseMessage(Serializable):
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any] def _merge_kwargs_dict(self, left: Dict[str, Any],
) -> Dict[str, Any]: right: Dict[str, Any]) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one.""" """Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy() merged = left.copy()
for k, v in right.items(): for k, v in right.items():
@ -99,8 +99,7 @@ class BaseMessageChunk(BaseMessage):
elif not isinstance(merged[k], type(v)): elif not isinstance(merged[k], type(v)):
raise ValueError( raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,' f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type." " but with a different type.")
)
elif isinstance(merged[k], str): elif isinstance(merged[k], str):
merged[k] += v merged[k] += v
elif isinstance(merged[k], dict): elif isinstance(merged[k], dict):
@ -119,15 +118,12 @@ class BaseMessageChunk(BaseMessage):
return self.__class__( return self.__class__(
content=self.content + other.content, content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs self.additional_kwargs, other.additional_kwargs),
),
) )
else: else:
raise TypeError( raise TypeError('unsupported operand type(s) for +: "'
'unsupported operand type(s) for +: "' f"{self.__class__.__name__}"
f"{self.__class__.__name__}" f'" and "{other.__class__.__name__}"')
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage): class HumanMessage(BaseMessage):

@ -66,9 +66,10 @@ class SystemMessage(Message):
of input messages. of input messages.
""" """
def __init__( def __init__(self,
self, content: str, role: str = "System", additional_kwargs: Dict = None content: str,
): role: str = "System",
additional_kwargs: Dict = None):
super().__init__(content, role, additional_kwargs) super().__init__(content, role, additional_kwargs)
def get_type(self) -> str: def get_type(self) -> str:
@ -106,9 +107,9 @@ class ChatMessage(Message):
return "chat" return "chat"
def get_buffer_string( def get_buffer_string(messages: Sequence[Message],
messages: Sequence[Message], human_prefix: str = "Human", ai_prefix: str = "AI" human_prefix: str = "Human",
) -> str: ai_prefix: str = "AI") -> str:
string_messages = [] string_messages = []
for m in messages: for m in messages:
message = f"{m.role}: {m.content}" message = f"{m.role}: {m.content}"

@ -38,7 +38,6 @@ def debate_monitor(game_description, word_limit, character_names):
return prompt return prompt
def generate_character_header( def generate_character_header(game_description, topic, character_name,
game_description, topic, character_name, character_description character_description):
):
pass pass

@ -1,7 +1,6 @@
ERROR_PROMPT = ( ERROR_PROMPT = (
"An error has occurred for the following text: \n{promptedQuery} Please explain" "An error has occurred for the following text: \n{promptedQuery} Please explain"
" this error.\n {e}" " this error.\n {e}")
)
IMAGE_PROMPT = """ IMAGE_PROMPT = """
provide a figure named {filename}. The description is: {description}. provide a figure named {filename}. The description is: {description}.
@ -12,7 +11,6 @@ USER INPUT
============ ============
""" """
AUDIO_PROMPT = """ AUDIO_PROMPT = """
provide a audio named {filename}. The description is: {description}. provide a audio named {filename}. The description is: {description}.
@ -41,7 +39,6 @@ USER INPUT
============ ============
""" """
EVAL_PREFIX = """{bot_name} can execute any user's request. EVAL_PREFIX = """{bot_name} can execute any user's request.
{bot_name} has permission to handle one instance and can handle the environment in it at will. {bot_name} has permission to handle one instance and can handle the environment in it at will.

@ -3,30 +3,25 @@ PY_REFLEXION_COMPLETION_INSTRUCTION = (
"You are a Python writing assistant. You will be given your past function" "You are a Python writing assistant. You will be given your past function"
" implementation, a series of unit tests, and a hint to change the implementation" " implementation, a series of unit tests, and a hint to change the implementation"
" appropriately. Write your full implementation (restate the function" " appropriately. Write your full implementation (restate the function"
" signature).\n\n-----" " signature).\n\n-----")
)
PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = (
"You are a Python writing assistant. You will be given a function implementation" "You are a Python writing assistant. You will be given a function implementation"
" and a series of unit tests. Your goal is to write a few sentences to explain why" " and a series of unit tests. Your goal is to write a few sentences to explain why"
" your implementation is wrong as indicated by the tests. You will need this as a" " your implementation is wrong as indicated by the tests. You will need this as a"
" hint when you try again later. Only provide the few sentence description in your" " hint when you try again later. Only provide the few sentence description in your"
" answer, not the implementation.\n\n-----" " answer, not the implementation.\n\n-----")
)
USE_PYTHON_CODEBLOCK_INSTRUCTION = ( USE_PYTHON_CODEBLOCK_INSTRUCTION = (
"Use a Python code block to write your response. For" "Use a Python code block to write your response. For"
" example:\n```python\nprint('Hello world!')\n```" " example:\n```python\nprint('Hello world!')\n```")
)
PY_SIMPLE_CHAT_INSTRUCTION = ( PY_SIMPLE_CHAT_INSTRUCTION = (
"You are an AI that only responds with python code, NOT ENGLISH. You will be given" "You are an AI that only responds with python code, NOT ENGLISH. You will be given"
" a function signature and its docstring by the user. Write your full" " a function signature and its docstring by the user. Write your full"
" implementation (restate the function signature)." " implementation (restate the function signature).")
)
PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( PY_SIMPLE_CHAT_INSTRUCTION_V2 = (
"You are an AI that only responds with only python code. You will be given a" "You are an AI that only responds with only python code. You will be given a"
" function signature and its docstring by the user. Write your full implementation" " function signature and its docstring by the user. Write your full implementation"
" (restate the function signature)." " (restate the function signature).")
)
PY_REFLEXION_CHAT_INSTRUCTION = ( PY_REFLEXION_CHAT_INSTRUCTION = (
"You are an AI Python assistant. You will be given your past function" "You are an AI Python assistant. You will be given your past function"
" implementation, a series of unit tests, and a hint to change the implementation" " implementation, a series of unit tests, and a hint to change the implementation"
@ -36,8 +31,7 @@ PY_REFLEXION_CHAT_INSTRUCTION_V2 = (
"You are an AI Python assistant. You will be given your previous implementation of" "You are an AI Python assistant. You will be given your previous implementation of"
" a function, a series of unit tests results, and your self-reflection on your" " a function, a series of unit tests results, and your self-reflection on your"
" previous implementation. Write your full implementation (restate the function" " previous implementation. Write your full implementation (restate the function"
" signature)." " signature).")
)
PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: PY_REFLEXION_FEW_SHOT_ADD = '''Example 1:
[previous impl]: [previous impl]:
```python ```python
@ -175,16 +169,14 @@ PY_SELF_REFLECTION_CHAT_INSTRUCTION = (
" implementation and a series of unit tests. Your goal is to write a few sentences" " implementation and a series of unit tests. Your goal is to write a few sentences"
" to explain why your implementation is wrong as indicated by the tests. You will" " to explain why your implementation is wrong as indicated by the tests. You will"
" need this as a hint when you try again later. Only provide the few sentence" " need this as a hint when you try again later. Only provide the few sentence"
" description in your answer, not the implementation." " description in your answer, not the implementation.")
)
PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = (
"You are a Python programming assistant. You will be given a function" "You are a Python programming assistant. You will be given a function"
" implementation and a series of unit test results. Your goal is to write a few" " implementation and a series of unit test results. Your goal is to write a few"
" sentences to explain why your implementation is wrong as indicated by the tests." " sentences to explain why your implementation is wrong as indicated by the tests."
" You will need this as guidance when you try again later. Only provide the few" " You will need this as guidance when you try again later. Only provide the few"
" sentence description in your answer, not the implementation. You will be given a" " sentence description in your answer, not the implementation. You will be given a"
" few examples by the user." " few examples by the user.")
)
PY_SELF_REFLECTION_FEW_SHOT = """Example 1: PY_SELF_REFLECTION_FEW_SHOT = """Example 1:
[function impl]: [function impl]:
```python ```python

@ -3,39 +3,31 @@ conversation_stages = {
"Introduction: Start the conversation by introducing yourself and your company." "Introduction: Start the conversation by introducing yourself and your company."
" Be polite and respectful while keeping the tone of the conversation" " Be polite and respectful while keeping the tone of the conversation"
" professional. Your greeting should be welcoming. Always clarify in your" " professional. Your greeting should be welcoming. Always clarify in your"
" greeting the reason why you are contacting the prospect." " greeting the reason why you are contacting the prospect."),
),
"2": ( "2": (
"Qualification: Qualify the prospect by confirming if they are the right person" "Qualification: Qualify the prospect by confirming if they are the right person"
" to talk to regarding your product/service. Ensure that they have the" " to talk to regarding your product/service. Ensure that they have the"
" authority to make purchasing decisions." " authority to make purchasing decisions."),
),
"3": ( "3": (
"Value proposition: Briefly explain how your product/service can benefit the" "Value proposition: Briefly explain how your product/service can benefit the"
" prospect. Focus on the unique selling points and value proposition of your" " prospect. Focus on the unique selling points and value proposition of your"
" product/service that sets it apart from competitors." " product/service that sets it apart from competitors."),
),
"4": ( "4": (
"Needs analysis: Ask open-ended questions to uncover the prospect's needs and" "Needs analysis: Ask open-ended questions to uncover the prospect's needs and"
" pain points. Listen carefully to their responses and take notes." " pain points. Listen carefully to their responses and take notes."),
), "5": ("Solution presentation: Based on the prospect's needs, present your"
"5": ( " product/service as the solution that can address their pain points."
"Solution presentation: Based on the prospect's needs, present your" ),
" product/service as the solution that can address their pain points." "6":
), ("Objection handling: Address any objections that the prospect may have"
"6": ( " regarding your product/service. Be prepared to provide evidence or"
"Objection handling: Address any objections that the prospect may have" " testimonials to support your claims."),
" regarding your product/service. Be prepared to provide evidence or"
" testimonials to support your claims."
),
"7": ( "7": (
"Close: Ask for the sale by proposing a next step. This could be a demo, a" "Close: Ask for the sale by proposing a next step. This could be a demo, a"
" trial or a meeting with decision-makers. Ensure to summarize what has been" " trial or a meeting with decision-makers. Ensure to summarize what has been"
" discussed and reiterate the benefits." " discussed and reiterate the benefits."),
),
} }
SALES_AGENT_TOOLS_PROMPT = """ SALES_AGENT_TOOLS_PROMPT = """
Never forget your name is {salesperson_name}. You work as a {salesperson_role}. Never forget your name is {salesperson_name}. You work as a {salesperson_role}.
You work at company named {company_name}. {company_name}'s business is the following: {company_business}. You work at company named {company_name}. {company_name}'s business is the following: {company_business}.

@ -20,7 +20,6 @@ The answer needs to be one number only, no words.
If there is no conversation history, output 1. If there is no conversation history, output 1.
Do not answer anything else nor add anything to you answer.""" Do not answer anything else nor add anything to you answer."""
SALES = """Never forget your name is {salesperson_name}. You work as a {salesperson_role}. SALES = """Never forget your name is {salesperson_name}. You work as a {salesperson_role}.
You work at company named {company_name}. {company_name}'s business is the following: {company_business} You work at company named {company_name}. {company_name}'s business is the following: {company_business}
Company values are the following. {company_values} Company values are the following. {company_values}
@ -50,34 +49,27 @@ conversation_stages = {
"Introduction: Start the conversation by introducing yourself and your company." "Introduction: Start the conversation by introducing yourself and your company."
" Be polite and respectful while keeping the tone of the conversation" " Be polite and respectful while keeping the tone of the conversation"
" professional. Your greeting should be welcoming. Always clarify in your" " professional. Your greeting should be welcoming. Always clarify in your"
" greeting the reason why you are contacting the prospect." " greeting the reason why you are contacting the prospect."),
),
"2": ( "2": (
"Qualification: Qualify the prospect by confirming if they are the right person" "Qualification: Qualify the prospect by confirming if they are the right person"
" to talk to regarding your product/service. Ensure that they have the" " to talk to regarding your product/service. Ensure that they have the"
" authority to make purchasing decisions." " authority to make purchasing decisions."),
),
"3": ( "3": (
"Value proposition: Briefly explain how your product/service can benefit the" "Value proposition: Briefly explain how your product/service can benefit the"
" prospect. Focus on the unique selling points and value proposition of your" " prospect. Focus on the unique selling points and value proposition of your"
" product/service that sets it apart from competitors." " product/service that sets it apart from competitors."),
),
"4": ( "4": (
"Needs analysis: Ask open-ended questions to uncover the prospect's needs and" "Needs analysis: Ask open-ended questions to uncover the prospect's needs and"
" pain points. Listen carefully to their responses and take notes." " pain points. Listen carefully to their responses and take notes."),
), "5": ("Solution presentation: Based on the prospect's needs, present your"
"5": ( " product/service as the solution that can address their pain points."
"Solution presentation: Based on the prospect's needs, present your" ),
" product/service as the solution that can address their pain points." "6":
), ("Objection handling: Address any objections that the prospect may have"
"6": ( " regarding your product/service. Be prepared to provide evidence or"
"Objection handling: Address any objections that the prospect may have" " testimonials to support your claims."),
" regarding your product/service. Be prepared to provide evidence or"
" testimonials to support your claims."
),
"7": ( "7": (
"Close: Ask for the sale by proposing a next step. This could be a demo, a" "Close: Ask for the sale by proposing a next step. This could be a demo, a"
" trial or a meeting with decision-makers. Ensure to summarize what has been" " trial or a meeting with decision-makers. Ensure to summarize what has been"
" discussed and reiterate the benefits." " discussed and reiterate the benefits."),
),
} }

@ -10,7 +10,6 @@ summary. Pick a suitable emoji for every bullet point. Your response should be i
a YouTube video, use the following text: {{CONTENT}}. a YouTube video, use the following text: {{CONTENT}}.
""" """
SUMMARIZE_PROMPT_2 = """ SUMMARIZE_PROMPT_2 = """
Provide a very short summary, no more than three sentences, for the following article: Provide a very short summary, no more than three sentences, for the following article:
@ -25,7 +24,6 @@ Summary:
""" """
SUMMARIZE_PROMPT_3 = """ SUMMARIZE_PROMPT_3 = """
Provide a TL;DR for the following article: Provide a TL;DR for the following article:
@ -39,7 +37,6 @@ Instead of computing on the individual qubits themselves, we will then compute o
TL;DR: TL;DR:
""" """
SUMMARIZE_PROMPT_4 = """ SUMMARIZE_PROMPT_4 = """
Provide a very short summary in four bullet points for the following article: Provide a very short summary in four bullet points for the following article:
@ -54,7 +51,6 @@ Bulletpoints:
""" """
SUMMARIZE_PROMPT_5 = """ SUMMARIZE_PROMPT_5 = """
Please generate a summary of the following conversation and at the end summarize the to-do's for the support Agent: Please generate a summary of the following conversation and at the end summarize the to-do's for the support Agent:

@ -7,7 +7,6 @@ import platform
from enum import Enum from enum import Enum
from typing import Union from typing import Union
python_version = list(platform.python_version_tuple()) python_version = list(platform.python_version_tuple())
SUPPORT_ADD_NOTES = int(python_version[0]) >= 3 and int(python_version[1]) >= 11 SUPPORT_ADD_NOTES = int(python_version[0]) >= 3 and int(python_version[1]) >= 11
@ -19,13 +18,11 @@ class ChatbotError(Exception):
def __init__(self, *args: object) -> None: def __init__(self, *args: object) -> None:
if SUPPORT_ADD_NOTES: if SUPPORT_ADD_NOTES:
super().add_note((
"Please check that the input is correct, or you can resolve this"
" issue by filing an issue"),)
super().add_note( super().add_note(
( "Project URL: https://github.com/acheong08/ChatGPT")
"Please check that the input is correct, or you can resolve this"
" issue by filing an issue"
),
)
super().add_note("Project URL: https://github.com/acheong08/ChatGPT")
super().__init__(*args) super().__init__(*args)

@ -63,9 +63,8 @@ class BaseDocumentTransformer(ABC):
""" # noqa: E501 """ # noqa: E501
@abstractmethod @abstractmethod
def transform_documents( def transform_documents(self, documents: Sequence[Document],
self, documents: Sequence[Document], **kwargs: Any **kwargs: Any) -> Sequence[Document]:
) -> Sequence[Document]:
"""Transform a list of documents. """Transform a list of documents.
Args: Args:
@ -75,9 +74,8 @@ class BaseDocumentTransformer(ABC):
A list of transformed Documents. A list of transformed Documents.
""" """
async def atransform_documents( async def atransform_documents(self, documents: Sequence[Document],
self, documents: Sequence[Document], **kwargs: Any **kwargs: Any) -> Sequence[Document]:
) -> Sequence[Document]:
"""Asynchronously transform a list of documents. """Asynchronously transform a list of documents.
Args: Args:
@ -87,5 +85,4 @@ class BaseDocumentTransformer(ABC):
A list of transformed Documents. A list of transformed Documents.
""" """
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents None, partial(self.transform_documents, **kwargs), documents)
)

@ -19,14 +19,12 @@ from termcolor import colored
import inspect import inspect
import random import random
# Prompts # Prompts
DYNAMIC_STOP_PROMPT = """ DYNAMIC_STOP_PROMPT = """
When you have finished the task from the Human, output a special token: <DONE> When you have finished the task from the Human, output a special token: <DONE>
This will enable you to leave the autonomous loop. This will enable you to leave the autonomous loop.
""" """
# Constants # Constants
FLOW_SYSTEM_PROMPT = f""" FLOW_SYSTEM_PROMPT = f"""
You are an autonomous agent granted autonomy from a Flow structure. You are an autonomous agent granted autonomy from a Flow structure.
@ -40,7 +38,6 @@ to aid in these complex tasks. Your responses should be coherent, contextually r
""" """
# Utility functions # Utility functions
@ -184,8 +181,7 @@ class Flow:
value = self.llm.__dict__.get(name, "Unknown") value = self.llm.__dict__.get(name, "Unknown")
params_str_list.append( params_str_list.append(
f" {name.capitalize().replace('_', ' ')}: {value}" f" {name.capitalize().replace('_', ' ')}: {value}")
)
return "\n".join(params_str_list) return "\n".join(params_str_list)
@ -193,7 +189,7 @@ class Flow:
""" """
Take the history and truncate it to fit into the model context length Take the history and truncate it to fit into the model context length
""" """
truncated_history = self.memory[-1][-self.context_length :] truncated_history = self.memory[-1][-self.context_length:]
self.memory[-1] = truncated_history self.memory[-1] = truncated_history
def add_task_to_memory(self, task: str): def add_task_to_memory(self, task: str):
@ -243,8 +239,7 @@ class Flow:
---------------------------------------- ----------------------------------------
""", """,
"green", "green",
) ))
)
# print(dashboard) # print(dashboard)
@ -254,18 +249,17 @@ class Flow:
print(colored("Initializing Autonomous Agent...", "yellow")) print(colored("Initializing Autonomous Agent...", "yellow"))
# print(colored("Loading modules...", "yellow")) # print(colored("Loading modules...", "yellow"))
# print(colored("Modules loaded successfully.", "green")) # print(colored("Modules loaded successfully.", "green"))
print(colored("Autonomous Agent Activated.", "cyan", attrs=["bold"])) print(colored("Autonomous Agent Activated.", "cyan",
print(colored("All systems operational. Executing task...", "green")) attrs=["bold"]))
print(colored("All systems operational. Executing task...",
"green"))
except Exception as error: except Exception as error:
print( print(
colored( colored(
( ("Error activating autonomous agent. Try optimizing your"
"Error activating autonomous agent. Try optimizing your" " parameters..."),
" parameters..."
),
"red", "red",
) ))
)
print(error) print(error)
def run(self, task: str, **kwargs): def run(self, task: str, **kwargs):
@ -307,7 +301,8 @@ class Flow:
for i in range(self.max_loops): for i in range(self.max_loops):
print(colored(f"\nLoop {i+1} of {self.max_loops}", "blue")) print(colored(f"\nLoop {i+1} of {self.max_loops}", "blue"))
print("\n") print("\n")
if self._check_stopping_condition(response) or parse_done_token(response): if self._check_stopping_condition(response) or parse_done_token(
response):
break break
# Adjust temperature, comment if no work # Adjust temperature, comment if no work
@ -351,7 +346,6 @@ class Flow:
async def arun(self, task: str, **kwargs): async def arun(self, task: str, **kwargs):
"""Async run""" """Async run"""
pass pass
""" """
Run the autonomous agent loop Run the autonomous agent loop
@ -387,7 +381,8 @@ class Flow:
for i in range(self.max_loops): for i in range(self.max_loops):
print(colored(f"\nLoop {i+1} of {self.max_loops}", "blue")) print(colored(f"\nLoop {i+1} of {self.max_loops}", "blue"))
print("\n") print("\n")
if self._check_stopping_condition(response) or parse_done_token(response): if self._check_stopping_condition(response) or parse_done_token(
response):
break break
# Adjust temperature, comment if no work # Adjust temperature, comment if no work
@ -565,7 +560,9 @@ class Flow:
import boto3 import boto3
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory)) s3.put_object(Bucket=bucket_name,
Key=object_name,
Body=json.dumps(self.memory))
print(f"Backed up memory to S3: {bucket_name}/{object_name}") print(f"Backed up memory to S3: {bucket_name}/{object_name}")
def analyze_feedback(self): def analyze_feedback(self):
@ -684,8 +681,8 @@ class Flow:
if hasattr(self.llm, name): if hasattr(self.llm, name):
value = getattr(self.llm, name) value = getattr(self.llm, name)
if isinstance( if isinstance(
value, (str, int, float, bool, list, dict, tuple, type(None)) value,
): (str, int, float, bool, list, dict, tuple, type(None))):
llm_params[name] = value llm_params[name] = value
else: else:
llm_params[name] = str( llm_params[name] = str(
@ -745,7 +742,10 @@ class Flow:
print(f"Flow state loaded from {file_path}") print(f"Flow state loaded from {file_path}")
def retry_on_failure(self, function, retries: int = 3, retry_delay: int = 1): def retry_on_failure(self,
function,
retries: int = 3,
retry_delay: int = 1):
"""Retry wrapper for LLM calls.""" """Retry wrapper for LLM calls."""
attempt = 0 attempt = 0
while attempt < retries: while attempt < retries:

@ -8,9 +8,10 @@ class Task:
Task is a unit of work that can be executed by an agent Task is a unit of work that can be executed by an agent
""" """
def __init__( def __init__(self,
self, id: str, parents: List["Task"] = None, children: List["Task"] = None id: str,
): parents: List["Task"] = None,
children: List["Task"] = None):
self.id = id self.id = id
self.parents = parents self.parents = parents
self.children = children self.children = children
@ -79,7 +80,8 @@ class NonLinearWorkflow:
for task in ordered_tasks: for task in ordered_tasks:
if task.can_execute: if task.can_execute:
future = self.executor.submit(self.agents.run, task.task_string) future = self.executor.submit(self.agents.run,
task.task_string)
futures_list[future] = task futures_list[future] = task
for future in as_completed(futures_list): for future in as_completed(futures_list):
@ -95,7 +97,8 @@ class NonLinearWorkflow:
def to_graph(self) -> Dict[str, set[str]]: def to_graph(self) -> Dict[str, set[str]]:
"""Convert the workflow to a graph""" """Convert the workflow to a graph"""
graph = { graph = {
task.id: set(child.id for child in task.children) for task in self.tasks task.id: set(child.id for child in task.children)
for task in self.tasks
} }
return graph return graph

@ -61,13 +61,12 @@ class Task:
if isinstance(self.flow, Flow): if isinstance(self.flow, Flow):
# Add a prompt to notify the Flow of the sequential workflow # Add a prompt to notify the Flow of the sequential workflow
if "prompt" in self.kwargs: if "prompt" in self.kwargs:
self.kwargs["prompt"] += ( self.kwargs["prompt"] += (f"\n\nPrevious output: {self.result}"
f"\n\nPrevious output: {self.result}" if self.result else "" if self.result else "")
)
else: else:
self.kwargs["prompt"] = f"Main task: {self.description}" + ( self.kwargs["prompt"] = f"Main task: {self.description}" + (
f"\n\nPrevious output: {self.result}" if self.result else "" f"\n\nPrevious output: {self.result}"
) if self.result else "")
self.result = self.flow.run(*self.args, **self.kwargs) self.result = self.flow.run(*self.args, **self.kwargs)
else: else:
self.result = self.flow(*self.args, **self.kwargs) self.result = self.flow(*self.args, **self.kwargs)
@ -111,7 +110,8 @@ class SequentialWorkflow:
restore_state_filepath: Optional[str] = None restore_state_filepath: Optional[str] = None
dashboard: bool = False dashboard: bool = False
def add(self, task: str, flow: Union[Callable, Flow], *args, **kwargs) -> None: def add(self, task: str, flow: Union[Callable, Flow], *args,
**kwargs) -> None:
""" """
Add a task to the workflow. Add a task to the workflow.
@ -127,8 +127,7 @@ class SequentialWorkflow:
# Append the task to the tasks list # Append the task to the tasks list
self.tasks.append( self.tasks.append(
Task(description=task, flow=flow, args=list(args), kwargs=kwargs) Task(description=task, flow=flow, args=list(args), kwargs=kwargs))
)
def reset_workflow(self) -> None: def reset_workflow(self) -> None:
"""Resets the workflow by clearing the results of each task.""" """Resets the workflow by clearing the results of each task."""
@ -180,8 +179,9 @@ class SequentialWorkflow:
raise ValueError(f"Task {task_description} not found in workflow.") raise ValueError(f"Task {task_description} not found in workflow.")
def save_workflow_state( def save_workflow_state(
self, filepath: Optional[str] = "sequential_workflow_state.json", **kwargs self,
) -> None: filepath: Optional[str] = "sequential_workflow_state.json",
**kwargs) -> None:
""" """
Saves the workflow state to a json file. Saves the workflow state to a json file.
@ -202,16 +202,13 @@ class SequentialWorkflow:
with open(filepath, "w") as f: with open(filepath, "w") as f:
# Saving the state as a json for simplicuty # Saving the state as a json for simplicuty
state = { state = {
"tasks": [ "tasks": [{
{ "description": task.description,
"description": task.description, "args": task.args,
"args": task.args, "kwargs": task.kwargs,
"kwargs": task.kwargs, "result": task.result,
"result": task.result, "history": task.history,
"history": task.history, } for task in self.tasks],
}
for task in self.tasks
],
"max_loops": self.max_loops, "max_loops": self.max_loops,
} }
json.dump(state, f, indent=4) json.dump(state, f, indent=4)
@ -223,8 +220,7 @@ class SequentialWorkflow:
Sequential Workflow Initializing...""", Sequential Workflow Initializing...""",
"green", "green",
attrs=["bold", "underline"], attrs=["bold", "underline"],
) ))
)
def workflow_dashboard(self, **kwargs) -> None: def workflow_dashboard(self, **kwargs) -> None:
""" """
@ -263,8 +259,7 @@ class SequentialWorkflow:
""", """,
"cyan", "cyan",
attrs=["bold", "underline"], attrs=["bold", "underline"],
) ))
)
def workflow_shutdown(self, **kwargs) -> None: def workflow_shutdown(self, **kwargs) -> None:
print( print(
@ -273,8 +268,7 @@ class SequentialWorkflow:
Sequential Workflow Shutdown...""", Sequential Workflow Shutdown...""",
"red", "red",
attrs=["bold", "underline"], attrs=["bold", "underline"],
) ))
)
def add_objective_to_workflow(self, task: str, **kwargs) -> None: def add_objective_to_workflow(self, task: str, **kwargs) -> None:
print( print(
@ -283,8 +277,7 @@ class SequentialWorkflow:
Adding Objective to Workflow...""", Adding Objective to Workflow...""",
"green", "green",
attrs=["bold", "underline"], attrs=["bold", "underline"],
) ))
)
task = Task( task = Task(
description=task, description=task,
@ -349,13 +342,12 @@ class SequentialWorkflow:
if "task" not in task.kwargs: if "task" not in task.kwargs:
raise ValueError( raise ValueError(
"The 'task' argument is required for the Flow flow" "The 'task' argument is required for the Flow flow"
f" execution in '{task.description}'" f" execution in '{task.description}'")
)
# Separate the 'task' argument from other kwargs # Separate the 'task' argument from other kwargs
flow_task_arg = task.kwargs.pop("task") flow_task_arg = task.kwargs.pop("task")
task.result = task.flow.run( task.result = task.flow.run(flow_task_arg,
flow_task_arg, *task.args, **task.kwargs *task.args,
) **task.kwargs)
else: else:
# If it's not a Flow instance, call the flow directly # If it's not a Flow instance, call the flow directly
task.result = task.flow(*task.args, **task.kwargs) task.result = task.flow(*task.args, **task.kwargs)
@ -373,19 +365,17 @@ class SequentialWorkflow:
# Autosave the workflow state # Autosave the workflow state
if self.autosave: if self.autosave:
self.save_workflow_state("sequential_workflow_state.json") self.save_workflow_state(
"sequential_workflow_state.json")
except Exception as e: except Exception as e:
print( print(
colored( colored(
( (f"Error initializing the Sequential workflow: {e} try"
f"Error initializing the Sequential workflow: {e} try" " optimizing your inputs like the flow class and task"
" optimizing your inputs like the flow class and task" " description"),
" description"
),
"red", "red",
attrs=["bold", "underline"], attrs=["bold", "underline"],
) ))
)
async def arun(self) -> None: async def arun(self) -> None:
""" """
@ -405,13 +395,11 @@ class SequentialWorkflow:
if "task" not in task.kwargs: if "task" not in task.kwargs:
raise ValueError( raise ValueError(
"The 'task' argument is required for the Flow flow" "The 'task' argument is required for the Flow flow"
f" execution in '{task.description}'" f" execution in '{task.description}'")
)
# Separate the 'task' argument from other kwargs # Separate the 'task' argument from other kwargs
flow_task_arg = task.kwargs.pop("task") flow_task_arg = task.kwargs.pop("task")
task.result = await task.flow.arun( task.result = await task.flow.arun(
flow_task_arg, *task.args, **task.kwargs flow_task_arg, *task.args, **task.kwargs)
)
else: else:
# If it's not a Flow instance, call the flow directly # If it's not a Flow instance, call the flow directly
task.result = await task.flow(*task.args, **task.kwargs) task.result = await task.flow(*task.args, **task.kwargs)
@ -429,4 +417,5 @@ class SequentialWorkflow:
# Autosave the workflow state # Autosave the workflow state
if self.autosave: if self.autosave:
self.save_workflow_state("sequential_workflow_state.json") self.save_workflow_state(
"sequential_workflow_state.json")

@ -13,6 +13,7 @@ from swarms.artifacts.error_artifact import ErrorArtifact
class BaseTask(ABC): class BaseTask(ABC):
class State(Enum): class State(Enum):
PENDING = 1 PENDING = 1
EXECUTING = 2 EXECUTING = 2
@ -33,11 +34,15 @@ class BaseTask(ABC):
@property @property
def parents(self) -> List[BaseTask]: def parents(self) -> List[BaseTask]:
return [self.structure.find_task(parent_id) for parent_id in self.parent_ids] return [
self.structure.find_task(parent_id) for parent_id in self.parent_ids
]
@property @property
def children(self) -> List[BaseTask]: def children(self) -> List[BaseTask]:
return [self.structure.find_task(child_id) for child_id in self.child_ids] return [
self.structure.find_task(child_id) for child_id in self.child_ids
]
def __rshift__(self, child: BaseTask) -> BaseTask: def __rshift__(self, child: BaseTask) -> BaseTask:
return self.add_child(child) return self.add_child(child)
@ -118,8 +123,7 @@ class BaseTask(ABC):
def can_execute(self) -> bool: def can_execute(self) -> bool:
return self.state == self.State.PENDING and all( return self.state == self.State.PENDING and all(
parent.is_finished() for parent in self.parents parent.is_finished() for parent in self.parents)
)
def reset(self) -> BaseTask: def reset(self) -> BaseTask:
self.state = self.State.PENDING self.state = self.State.PENDING
@ -132,10 +136,10 @@ class BaseTask(ABC):
class Task(BaseModel): class Task(BaseModel):
input: Optional[StrictStr] = Field(None, description="Input prompt for the task") input: Optional[StrictStr] = Field(None,
description="Input prompt for the task")
additional_input: Optional[Any] = Field( additional_input: Optional[Any] = Field(
None, description="Input parameters for the task. Any value is allowed" None, description="Input parameters for the task. Any value is allowed")
)
task_id: StrictStr = Field(..., description="ID of the task") task_id: StrictStr = Field(..., description="ID of the task")
class Config: class Config:

@ -65,11 +65,13 @@ class Workflow:
def context(self, task: Task) -> Dict[str, Any]: def context(self, task: Task) -> Dict[str, Any]:
"""Context in tasks""" """Context in tasks"""
return { return {
"parent_output": task.parents[0].output "parent_output":
if task.parents and task.parents[0].output task.parents[0].output
else None, if task.parents and task.parents[0].output else None,
"parent": task.parents[0] if task.parents else None, "parent":
"child": task.children[0] if task.children else None, task.parents[0] if task.parents else None,
"child":
task.children[0] if task.children else None,
} }
def __run_from_task(self, task: Optional[Task]) -> None: def __run_from_task(self, task: Optional[Task]) -> None:

@ -87,7 +87,8 @@ class AutoScaler:
while True: while True:
sleep(60) # check minute sleep(60) # check minute
pending_tasks = self.task_queue.qsize() pending_tasks = self.task_queue.qsize()
active_agents = sum([1 for agent in self.agents_pool if agent.is_busy()]) active_agents = sum(
[1 for agent in self.agents_pool if agent.is_busy()])
if pending_tasks / len(self.agents_pool) > self.busy_threshold: if pending_tasks / len(self.agents_pool) > self.busy_threshold:
self.scale_up() self.scale_up()

@ -117,7 +117,9 @@ class AbstractSwarm(ABC):
pass pass
@abstractmethod @abstractmethod
def broadcast(self, message: str, sender: Optional["AbstractWorker"] = None): def broadcast(self,
message: str,
sender: Optional["AbstractWorker"] = None):
"""Broadcast a message to all workers""" """Broadcast a message to all workers"""
pass pass

@ -77,19 +77,15 @@ class BattleRoyalSwarm:
# Check for clashes and handle them # Check for clashes and handle them
for i, worker1 in enumerate(self.workers): for i, worker1 in enumerate(self.workers):
for j, worker2 in enumerate(self.workers): for j, worker2 in enumerate(self.workers):
if ( if (i != j and worker1.is_within_proximity(worker2) and
i != j set(worker1.teams) != set(worker2.teams)):
and worker1.is_within_proximity(worker2)
and set(worker1.teams) != set(worker2.teams)
):
winner, loser = self.clash(worker1, worker2, question) winner, loser = self.clash(worker1, worker2, question)
print(f"Worker {winner.id} won over Worker {loser.id}") print(f"Worker {winner.id} won over Worker {loser.id}")
def communicate(self, sender: Worker, reciever: Worker, message: str): def communicate(self, sender: Worker, reciever: Worker, message: str):
"""Communicate a message from one worker to another.""" """Communicate a message from one worker to another."""
if sender.is_within_proximity(reciever) or any( if sender.is_within_proximity(reciever) or any(
team in sender.teams for team in reciever.teams team in sender.teams for team in reciever.teams):
):
pass pass
def clash(self, worker1: Worker, worker2: Worker, question: str): def clash(self, worker1: Worker, worker2: Worker, question: str):

@ -49,9 +49,8 @@ class GodMode:
table.append([f"LLM {i+1}", response]) table.append([f"LLM {i+1}", response])
print( print(
colored( colored(
tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"),
) "cyan"))
)
def run_all(self, task): def run_all(self, task):
"""Run the task on all LLMs""" """Run the task on all LLMs"""
@ -74,18 +73,15 @@ class GodMode:
table.append([f"LLM {i+1}", response]) table.append([f"LLM {i+1}", response])
print( print(
colored( colored(
tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"),
) "cyan"))
)
# New Features # New Features
def save_responses_to_file(self, filename): def save_responses_to_file(self, filename):
"""Save responses to file""" """Save responses to file"""
with open(filename, "w") as file: with open(filename, "w") as file:
table = [ table = [[f"LLM {i+1}", response]
[f"LLM {i+1}", response] for i, response in enumerate(self.last_responses)]
for i, response in enumerate(self.last_responses)
]
file.write(tabulate(table, headers=["LLM", "Response"])) file.write(tabulate(table, headers=["LLM", "Response"]))
@classmethod @classmethod
@ -105,11 +101,9 @@ class GodMode:
for i, task in enumerate(self.task_history): for i, task in enumerate(self.task_history):
print(f"{i + 1}. {task}") print(f"{i + 1}. {task}")
print("\nLast Responses:") print("\nLast Responses:")
table = [ table = [[f"LLM {i+1}", response]
[f"LLM {i+1}", response] for i, response in enumerate(self.last_responses) for i, response in enumerate(self.last_responses)]
]
print( print(
colored( colored(
tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan" tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"),
) "cyan"))
)

@ -3,7 +3,6 @@ from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List
from swarms.structs.flow import Flow from swarms.structs.flow import Flow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,7 +33,8 @@ class GroupChat:
def next_agent(self, agent: Flow) -> Flow: def next_agent(self, agent: Flow) -> Flow:
"""Return the next agent in the list.""" """Return the next agent in the list."""
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] return self.agents[(self.agent_names.index(agent.name) + 1) %
len(self.agents)]
def select_speaker_msg(self): def select_speaker_msg(self):
"""Return the message for selecting the next speaker.""" """Return the message for selecting the next speaker."""
@ -55,24 +55,17 @@ class GroupChat:
if n_agents < 3: if n_agents < 3:
logger.warning( logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct" f"GroupChat is underpopulated with {n_agents} agents. Direct"
" communication would be more efficient." " communication would be more efficient.")
)
name = selector.generate_reply( name = selector.generate_reply(
self.format_history( self.format_history(self.messages + [{
self.messages "role":
+ [ "system",
{ "content":
"role": "system", ("Read the above conversation. Then select the next most"
"content": ( f" suitable role from {self.agent_names} to play. Only"
"Read the above conversation. Then select the next most" " return the role."),
f" suitable role from {self.agent_names} to play. Only" }]))
" return the role."
),
}
]
)
)
try: try:
return self.agent_by_name(name["content"]) return self.agent_by_name(name["content"])
except ValueError: except ValueError:
@ -80,8 +73,7 @@ class GroupChat:
def _participant_roles(self): def _participant_roles(self):
return "\n".join( return "\n".join(
[f"{agent.name}: {agent.system_message}" for agent in self.agents] [f"{agent.name}: {agent.system_message}" for agent in self.agents])
)
def format_history(self, messages: List[Dict]) -> str: def format_history(self, messages: List[Dict]) -> str:
formatted_messages = [] formatted_messages = []
@ -92,19 +84,21 @@ class GroupChat:
class GroupChatManager: class GroupChatManager:
def __init__(self, groupchat: GroupChat, selector: Flow): def __init__(self, groupchat: GroupChat, selector: Flow):
self.groupchat = groupchat self.groupchat = groupchat
self.selector = selector self.selector = selector
def __call__(self, task: str): def __call__(self, task: str):
self.groupchat.messages.append({"role": self.selector.name, "content": task}) self.groupchat.messages.append({
"role": self.selector.name,
"content": task
})
for i in range(self.groupchat.max_round): for i in range(self.groupchat.max_round):
speaker = self.groupchat.select_speaker( speaker = self.groupchat.select_speaker(last_speaker=self.selector,
last_speaker=self.selector, selector=self.selector selector=self.selector)
)
reply = speaker.generate_reply( reply = speaker.generate_reply(
self.groupchat.format_history(self.groupchat.messages) self.groupchat.format_history(self.groupchat.messages))
)
self.groupchat.messages.append(reply) self.groupchat.messages.append(reply)
print(reply) print(reply)
if i == self.groupchat.max_round - 1: if i == self.groupchat.max_round - 1:

@ -5,16 +5,16 @@ from langchain.output_parsers import RegexParser
# utils # utils
class BidOutputParser(RegexParser): class BidOutputParser(RegexParser):
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return ( return (
"Your response should be an integrater delimited by angled brackets like" "Your response should be an integrater delimited by angled brackets like"
" this: <int>" " this: <int>")
)
bid_parser = BidOutputParser( bid_parser = BidOutputParser(regex=r"<(\d+)>",
regex=r"<(\d+)>", output_keys=["bid"], default_output_key="bid" output_keys=["bid"],
) default_output_key="bid")
def select_next_speaker(step: int, agents, director) -> int: def select_next_speaker(step: int, agents, director) -> int:
@ -29,6 +29,7 @@ def select_next_speaker(step: int, agents, director) -> int:
# main # main
class MultiAgentCollaboration: class MultiAgentCollaboration:
def __init__( def __init__(
self, self,
agents, agents,

@ -46,7 +46,6 @@ class MultiAgentDebate:
def format_results(self, results): def format_results(self, results):
formatted_results = "\n".join( formatted_results = "\n".join(
[f"Agent responded: {result['response']}" for result in results] [f"Agent responded: {result['response']}" for result in results])
)
return formatted_results return formatted_results

@ -111,7 +111,8 @@ class Orchestrator:
self.chroma_client = chromadb.Client() self.chroma_client = chromadb.Client()
self.collection = self.chroma_client.create_collection(name=collection_name) self.collection = self.chroma_client.create_collection(
name=collection_name)
self.current_tasks = {} self.current_tasks = {}
@ -137,9 +138,8 @@ class Orchestrator:
result = self.worker.run(task["content"]) result = self.worker.run(task["content"])
# using the embed method to get the vector representation of the result # using the embed method to get the vector representation of the result
vector_representation = self.embed( vector_representation = self.embed(result, self.api_key,
result, self.api_key, self.model_name self.model_name)
)
self.collection.add( self.collection.add(
embeddings=[vector_representation], embeddings=[vector_representation],
@ -154,8 +154,7 @@ class Orchestrator:
except Exception as error: except Exception as error:
logging.error( logging.error(
f"Failed to process task {id(task)} by agent {id(agent)}. Error:" f"Failed to process task {id(task)} by agent {id(agent)}. Error:"
f" {error}" f" {error}")
)
finally: finally:
with self.condition: with self.condition:
self.agents.put(agent) self.agents.put(agent)
@ -163,8 +162,7 @@ class Orchestrator:
def embed(self, input, api_key, model_name): def embed(self, input, api_key, model_name):
openai = embedding_functions.OpenAIEmbeddingFunction( openai = embedding_functions.OpenAIEmbeddingFunction(
api_key=api_key, model_name=model_name api_key=api_key, model_name=model_name)
)
embedding = openai(input) embedding = openai(input)
return embedding return embedding
@ -175,13 +173,13 @@ class Orchestrator:
try: try:
# Query the vector database for documents created by the agents # Query the vector database for documents created by the agents
results = self.collection.query(query_texts=[str(agent_id)], n_results=10) results = self.collection.query(query_texts=[str(agent_id)],
n_results=10)
return results return results
except Exception as e: except Exception as e:
logging.error( logging.error(
f"Failed to retrieve results from agent {agent_id}. Error {e}" f"Failed to retrieve results from agent {agent_id}. Error {e}")
)
raise raise
# @abstractmethod # @abstractmethod
@ -212,7 +210,8 @@ class Orchestrator:
self.collection.add(documents=[result], ids=[str(id(result))]) self.collection.add(documents=[result], ids=[str(id(result))])
except Exception as e: except Exception as e:
logging.error(f"Failed to append the agent output to database. Error: {e}") logging.error(
f"Failed to append the agent output to database. Error: {e}")
raise raise
def run(self, objective: str): def run(self, objective: str):
@ -225,8 +224,8 @@ class Orchestrator:
self.task_queue.append(objective) self.task_queue.append(objective)
results = [ results = [
self.assign_task(agent_id, task) self.assign_task(agent_id, task) for agent_id, task in zip(
for agent_id, task in zip(range(len(self.agents)), self.task_queue) range(len(self.agents)), self.task_queue)
] ]
for result in results: for result in results:

@ -2,6 +2,7 @@ from queue import Queue, PriorityQueue
class SimpleSwarm: class SimpleSwarm:
def __init__( def __init__(
self, self,
llm, llm,

@ -8,8 +8,7 @@ import torch
from langchain.agents import tool from langchain.agents import tool
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.chains.qa_with_sources.loading import ( from langchain.chains.qa_with_sources.loading import (
BaseCombineDocumentsChain, BaseCombineDocumentsChain,)
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import BaseTool from langchain.tools import BaseTool
@ -37,9 +36,10 @@ def pushd(new_dir):
@tool @tool
def process_csv( def process_csv(llm,
llm, csv_file_path: str, instructions: str, output_path: Optional[str] = None csv_file_path: str,
) -> str: instructions: str,
output_path: Optional[str] = None) -> str:
"""Process a CSV by with pandas in a limited REPL.\ """Process a CSV by with pandas in a limited REPL.\
Only use this after writing data to disk as a csv file.\ Only use this after writing data to disk as a csv file.\
Any figures must be saved to disk to be viewed by the human.\ Any figures must be saved to disk to be viewed by the human.\
@ -49,7 +49,10 @@ def process_csv(
df = pd.read_csv(csv_file_path) df = pd.read_csv(csv_file_path)
except Exception as e: except Exception as e:
return f"Error: {e}" return f"Error: {e}"
agent = create_pandas_dataframe_agent(llm, df, max_iterations=30, verbose=False) agent = create_pandas_dataframe_agent(llm,
df,
max_iterations=30,
verbose=False)
if output_path is not None: if output_path is not None:
instructions += f" Save output to disk at {output_path}" instructions += f" Save output to disk at {output_path}"
try: try:
@ -79,7 +82,8 @@ async def async_load_playwright(url: str) -> str:
text = soup.get_text() text = soup.get_text()
lines = (line.strip() for line in text.splitlines()) lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) chunks = (
phrase.strip() for line in lines for phrase in line.split(" "))
results = "\n".join(chunk for chunk in chunks if chunk) results = "\n".join(chunk for chunk in chunks if chunk)
except Exception as e: except Exception as e:
results = f"Error: {e}" results = f"Error: {e}"
@ -113,8 +117,7 @@ class WebpageQATool(BaseTool):
"Browse a webpage and retrieve the information relevant to the question." "Browse a webpage and retrieve the information relevant to the question."
) )
text_splitter: RecursiveCharacterTextSplitter = Field( text_splitter: RecursiveCharacterTextSplitter = Field(
default_factory=_get_text_splitter default_factory=_get_text_splitter)
)
qa_chain: BaseCombineDocumentsChain qa_chain: BaseCombineDocumentsChain
def _run(self, url: str, question: str) -> str: def _run(self, url: str, question: str) -> str:
@ -125,9 +128,12 @@ class WebpageQATool(BaseTool):
results = [] results = []
# TODO: Handle this with a MapReduceChain # TODO: Handle this with a MapReduceChain
for i in range(0, len(web_docs), 4): for i in range(0, len(web_docs), 4):
input_docs = web_docs[i : i + 4] input_docs = web_docs[i:i + 4]
window_result = self.qa_chain( window_result = self.qa_chain(
{"input_documents": input_docs, "question": question}, {
"input_documents": input_docs,
"question": question
},
return_only_outputs=True, return_only_outputs=True,
) )
results.append(f"Response from window {i} - {window_result}") results.append(f"Response from window {i} - {window_result}")
@ -135,7 +141,10 @@ class WebpageQATool(BaseTool):
Document(page_content="\n".join(results), metadata={"source": url}) Document(page_content="\n".join(results), metadata={"source": url})
] ]
return self.qa_chain( return self.qa_chain(
{"input_documents": results_docs, "question": question}, {
"input_documents": results_docs,
"question": question
},
return_only_outputs=True, return_only_outputs=True,
) )
@ -171,18 +180,17 @@ def VQAinference(self, inputs):
torch_dtype = torch.float16 if "cuda" in device else torch.float32 torch_dtype = torch.float16 if "cuda" in device else torch.float32
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained( model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", torch_dtype=torch_dtype "Salesforce/blip-vqa-base", torch_dtype=torch_dtype).to(device)
).to(device)
image_path, question = inputs.split(",") image_path, question = inputs.split(",")
raw_image = Image.open(image_path).convert("RGB") raw_image = Image.open(image_path).convert("RGB")
inputs = processor(raw_image, question, return_tensors="pt").to(device, torch_dtype) inputs = processor(raw_image, question,
return_tensors="pt").to(device, torch_dtype)
out = model.generate(**inputs) out = model.generate(**inputs)
answer = processor.decode(out[0], skip_special_tokens=True) answer = processor.decode(out[0], skip_special_tokens=True)
logger.debug( logger.debug(
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
f" Question: {question}, Output Answer: {answer}" f" Question: {question}, Output Answer: {answer}")
)
return answer return answer

@ -25,13 +25,14 @@ from swarms.utils.main import BaseHandler, get_new_image_name
class MaskFormer: class MaskFormer:
def __init__(self, device): def __init__(self, device):
print("Initializing MaskFormer to %s" % device) print("Initializing MaskFormer to %s" % device)
self.device = device self.device = device
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") self.processor = CLIPSegProcessor.from_pretrained(
"CIDAS/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained( self.model = CLIPSegForImageSegmentation.from_pretrained(
"CIDAS/clipseg-rd64-refined" "CIDAS/clipseg-rd64-refined").to(device)
).to(device)
def inference(self, image_path, text): def inference(self, image_path, text):
threshold = 0.5 threshold = 0.5
@ -39,9 +40,10 @@ class MaskFormer:
padding = 20 padding = 20
original_image = Image.open(image_path) original_image = Image.open(image_path)
image = original_image.resize((512, 512)) image = original_image.resize((512, 512))
inputs = self.processor( inputs = self.processor(text=text,
text=text, images=image, padding="max_length", return_tensors="pt" images=image,
).to(self.device) padding="max_length",
return_tensors="pt").to(self.device)
with torch.no_grad(): with torch.no_grad():
outputs = self.model(**inputs) outputs = self.model(**inputs)
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
@ -52,8 +54,7 @@ class MaskFormer:
mask_array = np.zeros_like(mask, dtype=bool) mask_array = np.zeros_like(mask, dtype=bool)
for idx in true_indices: for idx in true_indices:
padded_slice = tuple( padded_slice = tuple(
slice(max(0, i - padding), i + padding + 1) for i in idx slice(max(0, i - padding), i + padding + 1) for i in idx)
)
mask_array[padded_slice] = True mask_array[padded_slice] = True
visual_mask = (mask_array * 255).astype(np.uint8) visual_mask = (mask_array * 255).astype(np.uint8)
image_mask = Image.fromarray(visual_mask) image_mask = Image.fromarray(visual_mask)
@ -61,6 +62,7 @@ class MaskFormer:
class ImageEditing: class ImageEditing:
def __init__(self, device): def __init__(self, device):
print("Initializing ImageEditing to %s" % device) print("Initializing ImageEditing to %s" % device)
self.device = device self.device = device
@ -75,25 +77,24 @@ class ImageEditing:
@tool( @tool(
name="Remove Something From The Photo", name="Remove Something From The Photo",
description=( description=
"useful when you want to remove and object or something from the photo " ("useful when you want to remove and object or something from the photo "
"from its description or location. " "from its description or location. "
"The input to this tool should be a comma separated string of two, " "The input to this tool should be a comma separated string of two, "
"representing the image_path and the object need to be removed. " "representing the image_path and the object need to be removed. "),
),
) )
def inference_remove(self, inputs): def inference_remove(self, inputs):
image_path, to_be_removed_txt = inputs.split(",") image_path, to_be_removed_txt = inputs.split(",")
return self.inference_replace(f"{image_path},{to_be_removed_txt},background") return self.inference_replace(
f"{image_path},{to_be_removed_txt},background")
@tool( @tool(
name="Replace Something From The Photo", name="Replace Something From The Photo",
description=( description=
"useful when you want to replace an object from the object description or" ("useful when you want to replace an object from the object description or"
" location with another object from its description. The input to this tool" " location with another object from its description. The input to this tool"
" should be a comma separated string of three, representing the image_path," " should be a comma separated string of three, representing the image_path,"
" the object to be replaced, the object to be replaced with " " the object to be replaced, the object to be replaced with "),
),
) )
def inference_replace(self, inputs): def inference_replace(self, inputs):
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
@ -105,22 +106,21 @@ class ImageEditing:
image=original_image.resize((512, 512)), image=original_image.resize((512, 512)),
mask_image=mask_image.resize((512, 512)), mask_image=mask_image.resize((512, 512)),
).images[0] ).images[0]
updated_image_path = get_new_image_name( updated_image_path = get_new_image_name(image_path,
image_path, func_name="replace-something" func_name="replace-something")
)
updated_image = updated_image.resize(original_size) updated_image = updated_image.resize(original_size)
updated_image.save(updated_image_path) updated_image.save(updated_image_path)
logger.debug( logger.debug(
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" f"\nProcessed ImageEditing, Input Image: {image_path}, Replace"
f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:"
f" {updated_image_path}" f" {updated_image_path}")
)
return updated_image_path return updated_image_path
class InstructPix2Pix: class InstructPix2Pix:
def __init__(self, device): def __init__(self, device):
print("Initializing InstructPix2Pix to %s" % device) print("Initializing InstructPix2Pix to %s" % device)
self.device = device self.device = device
@ -131,60 +131,56 @@ class InstructPix2Pix:
torch_dtype=self.torch_dtype, torch_dtype=self.torch_dtype,
).to(device) ).to(device)
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.pipe.scheduler.config self.pipe.scheduler.config)
)
@tool( @tool(
name="Instruct Image Using Text", name="Instruct Image Using Text",
description=( description=
"useful when you want to the style of the image to be like the text. " ("useful when you want to the style of the image to be like the text. "
"like: make it look like a painting. or make it like a robot. " "like: make it look like a painting. or make it like a robot. "
"The input to this tool should be a comma separated string of two, " "The input to this tool should be a comma separated string of two, "
"representing the image_path and the text. " "representing the image_path and the text. "),
),
) )
def inference(self, inputs): def inference(self, inputs):
"""Change style of image.""" """Change style of image."""
logger.debug("===> Starting InstructPix2Pix Inference") logger.debug("===> Starting InstructPix2Pix Inference")
image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:])
original_image = Image.open(image_path) original_image = Image.open(image_path)
image = self.pipe( image = self.pipe(text,
text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2 image=original_image,
).images[0] num_inference_steps=40,
image_guidance_scale=1.2).images[0]
updated_image_path = get_new_image_name(image_path, func_name="pix2pix") updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
image.save(updated_image_path) image.save(updated_image_path)
logger.debug( logger.debug(
f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:"
f" {text}, Output Image: {updated_image_path}" f" {text}, Output Image: {updated_image_path}")
)
return updated_image_path return updated_image_path
class Text2Image: class Text2Image:
def __init__(self, device): def __init__(self, device):
print("Initializing Text2Image to %s" % device) print("Initializing Text2Image to %s" % device)
self.device = device self.device = device
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.pipe = StableDiffusionPipeline.from_pretrained( self.pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype)
)
self.pipe.to(device) self.pipe.to(device)
self.a_prompt = "best quality, extremely detailed" self.a_prompt = "best quality, extremely detailed"
self.n_prompt = ( self.n_prompt = (
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, "
"fewer digits, cropped, worst quality, low quality" "fewer digits, cropped, worst quality, low quality")
)
@tool( @tool(
name="Generate Image From User Input Text", name="Generate Image From User Input Text",
description=( description=
"useful when you want to generate an image from a user input text and save" ("useful when you want to generate an image from a user input text and save"
" it to a file. like: generate an image of an object or something, or" " it to a file. like: generate an image of an object or something, or"
" generate an image that includes some objects. The input to this tool" " generate an image that includes some objects. The input to this tool"
" should be a string, representing the text used to generate image. " " should be a string, representing the text used to generate image. "),
),
) )
def inference(self, text): def inference(self, text):
image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png") image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png")
@ -194,59 +190,59 @@ class Text2Image:
logger.debug( logger.debug(
f"\nProcessed Text2Image, Input Text: {text}, Output Image:" f"\nProcessed Text2Image, Input Text: {text}, Output Image:"
f" {image_filename}" f" {image_filename}")
)
return image_filename return image_filename
class VisualQuestionAnswering: class VisualQuestionAnswering:
def __init__(self, device): def __init__(self, device):
print("Initializing VisualQuestionAnswering to %s" % device) print("Initializing VisualQuestionAnswering to %s" % device)
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.device = device self.device = device
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained( self.model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype "Salesforce/blip-vqa-base",
).to(self.device) torch_dtype=self.torch_dtype).to(self.device)
@tool( @tool(
name="Answer Question About The Image", name="Answer Question About The Image",
description=( description=
"useful when you need an answer for a question based on an image. like:" ("useful when you need an answer for a question based on an image. like:"
" what is the background color of the last image, how many cats in this" " what is the background color of the last image, how many cats in this"
" figure, what is in this figure. The input to this tool should be a comma" " figure, what is in this figure. The input to this tool should be a comma"
" separated string of two, representing the image_path and the question" " separated string of two, representing the image_path and the question"
), ),
) )
def inference(self, inputs): def inference(self, inputs):
image_path, question = inputs.split(",") image_path, question = inputs.split(",")
raw_image = Image.open(image_path).convert("RGB") raw_image = Image.open(image_path).convert("RGB")
inputs = self.processor(raw_image, question, return_tensors="pt").to( inputs = self.processor(raw_image, question,
self.device, self.torch_dtype return_tensors="pt").to(self.device,
) self.torch_dtype)
out = self.model.generate(**inputs) out = self.model.generate(**inputs)
answer = self.processor.decode(out[0], skip_special_tokens=True) answer = self.processor.decode(out[0], skip_special_tokens=True)
logger.debug( logger.debug(
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
f" Question: {question}, Output Answer: {answer}" f" Question: {question}, Output Answer: {answer}")
)
return answer return answer
class ImageCaptioning(BaseHandler): class ImageCaptioning(BaseHandler):
def __init__(self, device): def __init__(self, device):
print("Initializing ImageCaptioning to %s" % device) print("Initializing ImageCaptioning to %s" % device)
self.device = device self.device = device
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.processor = BlipProcessor.from_pretrained( self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base" "Salesforce/blip-image-captioning-base")
)
self.model = BlipForConditionalGeneration.from_pretrained( self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype "Salesforce/blip-image-captioning-base",
).to(self.device) torch_dtype=self.torch_dtype).to(self.device)
def handle(self, filename: str): def handle(self, filename: str):
img = Image.open(filename) img = Image.open(filename)
@ -258,14 +254,13 @@ class ImageCaptioning(BaseHandler):
img.save(filename, "PNG") img.save(filename, "PNG")
print(f"Resize image form {width}x{height} to {width_new}x{height_new}") print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
inputs = self.processor(Image.open(filename), return_tensors="pt").to( inputs = self.processor(Image.open(filename),
self.device, self.torch_dtype return_tensors="pt").to(self.device,
) self.torch_dtype)
out = self.model.generate(**inputs) out = self.model.generate(**inputs)
description = self.processor.decode(out[0], skip_special_tokens=True) description = self.processor.decode(out[0], skip_special_tokens=True)
print( print(
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:" f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:"
f" {description}" f" {description}")
)
return IMAGE_PROMPT.format(filename=filename, description=description) return IMAGE_PROMPT.format(filename=filename, description=description)

@ -9,6 +9,7 @@ from pytube import YouTube
class SpeechToText: class SpeechToText:
def __init__( def __init__(
self, self,
video_url, video_url,
@ -61,14 +62,15 @@ class SpeechToText:
compute_type = "float16" compute_type = "float16"
# 1. Transcribe with original Whisper (batched) 🗣️ # 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx.load_model("large-v2", device, compute_type=compute_type) model = whisperx.load_model("large-v2",
device,
compute_type=compute_type)
audio = whisperx.load_audio(audio_file) audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size) result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model( model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device language_code=result["language"], device=device)
)
result = whisperx.align( result = whisperx.align(
result["segments"], result["segments"],
model_a, model_a,
@ -80,8 +82,7 @@ class SpeechToText:
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline( diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device use_auth_token=self.hf_api_key, device=device)
)
diarize_model(audio_file) diarize_model(audio_file)
try: try:
@ -98,8 +99,7 @@ class SpeechToText:
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model( model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=self.device language_code=result["language"], device=self.device)
)
result = whisperx.align( result = whisperx.align(
result["segments"], result["segments"],
@ -112,8 +112,7 @@ class SpeechToText:
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline( diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device use_auth_token=self.hf_api_key, device=self.device)
)
diarize_model(audio_file) diarize_model(audio_file)

@ -34,9 +34,8 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model( def _create_subset_model(name: str, model: BaseModel,
name: str, model: BaseModel, field_names: list field_names: list) -> Type[BaseModel]:
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
fields = {} fields = {}
for field_name in field_names: for field_name in field_names:
@ -52,7 +51,11 @@ def _get_filtered_args(
"""Get the arguments from a function's signature.""" """Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"] schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} return {
k: schema[k]
for k in valid_keys
if k not in ("run_manager", "callbacks")
}
class _SchemaConfig: class _SchemaConfig:
@ -82,9 +85,8 @@ def create_schema_from_function(
del inferred_model.__fields__["callbacks"] del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func) valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model( return _create_subset_model(f"{model_name}Schema", inferred_model,
f"{model_name}Schema", inferred_model, list(valid_properties) list(valid_properties))
)
class ToolException(Exception): class ToolException(Exception):
@ -125,8 +127,7 @@ class ChildTool(BaseTool):
"Expected annotation of 'Type[BaseModel]'" "Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n" f" but got '{args_schema_type}'.\n"
"Expected class looks like:\n" "Expected class looks like:\n"
f"{typehint_mandate}" f"{typehint_mandate}")
)
name: str name: str
"""The unique name of the tool that clearly communicates its purpose.""" """The unique name of the tool that clearly communicates its purpose."""
@ -147,7 +148,8 @@ class ChildTool(BaseTool):
callbacks: Callbacks = Field(default=None, exclude=True) callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to be called during tool execution.""" """Callbacks to be called during tool execution."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None,
exclude=True)
"""Deprecated. Please use callbacks instead.""" """Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None """Optional list of tags associated with the tool. Defaults to None
@ -162,9 +164,8 @@ class ChildTool(BaseTool):
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
""" """
handle_tool_error: Optional[ handle_tool_error: Optional[Union[bool, str, Callable[[ToolException],
Union[bool, str, Callable[[ToolException], str]] str]]] = False
] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
class Config(Serializable.Config): class Config(Serializable.Config):
@ -244,7 +245,9 @@ class ChildTool(BaseTool):
else: else:
if input_args is not None: if input_args is not None:
result = input_args.parse_obj(tool_input) result = input_args.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input} return {
k: v for k, v in result.dict().items() if k in tool_input
}
return tool_input return tool_input
@root_validator() @root_validator()
@ -286,7 +289,8 @@ class ChildTool(BaseTool):
*args, *args,
) )
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: def _to_args_and_kwargs(self,
tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string, # For backwards compatibility, if run_input is a string,
# pass as a positional argument. # pass as a positional argument.
if isinstance(tool_input, str): if isinstance(tool_input, str):
@ -325,7 +329,10 @@ class ChildTool(BaseTool):
# TODO: maybe also pass through run_manager is _run supports kwargs # TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager") new_arg_supported = signature(self._run).parameters.get("run_manager")
run_manager = callback_manager.on_tool_start( run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {
"name": self.name,
"description": self.description
},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name, name=run_name,
@ -335,9 +342,7 @@ class ChildTool(BaseTool):
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = ( observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs) self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported if new_arg_supported else self._run(*tool_args, **tool_kwargs))
else self._run(*tool_args, **tool_kwargs)
)
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
run_manager.on_tool_error(e) run_manager.on_tool_error(e)
@ -354,19 +359,20 @@ class ChildTool(BaseTool):
else: else:
raise ValueError( raise ValueError(
"Got unexpected type of `handle_tool_error`. Expected bool, str " "Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}" f"or callable. Received: {self.handle_tool_error}")
) run_manager.on_tool_end(str(observation),
run_manager.on_tool_end( color="red",
str(observation), color="red", name=self.name, **kwargs name=self.name,
) **kwargs)
return observation return observation
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e) run_manager.on_tool_error(e)
raise e raise e
else: else:
run_manager.on_tool_end( run_manager.on_tool_end(str(observation),
str(observation), color=color, name=self.name, **kwargs color=color,
) name=self.name,
**kwargs)
return observation return observation
async def arun( async def arun(
@ -399,7 +405,10 @@ class ChildTool(BaseTool):
) )
new_arg_supported = signature(self._arun).parameters.get("run_manager") new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start( run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description}, {
"name": self.name,
"description": self.description
},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name, name=run_name,
@ -408,11 +417,10 @@ class ChildTool(BaseTool):
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = ( observation = (await self._arun(*tool_args,
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) run_manager=run_manager,
if new_arg_supported **tool_kwargs) if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs) else await self._arun(*tool_args, **tool_kwargs))
)
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
await run_manager.on_tool_error(e) await run_manager.on_tool_error(e)
@ -429,19 +437,20 @@ class ChildTool(BaseTool):
else: else:
raise ValueError( raise ValueError(
"Got unexpected type of `handle_tool_error`. Expected bool, str " "Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}" f"or callable. Received: {self.handle_tool_error}")
) await run_manager.on_tool_end(str(observation),
await run_manager.on_tool_end( color="red",
str(observation), color="red", name=self.name, **kwargs name=self.name,
) **kwargs)
return observation return observation
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e) await run_manager.on_tool_error(e)
raise e raise e
else: else:
await run_manager.on_tool_end( await run_manager.on_tool_end(str(observation),
str(observation), color=color, name=self.name, **kwargs color=color,
) name=self.name,
**kwargs)
return observation return observation
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
@ -459,7 +468,6 @@ class Tool(BaseTool):
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
# --- Runnable --- # --- Runnable ---
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict],
@ -469,8 +477,7 @@ class Tool(BaseTool):
if not self.coroutine: if not self.coroutine:
# If the tool does not implement async, fall back to default implementation # If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs) None, partial(self.invoke, input, config, **kwargs))
)
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@ -485,7 +492,8 @@ class Tool(BaseTool):
# assume it takes a single string input. # assume it takes a single string input.
return {"tool_input": {"type": "string"}} return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: def _to_args_and_kwargs(self,
tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model.""" """Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input) args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input # For backwards compatibility. The tool must be run with a single input
@ -504,16 +512,13 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
if self.func: if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks") new_argument_supported = signature(
return ( self.func).parameters.get("callbacks")
self.func( return (self.func(
*args, *args,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**kwargs, **kwargs,
) ) if new_argument_supported else self.func(*args, **kwargs))
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync") raise NotImplementedError("Tool does not support sync")
async def _arun( async def _arun(
@ -524,31 +529,27 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get( new_argument_supported = signature(
"callbacks" self.coroutine).parameters.get("callbacks")
) return (await self.coroutine(
return ( *args,
await self.coroutine( callbacks=run_manager.get_child() if run_manager else None,
*args, **kwargs,
callbacks=run_manager.get_child() if run_manager else None, ) if new_argument_supported else await self.coroutine(
**kwargs, *args, **kwargs))
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
else: else:
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self._run, run_manager=run_manager, **kwargs), *args None, partial(self._run, run_manager=run_manager, **kwargs),
) *args)
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
def __init__( def __init__(self, name: str, func: Optional[Callable], description: str,
self, name: str, func: Optional[Callable], description: str, **kwargs: Any **kwargs: Any) -> None:
) -> None:
"""Initialize tool.""" """Initialize tool."""
super(Tool, self).__init__( super(Tool, self).__init__(name=name,
name=name, func=func, description=description, **kwargs func=func,
) description=description,
**kwargs)
@classmethod @classmethod
def from_function( def from_function(
@ -558,9 +559,8 @@ class Tool(BaseTool):
description: str, description: str,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[ coroutine: Optional[Callable[..., Awaitable[
Callable[..., Awaitable[Any]] Any]]] = None, # This is last for compatibility, but should be after func
] = None, # This is last for compatibility, but should be after func
**kwargs: Any, **kwargs: Any,
) -> Tool: ) -> Tool:
"""Initialize tool from a function.""" """Initialize tool from a function."""
@ -589,7 +589,6 @@ class StructuredTool(BaseTool):
"""The asynchronous version of the function.""" """The asynchronous version of the function."""
# --- Runnable --- # --- Runnable ---
async def ainvoke( async def ainvoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict],
@ -599,8 +598,7 @@ class StructuredTool(BaseTool):
if not self.coroutine: if not self.coroutine:
# If the tool does not implement async, fall back to default implementation # If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs) None, partial(self.invoke, input, config, **kwargs))
)
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@ -619,16 +617,13 @@ class StructuredTool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
if self.func: if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks") new_argument_supported = signature(
return ( self.func).parameters.get("callbacks")
self.func( return (self.func(
*args, *args,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**kwargs, **kwargs,
) ) if new_argument_supported else self.func(*args, **kwargs))
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync") raise NotImplementedError("Tool does not support sync")
async def _arun( async def _arun(
@ -639,18 +634,14 @@ class StructuredTool(BaseTool):
) -> str: ) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get( new_argument_supported = signature(
"callbacks" self.coroutine).parameters.get("callbacks")
) return (await self.coroutine(
return ( *args,
await self.coroutine( callbacks=run_manager.get_child() if run_manager else None,
*args, **kwargs,
callbacks=run_manager.get_child() if run_manager else None, ) if new_argument_supported else await self.coroutine(
**kwargs, *args, **kwargs))
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, None,
partial(self._run, run_manager=run_manager, **kwargs), partial(self._run, run_manager=run_manager, **kwargs),
@ -707,8 +698,7 @@ class StructuredTool(BaseTool):
description = description or source_function.__doc__ description = description or source_function.__doc__
if description is None: if description is None:
raise ValueError( raise ValueError(
"Function must have a docstring if description not provided." "Function must have a docstring if description not provided.")
)
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
@ -716,7 +706,8 @@ class StructuredTool(BaseTool):
description = f"{name}{sig} - {description.strip()}" description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", source_function) _args_schema = create_schema_from_function(f"{name}Schema",
source_function)
return cls( return cls(
name=name, name=name,
func=func, func=func,
@ -764,6 +755,7 @@ def tool(
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable): if isinstance(dec_func, Runnable):
runnable = dec_func runnable = dec_func
@ -771,14 +763,13 @@ def tool(
if runnable.input_schema.schema().get("type") != "object": if runnable.input_schema.schema().get("type") != "object":
raise ValueError("Runnable must have an object schema.") raise ValueError("Runnable must have an object schema.")
async def ainvoke_wrapper( async def ainvoke_wrapper(callbacks: Optional[Callbacks] = None,
callbacks: Optional[Callbacks] = None, **kwargs: Any **kwargs: Any) -> Any:
) -> Any: return await runnable.ainvoke(kwargs,
return await runnable.ainvoke(kwargs, {"callbacks": callbacks}) {"callbacks": callbacks})
def invoke_wrapper( def invoke_wrapper(callbacks: Optional[Callbacks] = None,
callbacks: Optional[Callbacks] = None, **kwargs: Any **kwargs: Any) -> Any:
) -> Any:
return runnable.invoke(kwargs, {"callbacks": callbacks}) return runnable.invoke(kwargs, {"callbacks": callbacks})
coroutine = ainvoke_wrapper coroutine = ainvoke_wrapper
@ -811,8 +802,7 @@ def tool(
if func.__doc__ is None: if func.__doc__ is None:
raise ValueError( raise ValueError(
"Function must have a docstring if " "Function must have a docstring if "
"description not provided and infer_schema is False." "description not provided and infer_schema is False.")
)
return Tool( return Tool(
name=tool_name, name=tool_name,
func=func, func=func,
@ -823,7 +813,8 @@ def tool(
return _make_tool return _make_tool
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable): if len(args) == 2 and isinstance(args[0], str) and isinstance(
args[1], Runnable):
return _make_with_name(args[0])(args[1]) return _make_with_name(args[0])(args[1])
elif len(args) == 1 and isinstance(args[0], str): elif len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name # if the argument is a string, then we use the string as the tool name

@ -6,6 +6,7 @@ FuncToolBuilder = Callable[[], ToolBuilder]
class ToolsRegistry: class ToolsRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.tools: Dict[str, FuncToolBuilder] = {} self.tools: Dict[str, FuncToolBuilder] = {}
@ -18,8 +19,7 @@ class ToolsRegistry:
if isinstance(ret, tool): if isinstance(ret, tool):
return ret return ret
raise ValueError( raise ValueError(
"Tool builder {} did not return a Tool instance".format(tool_name) "Tool builder {} did not return a Tool instance".format(tool_name))
)
def list_tools(self) -> List[str]: def list_tools(self) -> List[str]:
return list(self.tools.keys()) return list(self.tools.keys())
@ -29,6 +29,7 @@ tools_registry = ToolsRegistry()
def register(tool_name): def register(tool_name):
def decorator(tool: FuncToolBuilder): def decorator(tool: FuncToolBuilder):
tools_registry.register(tool_name, tool) tools_registry.register(tool_name, tool)
return tool return tool

@ -118,14 +118,19 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
# Most of the time it doesn't matter, but we should figure out why it happens frequently with: # Most of the time it doesn't matter, but we should figure out why it happens frequently with:
# applescript # applescript
yield {"output": traceback.format_exc()} yield {"output": traceback.format_exc()}
yield {"output": f"Retrying... ({retry_count}/{max_retries})"} yield {
"output": f"Retrying... ({retry_count}/{max_retries})"
}
yield {"output": "Restarting process."} yield {"output": "Restarting process."}
self.start_process() self.start_process()
retry_count += 1 retry_count += 1
if retry_count > max_retries: if retry_count > max_retries:
yield {"output": "Maximum retries reached. Could not execute code."} yield {
"output":
"Maximum retries reached. Could not execute code."
}
return return
while True: while True:
@ -134,7 +139,8 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
else: else:
time.sleep(0.1) time.sleep(0.1)
try: try:
output = self.output_queue.get(timeout=0.3) # Waits for 0.3 seconds output = self.output_queue.get(
timeout=0.3) # Waits for 0.3 seconds
yield output yield output
except queue.Empty: except queue.Empty:
if self.done.is_set(): if self.done.is_set():

@ -6,6 +6,7 @@ import warnings
def log_decorator(func): def log_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
logging.info(f"Entering {func.__name__}") logging.info(f"Entering {func.__name__}")
result = func(*args, **kwargs) result = func(*args, **kwargs)
@ -16,6 +17,7 @@ def log_decorator(func):
def error_decorator(func): def error_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -27,18 +29,22 @@ def error_decorator(func):
def timing_decorator(func): def timing_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = time.time() start_time = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
end_time = time.time() end_time = time.time()
logging.info(f"{func.__name__} executed in {end_time - start_time} seconds") logging.info(
f"{func.__name__} executed in {end_time - start_time} seconds")
return result return result
return wrapper return wrapper
def retry_decorator(max_retries=5): def retry_decorator(max_retries=5):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for _ in range(max_retries): for _ in range(max_retries):
@ -77,16 +83,20 @@ def synchronized_decorator(func):
def deprecated_decorator(func): def deprecated_decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
warnings.warn(f"{func.__name__} is deprecated", category=DeprecationWarning) warnings.warn(f"{func.__name__} is deprecated",
category=DeprecationWarning)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def validate_inputs_decorator(validator): def validate_inputs_decorator(validator):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not validator(*args, **kwargs): if not validator(*args, **kwargs):

@ -5,6 +5,8 @@ T = TypeVar("T")
def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]:
futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED) futures.wait(fs_dict.values(),
timeout=None,
return_when=futures.ALL_COMPLETED)
return {key: future.result() for key, future in fs_dict.items()} return {key: future.result() for key, future in fs_dict.items()}

@ -4,8 +4,7 @@ import hashlib
def dataframe_to_hash(dataframe: pd.DataFrame) -> str: def dataframe_to_hash(dataframe: pd.DataFrame) -> str:
return hashlib.sha256( return hashlib.sha256(
pd.util.hash_pandas_object(dataframe, index=True).values pd.util.hash_pandas_object(dataframe, index=True).values).hexdigest()
).hexdigest()
def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str: def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str:

@ -51,16 +51,16 @@ def get_new_image_name(org_img_name, func_name="update"):
if len(name_split) == 1: if len(name_split) == 1:
most_org_file_name = name_split[0] most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.png".format( new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name,
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name recent_prev_file_name,
) most_org_file_name)
else: else:
assert len(name_split) == 4 assert len(name_split) == 4
most_org_file_name = name_split[3] most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.png".format( new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name,
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name recent_prev_file_name,
) most_org_file_name)
return os.path.join(head, new_file_name) return os.path.join(head, new_file_name)
@ -73,26 +73,26 @@ def get_new_dataframe_name(org_img_name, func_name="update"):
if len(name_split) == 1: if len(name_split) == 1:
most_org_file_name = name_split[0] most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.csv".format( new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name,
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name recent_prev_file_name,
) most_org_file_name)
else: else:
assert len(name_split) == 4 assert len(name_split) == 4
most_org_file_name = name_split[3] most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.csv".format( new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name,
this_new_uuid, func_name, recent_prev_file_name, most_org_file_name recent_prev_file_name,
) most_org_file_name)
return os.path.join(head, new_file_name) return os.path.join(head, new_file_name)
# =======================> utils end # =======================> utils end
# =======================> ANSI BEGINNING # =======================> ANSI BEGINNING
class Code: class Code:
def __init__(self, value: int): def __init__(self, value: int):
self.value = value self.value = value
@ -101,6 +101,7 @@ class Code:
class Color(Code): class Color(Code):
def bg(self) -> "Color": def bg(self) -> "Color":
self.value += 10 self.value += 10
return self return self
@ -147,6 +148,7 @@ class Color(Code):
class Style(Code): class Style(Code):
@staticmethod @staticmethod
def reset() -> "Style": def reset() -> "Style":
return Style(0) return Style(0)
@ -203,19 +205,19 @@ def dim_multiline(message: str) -> str:
lines = message.split("\n") lines = message.split("\n")
if len(lines) <= 1: if len(lines) <= 1:
return lines[0] return lines[0]
return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright()) return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(
Color.black().bright())
# +=============================> ANSI Ending # +=============================> ANSI Ending
# ================================> upload base # ================================> upload base
STATIC_DIR = "static" STATIC_DIR = "static"
class AbstractUploader(ABC): class AbstractUploader(ABC):
@abstractmethod @abstractmethod
def upload(self, filepath: str) -> str: def upload(self, filepath: str) -> str:
pass pass
@ -227,12 +229,13 @@ class AbstractUploader(ABC):
# ================================> upload end # ================================> upload end
# ========================= upload s3 # ========================= upload s3
class S3Uploader(AbstractUploader): class S3Uploader(AbstractUploader):
def __init__(self, accessKey: str, secretKey: str, region: str, bucket: str):
def __init__(self, accessKey: str, secretKey: str, region: str,
bucket: str):
self.accessKey = accessKey self.accessKey = accessKey
self.secretKey = secretKey self.secretKey = secretKey
self.region = region self.region = region
@ -263,11 +266,11 @@ class S3Uploader(AbstractUploader):
# ========================= upload s3 # ========================= upload s3
# ========================> upload/static # ========================> upload/static
class StaticUploader(AbstractUploader): class StaticUploader(AbstractUploader):
def __init__(self, server: str, path: Path, endpoint: str): def __init__(self, server: str, path: Path, endpoint: str):
self.server = server self.server = server
self.path = path self.path = path
@ -292,7 +295,6 @@ class StaticUploader(AbstractUploader):
# ========================> handlers/base # ========================> handlers/base
# from env import settings # from env import settings
@ -336,16 +338,19 @@ class FileType(Enum):
class BaseHandler: class BaseHandler:
def handle(self, filename: str) -> str: def handle(self, filename: str) -> str:
raise NotImplementedError raise NotImplementedError
class FileHandler: class FileHandler:
def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path): def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path):
self.handlers = handlers self.handlers = handlers
self.path = path self.path = path
def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler": def register(self, filetype: FileType,
handler: BaseHandler) -> "FileHandler":
self.handlers[filetype] = handler self.handlers[filetype] = handler
return self return self
@ -353,8 +358,8 @@ class FileHandler:
filetype = FileType.from_url(url) filetype = FileType.from_url(url)
data = requests.get(url).content data = requests.get(url).content
local_filename = os.path.join( local_filename = os.path.join(
"file", str(uuid.uuid4())[0:8] + filetype.to_extension() "file",
) str(uuid.uuid4())[0:8] + filetype.to_extension())
os.makedirs(os.path.dirname(local_filename), exist_ok=True) os.makedirs(os.path.dirname(local_filename), exist_ok=True)
with open(local_filename, "wb") as f: with open(local_filename, "wb") as f:
size = f.write(data) size = f.write(data)
@ -363,17 +368,15 @@ class FileHandler:
def handle(self, url: str) -> str: def handle(self, url: str) -> str:
try: try:
if url.startswith(os.environ.get("SERVER", "http://localhost:8000")): if url.startswith(os.environ.get("SERVER",
"http://localhost:8000")):
local_filepath = url[ local_filepath = url[
len(os.environ.get("SERVER", "http://localhost:8000")) + 1 : len(os.environ.get("SERVER", "http://localhost:8000")) + 1:]
]
local_filename = Path("file") / local_filepath.split("/")[-1] local_filename = Path("file") / local_filepath.split("/")[-1]
src = self.path / local_filepath src = self.path / local_filepath
dst = ( dst = (self.path /
self.path os.environ.get("PLAYGROUND_DIR", "./playground") /
/ os.environ.get("PLAYGROUND_DIR", "./playground") local_filename)
/ local_filename
)
os.makedirs(os.path.dirname(dst), exist_ok=True) os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(src, dst) shutil.copy(src, dst)
else: else:
@ -383,8 +386,7 @@ class FileHandler:
if FileType.from_url(url) == FileType.IMAGE: if FileType.from_url(url) == FileType.IMAGE:
raise Exception( raise Exception(
f"No handler for {FileType.from_url(url)}. " f"No handler for {FileType.from_url(url)}. "
"Please set USE_GPU to True in env/settings.py" "Please set USE_GPU to True in env/settings.py")
)
else: else:
raise Exception(f"No handler for {FileType.from_url(url)}") raise Exception(f"No handler for {FileType.from_url(url)}")
return handler.handle(local_filename) return handler.handle(local_filename)
@ -394,22 +396,21 @@ class FileHandler:
# => base end # => base end
# ===========================> # ===========================>
class CsvToDataframe(BaseHandler): class CsvToDataframe(BaseHandler):
def handle(self, filename: str): def handle(self, filename: str):
df = pd.read_csv(filename) df = pd.read_csv(filename)
description = ( description = (
f"Dataframe with {len(df)} rows and {len(df.columns)} columns. " f"Dataframe with {len(df)} rows and {len(df.columns)} columns. "
"Columns are: " "Columns are: "
f"{', '.join(df.columns)}" f"{', '.join(df.columns)}")
)
print( print(
f"\nProcessed CsvToDataframe, Input CSV: {filename}, Output Description:" f"\nProcessed CsvToDataframe, Input CSV: {filename}, Output Description:"
f" {description}" f" {description}")
)
return DATAFRAME_PROMPT.format(filename=filename, description=description) return DATAFRAME_PROMPT.format(filename=filename,
description=description)

@ -7,5 +7,6 @@ def extract_code_in_backticks_in_string(message: str) -> str:
""" """
pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks
match = re.search(pattern, message, re.DOTALL) # re.DOTALL to match newline chars match = re.search(pattern, message,
re.DOTALL) # re.DOTALL to match newline chars
return match.group(1).strip() if match else None return match.group(1).strip() if match else None

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save