dockerfile running

pull/128/head
Kye 1 year ago
parent 371da7944e
commit 991979dfc6

@ -0,0 +1,31 @@
# Use an official Python runtime as a parent image
FROM python:3.9-slim
# Set environment variables to make Python output unbuffered and disable the PIP cache
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
ENV PIP_NO_CACHE_DIR off
ENV PIP_DISABLE_PIP_VERSION_CHECK on
ENV PIP_DEFAULT_TIMEOUT 100
# Set the working directory in the container
WORKDIR /usr/src/app
# Copy the current directory contents into the container at /usr/src/app
COPY . .
# Install Poetry
RUN pip install poetry
# Disable virtualenv creation by poetry and install dependencies
RUN poetry config virtualenvs.create false
RUN poetry install --no-interaction --no-ansi
# Install the 'swarms' package if it's not included in the poetry.lock
RUN pip install swarms
# Assuming tests require pytest to run
RUN pip install pytest
# Run pytest on all tests in the tests directory
CMD find ./tests -name '*.py' -exec pytest {} +

@ -2,4 +2,4 @@
Autonomous swarm that optimizes UI autonomously Autonomous swarm that optimizes UI autonomously
GPT4Vision ->> GPT4 ->> UI GPT4Vision ->> GPT4 ->> UI
""" """

@ -9,6 +9,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from swarms.agents import * from swarms.agents import *
from swarms.swarms import * from swarms.swarms import *
from swarms.structs import * from swarms.structs import *
from swarms.models import * from swarms.models import *
from swarms.chunkers import * from swarms.chunkers import *
from swarms.workers import * from swarms.workers import *

@ -8,6 +8,7 @@ from swarms.agents.registry import Registry
# from swarms.agents.idea_to_image_agent import Idea2Image # from swarms.agents.idea_to_image_agent import Idea2Image
from swarms.agents.simple_agent import SimpleAgent from swarms.agents.simple_agent import SimpleAgent
"""Agent Infrastructure, models, memory, utils, tools""" """Agent Infrastructure, models, memory, utils, tools"""
__all__ = [ __all__ = [

@ -8,7 +8,8 @@ from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.memory import ChatMessageHistory from langchain.memory import ChatMessageHistory
from langchain.prompts.chat import ( from langchain.prompts.chat import (
BaseChatPromptTemplate,) BaseChatPromptTemplate,
)
from langchain.schema import ( from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
Document, Document,
@ -70,12 +71,14 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
send_token_limit: int = 4196 send_token_limit: int = 4196
def construct_full_prompt(self, goals: List[str]) -> str: def construct_full_prompt(self, goals: List[str]) -> str:
prompt_start = ("Your decisions must always be made independently " prompt_start = (
"without seeking user assistance.\n" "Your decisions must always be made independently "
"Play to your strengths as an LLM and pursue simple " "without seeking user assistance.\n"
"strategies with no legal complications.\n" "Play to your strengths as an LLM and pursue simple "
"If you have completed all your tasks, make sure to " "strategies with no legal complications.\n"
'use the "finish" command.') "If you have completed all your tasks, make sure to "
'use the "finish" command.'
)
# Construct full prompt # Construct full prompt
full_prompt = ( full_prompt = (
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n" f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
@ -87,23 +90,25 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
return full_prompt return full_prompt
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
base_prompt = SystemMessage( base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"]))
content=self.construct_full_prompt(kwargs["goals"]))
time_prompt = SystemMessage( time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}") content=f"The current time and date is {time.strftime('%c')}"
used_tokens = self.token_counter( )
base_prompt.content) + self.token_counter(time_prompt.content) used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
time_prompt.content
)
memory: VectorStoreRetriever = kwargs["memory"] memory: VectorStoreRetriever = kwargs["memory"]
previous_messages = kwargs["messages"] previous_messages = kwargs["messages"]
relevant_docs = memory.get_relevant_documents( relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
str(previous_messages[-10:]))
relevant_memory = [d.page_content for d in relevant_docs] relevant_memory = [d.page_content for d in relevant_docs]
relevant_memory_tokens = sum( relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory]) [self.token_counter(doc) for doc in relevant_memory]
)
while used_tokens + relevant_memory_tokens > 2500: while used_tokens + relevant_memory_tokens > 2500:
relevant_memory = relevant_memory[:-1] relevant_memory = relevant_memory[:-1]
relevant_memory_tokens = sum( relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory]) [self.token_counter(doc) for doc in relevant_memory]
)
content_format = ( content_format = (
f"This reminds you of these events from your past:\n{relevant_memory}\n\n" f"This reminds you of these events from your past:\n{relevant_memory}\n\n"
) )
@ -141,23 +146,13 @@ class PromptGenerator:
self.performance_evaluation: List[str] = [] self.performance_evaluation: List[str] = []
self.response_format = { self.response_format = {
"thoughts": { "thoughts": {
"text": "text": "thought",
"thought", "reasoning": "reasoning",
"reasoning": "plan": "- short bulleted\n- list that conveys\n- long-term plan",
"reasoning", "criticism": "constructive self-criticism",
"plan": "speak": "thoughts summary to say to user",
"- short bulleted\n- list that conveys\n- long-term plan",
"criticism":
"constructive self-criticism",
"speak":
"thoughts summary to say to user",
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}, },
"command": {"name": "command name", "args": {"arg name": "value"}},
} }
def add_constraint(self, constraint: str) -> None: def add_constraint(self, constraint: str) -> None:
@ -195,9 +190,7 @@ class PromptGenerator:
""" """
self.performance_evaluation.append(evaluation) self.performance_evaluation.append(evaluation)
def _generate_numbered_list(self, def _generate_numbered_list(self, items: list, item_type: str = "list") -> str:
items: list,
item_type: str = "list") -> str:
""" """
Generate a numbered list from given items based on the item_type. Generate a numbered list from given items based on the item_type.
@ -215,11 +208,16 @@ class PromptGenerator:
for i, item in enumerate(items) for i, item in enumerate(items)
] ]
finish_description = ( finish_description = (
"use this to signal that you have finished all your objectives") "use this to signal that you have finished all your objectives"
finish_args = ('"response": "final response to let ' )
'people know you have finished your objectives"') finish_args = (
finish_string = (f"{len(items) + 1}. {FINISH_NAME}: " '"response": "final response to let '
f"{finish_description}, args: {finish_args}") 'people know you have finished your objectives"'
)
finish_string = (
f"{len(items) + 1}. {FINISH_NAME}: "
f"{finish_description}, args: {finish_args}"
)
return "\n".join(command_strings + [finish_string]) return "\n".join(command_strings + [finish_string])
else: else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
@ -240,7 +238,8 @@ class PromptGenerator:
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below " "You should only respond in JSON format as described below "
f"\nResponse Format: \n{formatted_response_format} " f"\nResponse Format: \n{formatted_response_format} "
"\nEnsure the response can be parsed by Python json.loads") "\nEnsure the response can be parsed by Python json.loads"
)
return prompt_string return prompt_string
@ -261,11 +260,13 @@ def get_prompt(tools: List[BaseTool]) -> str:
prompt_generator.add_constraint( prompt_generator.add_constraint(
"~16000 word limit for short term memory. " "~16000 word limit for short term memory. "
"Your short term memory is short, " "Your short term memory is short, "
"so immediately save important information to files.") "so immediately save important information to files."
)
prompt_generator.add_constraint( prompt_generator.add_constraint(
"If you are unsure how you previously did something " "If you are unsure how you previously did something "
"or want to recall past events, " "or want to recall past events, "
"thinking about similar events will help you remember.") "thinking about similar events will help you remember."
)
prompt_generator.add_constraint("No user assistance") prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint( prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"' 'Exclusively use the commands listed in double quotes e.g. "command name"'
@ -277,23 +278,29 @@ def get_prompt(tools: List[BaseTool]) -> str:
# Add resources to the PromptGenerator object # Add resources to the PromptGenerator object
prompt_generator.add_resource( prompt_generator.add_resource(
"Internet access for searches and information gathering.") "Internet access for searches and information gathering."
)
prompt_generator.add_resource("Long Term memory management.") prompt_generator.add_resource("Long Term memory management.")
prompt_generator.add_resource( prompt_generator.add_resource(
"GPT-3.5 powered Agents for delegation of simple tasks.") "GPT-3.5 powered Agents for delegation of simple tasks."
)
prompt_generator.add_resource("File output.") prompt_generator.add_resource("File output.")
# Add performance evaluations to the PromptGenerator object # Add performance evaluations to the PromptGenerator object
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Continuously review and analyze your actions " "Continuously review and analyze your actions "
"to ensure you are performing to the best of your abilities.") "to ensure you are performing to the best of your abilities."
)
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Constructively self-criticize your big-picture behavior constantly.") "Constructively self-criticize your big-picture behavior constantly."
)
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Reflect on past decisions and strategies to refine your approach.") "Reflect on past decisions and strategies to refine your approach."
)
prompt_generator.add_performance_evaluation( prompt_generator.add_performance_evaluation(
"Every command has a cost, so be smart and efficient. " "Every command has a cost, so be smart and efficient. "
"Aim to complete tasks in the least number of steps.") "Aim to complete tasks in the least number of steps."
)
# Generate the prompt string # Generate the prompt string
prompt_string = prompt_generator.generate_prompt_string() prompt_string = prompt_generator.generate_prompt_string()
@ -364,8 +371,10 @@ class AutoGPT:
) )
def run(self, goals: List[str]) -> str: def run(self, goals: List[str]) -> str:
user_input = ("Determine which next command to use, " user_input = (
"and respond using the format specified above:") "Determine which next command to use, "
"and respond using the format specified above:"
)
# Interaction Loop # Interaction Loop
loop_count = 0 loop_count = 0
while True: while True:
@ -382,10 +391,8 @@ class AutoGPT:
# Print Assistant thoughts # Print Assistant thoughts
print(assistant_reply) print(assistant_reply)
self.chat_history_memory.add_message( self.chat_history_memory.add_message(HumanMessage(content=user_input))
HumanMessage(content=user_input)) self.chat_history_memory.add_message(AIMessage(content=assistant_reply))
self.chat_history_memory.add_message(
AIMessage(content=assistant_reply))
# Get command name and arguments # Get command name and arguments
action = self.output_parser.parse(assistant_reply) action = self.output_parser.parse(assistant_reply)
@ -411,7 +418,8 @@ class AutoGPT:
result = ( result = (
f"Unknown command '{action.name}'. " f"Unknown command '{action.name}'. "
"Please refer to the 'COMMANDS' list for available " "Please refer to the 'COMMANDS' list for available "
"commands and only respond in the specified JSON format.") "commands and only respond in the specified JSON format."
)
memory_to_add = f"Assistant Reply: {assistant_reply} \nResult: {result} " memory_to_add = f"Assistant Reply: {assistant_reply} \nResult: {result} "
if self.feedback_tool is not None: if self.feedback_tool is not None:

@ -4,13 +4,13 @@ import time
import openai_model import openai_model
logging.basicConfig(level=logging.INFO, logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s") level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OpenAI: class OpenAI:
def __init__( def __init__(
self, self,
api_key, api_key,
@ -68,13 +68,16 @@ class OpenAI:
temperature=temperature, temperature=temperature,
) )
with open("openai.logs", "a") as log_file: with open("openai.logs", "a") as log_file:
log_file.write("\n" + "-----------" + "\n" + "Prompt : " + log_file.write(
prompt + "\n") "\n" + "-----------" + "\n" + "Prompt : " + prompt + "\n"
)
return response return response
except openai_model.error.RateLimitError as e: except openai_model.error.RateLimitError as e:
sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30) sleep_duratoin = os.environ.get("OPENAI_RATE_TIMEOUT", 30)
print(f"{str(e)}, sleep for {sleep_duratoin}s, set it by env" print(
" OPENAI_RATE_TIMEOUT") f"{str(e)}, sleep for {sleep_duratoin}s, set it by env"
" OPENAI_RATE_TIMEOUT"
)
time.sleep(sleep_duratoin) time.sleep(sleep_duratoin)
def openai_choice2text_handler(self, choice): def openai_choice2text_handler(self, choice):
@ -97,16 +100,11 @@ class OpenAI:
else: else:
response = self.run(prompt, 300, 0.5, k) response = self.run(prompt, 300, 0.5, k)
thoughts = [ thoughts = [
self.openai_choice2text_handler(choice) self.openai_choice2text_handler(choice) for choice in response.choices
for choice in response.choices
] ]
return thoughts return thoughts
def generate_thoughts(self, def generate_thoughts(self, state, k, initial_prompt, rejected_solutions=None):
state,
k,
initial_prompt,
rejected_solutions=None):
if isinstance(state, str): if isinstance(state, str):
pass pass
else: else:
@ -179,8 +177,7 @@ class OpenAI:
""" """
response = self.run(prompt, 10, 1) response = self.run(prompt, 10, 1)
try: try:
value_text = self.openai_choice2text_handler( value_text = self.openai_choice2text_handler(response.choices[0])
response.choices[0])
# print(f'state: {value_text}') # print(f'state: {value_text}')
value = float(value_text) value = float(value_text)
print(f"Evaluated Thought Value: {value}") print(f"Evaluated Thought Value: {value}")
@ -190,12 +187,10 @@ class OpenAI:
return state_values return state_values
else: else:
raise ValueError( raise ValueError("Invalid evaluation strategy. Choose 'value' or 'vote'.")
"Invalid evaluation strategy. Choose 'value' or 'vote'.")
class AoTAgent: class AoTAgent:
def __init__( def __init__(
self, self,
num_thoughts: int = None, num_thoughts: int = None,
@ -227,8 +222,7 @@ class AoTAgent:
return None return None
best_state, _ = max(self.output, key=lambda x: x[1]) best_state, _ = max(self.output, key=lambda x: x[1])
solution = self.model.generate_solution(self.initial_prompt, solution = self.model.generate_solution(self.initial_prompt, best_state)
best_state)
print(f"Solution is {solution}") print(f"Solution is {solution}")
return solution if solution else best_state return solution if solution else best_state
except Exception as error: except Exception as error:
@ -245,8 +239,11 @@ class AoTAgent:
for next_state in thoughts: for next_state in thoughts:
state_value = self.evaluated_thoughts[next_state] state_value = self.evaluated_thoughts[next_state]
if state_value > self.value_threshold: if state_value > self.value_threshold:
child = ((state, next_state) if isinstance(state, str) else child = (
(*state, next_state)) (state, next_state)
if isinstance(state, str)
else (*state, next_state)
)
self.dfs(child, step + 1) self.dfs(child, step + 1)
# backtracking # backtracking
@ -256,14 +253,17 @@ class AoTAgent:
continue continue
def generate_and_filter_thoughts(self, state): def generate_and_filter_thoughts(self, state):
thoughts = self.model.generate_thoughts(state, self.num_thoughts, thoughts = self.model.generate_thoughts(
self.initial_prompt) state, self.num_thoughts, self.initial_prompt
)
self.evaluated_thoughts = self.model.evaluate_states( self.evaluated_thoughts = self.model.evaluate_states(
thoughts, self.initial_prompt) thoughts, self.initial_prompt
)
filtered_thoughts = [ filtered_thoughts = [
thought for thought in thoughts thought
for thought in thoughts
if self.evaluated_thoughts[thought] >= self.pruning_threshold if self.evaluated_thoughts[thought] >= self.pruning_threshold
] ]
print(f"filtered_thoughts: {filtered_thoughts}") print(f"filtered_thoughts: {filtered_thoughts}")

@ -38,8 +38,7 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None):
if not os.path.exists("agents"): if not os.path.exists("agents"):
os.makedirs("agents") os.makedirs("agents")
if os.path.exists( if os.path.exists(f"agents/{agent_name}.py") and config.environment != "local":
f"agents/{agent_name}.py") and config.environment != "local":
if not _is_blank_agent(agent_name=agent_name): if not _is_blank_agent(agent_name=agent_name):
raise Exception(f"Agent with name {agent_name} already exists") raise Exception(f"Agent with name {agent_name} already exists")
driver = get_driver( # noqa: F841 driver = get_driver( # noqa: F841
@ -55,10 +54,12 @@ def record(agent_name: str, autotab_ext_path: Optional[str] = None):
print( print(
"\033[34mYou have the Python debugger open, you can run commands in it like you" "\033[34mYou have the Python debugger open, you can run commands in it like you"
" would in a normal Python shell.\033[0m") " would in a normal Python shell.\033[0m"
)
print( print(
"\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and" "\033[34mTo exit, type 'q' and press enter. For a list of commands type '?' and"
" press enter.\033[0m") " press enter.\033[0m"
)
breakpoint() breakpoint()
@ -78,13 +79,12 @@ def extract_domain_from_url(url: str):
class AutotabChromeDriver(uc.Chrome): class AutotabChromeDriver(uc.Chrome):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def find_element_with_retry(self, def find_element_with_retry(
by=By.ID, self, by=By.ID, value: Optional[str] = None
value: Optional[str] = None) -> WebElement: ) -> WebElement:
try: try:
return super().find_element(by, value) return super().find_element(by, value)
except Exception as e: except Exception as e:
@ -102,8 +102,11 @@ def open_plugin(driver: AutotabChromeDriver):
def open_plugin_and_login(driver: AutotabChromeDriver): def open_plugin_and_login(driver: AutotabChromeDriver):
if config.autotab_api_key is not None: if config.autotab_api_key is not None:
backend_url = ("http://localhost:8000" if config.environment == "local" backend_url = (
else "https://api.autotab.com") "http://localhost:8000"
if config.environment == "local"
else "https://api.autotab.com"
)
driver.get(f"{backend_url}/auth/signin-api-key-page") driver.get(f"{backend_url}/auth/signin-api-key-page")
response = requests.post( response = requests.post(
f"{backend_url}/auth/signin-api-key", f"{backend_url}/auth/signin-api-key",
@ -116,7 +119,8 @@ def open_plugin_and_login(driver: AutotabChromeDriver):
else: else:
raise Exception( raise Exception(
f"Error {response.status_code} from backend while logging you in" f"Error {response.status_code} from backend while logging you in"
f" with your API key: {response.text}") f" with your API key: {response.text}"
)
cookie["name"] = cookie["key"] cookie["name"] = cookie["key"]
del cookie["key"] del cookie["key"]
driver.add_cookie(cookie) driver.add_cookie(cookie)
@ -126,21 +130,26 @@ def open_plugin_and_login(driver: AutotabChromeDriver):
else: else:
print("No autotab API key found, heading to autotab.com to sign up") print("No autotab API key found, heading to autotab.com to sign up")
url = ("http://localhost:3000/dashboard" if config.environment url = (
== "local" else "https://autotab.com/dashboard") "http://localhost:3000/dashboard"
if config.environment == "local"
else "https://autotab.com/dashboard"
)
driver.get(url) driver.get(url)
time.sleep(0.5) time.sleep(0.5)
open_plugin(driver) open_plugin(driver)
def get_driver(autotab_ext_path: Optional[str] = None, def get_driver(
record_mode: bool = False) -> AutotabChromeDriver: autotab_ext_path: Optional[str] = None, record_mode: bool = False
) -> AutotabChromeDriver:
options = webdriver.ChromeOptions() options = webdriver.ChromeOptions()
options.add_argument("--no-sandbox") # Necessary for running options.add_argument("--no-sandbox") # Necessary for running
options.add_argument( options.add_argument(
"--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" "--user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
" (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36") " (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36"
)
options.add_argument("--enable-webgl") options.add_argument("--enable-webgl")
options.add_argument("--enable-3d-apis") options.add_argument("--enable-3d-apis")
options.add_argument("--enable-clipboard-read-write") options.add_argument("--enable-clipboard-read-write")
@ -229,8 +238,7 @@ class Config(BaseModel):
return cls( return cls(
autotab_api_key=autotab_api_key, autotab_api_key=autotab_api_key,
credentials=_credentials, credentials=_credentials,
google_credentials=GoogleCredentials( google_credentials=GoogleCredentials(credentials=google_credentials),
credentials=google_credentials),
chrome_binary_location=config.get("chrome_binary_location"), chrome_binary_location=config.get("chrome_binary_location"),
environment=config.get("environment", "prod"), environment=config.get("environment", "prod"),
) )
@ -248,9 +256,9 @@ def is_signed_in_to_google(driver):
return len([c for c in cookies if c["name"] == "SAPISID"]) != 0 return len([c for c in cookies if c["name"] == "SAPISID"]) != 0
def google_login(driver, def google_login(
credentials: Optional[SiteCredentials] = None, driver, credentials: Optional[SiteCredentials] = None, navigate: bool = True
navigate: bool = True): ):
print("Logging in to Google") print("Logging in to Google")
if navigate: if navigate:
driver.get("https://accounts.google.com/") driver.get("https://accounts.google.com/")
@ -282,7 +290,8 @@ def google_login(driver,
email_input.send_keys(credentials.email) email_input.send_keys(credentials.email)
email_input.send_keys(Keys.ENTER) email_input.send_keys(Keys.ENTER)
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))) EC.element_to_be_clickable((By.CSS_SELECTOR, "[type='password']"))
)
password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']") password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
password_input.send_keys(credentials.password) password_input.send_keys(credentials.password)
@ -305,20 +314,21 @@ def google_login(driver,
cookies = driver.get_cookies() cookies = driver.get_cookies()
cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"] cookie_names = ["__Host-GAPS", "SMSV", "NID", "ACCOUNT_CHOOSER"]
google_cookies = [ google_cookies = [
cookie for cookie in cookies cookie
if cookie["domain"] in [".google.com", "accounts.google.com"] and for cookie in cookies
cookie["name"] in cookie_names if cookie["domain"] in [".google.com", "accounts.google.com"]
and cookie["name"] in cookie_names
] ]
with open("google_cookies.json", "w") as f: with open("google_cookies.json", "w") as f:
json.dump(google_cookies, f) json.dump(google_cookies, f)
# Log back in # Log back in
login_button = driver.find_element( login_button = driver.find_element(
By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']") By.CSS_SELECTOR, f"[data-identifier='{credentials.email}']"
)
login_button.click() login_button.click()
time.sleep(1) time.sleep(1)
password_input = driver.find_element(By.CSS_SELECTOR, password_input = driver.find_element(By.CSS_SELECTOR, "[type='password']")
"[type='password']")
password_input.send_keys(credentials.password) password_input.send_keys(credentials.password)
password_input.send_keys(Keys.ENTER) password_input.send_keys(Keys.ENTER)
@ -333,7 +343,8 @@ def login(driver, url: str):
login_url = credentials.login_url login_url = credentials.login_url
if credentials.login_with_google_account: if credentials.login_with_google_account:
google_credentials = config.google_credentials.credentials[ google_credentials = config.google_credentials.credentials[
credentials.login_with_google_account] credentials.login_with_google_account
]
_login_with_google(driver, login_url, google_credentials) _login_with_google(driver, login_url, google_credentials)
else: else:
_login(driver, login_url, credentials=credentials) _login(driver, login_url, credentials=credentials)
@ -360,15 +371,16 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
driver.get(url) driver.get(url)
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))) EC.presence_of_element_located((By.TAG_NAME, "body"))
)
main_window = driver.current_window_handle main_window = driver.current_window_handle
xpath = ( xpath = (
"//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with" "//*[contains(text(), 'Continue with Google') or contains(text(), 'Sign in with"
" Google') or contains(@title, 'Sign in with Google')]") " Google') or contains(@title, 'Sign in with Google')]"
)
WebDriverWait(driver, WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.XPATH, xpath)))
10).until(EC.presence_of_element_located((By.XPATH, xpath)))
driver.find_element( driver.find_element(
By.XPATH, By.XPATH,
xpath, xpath,
@ -376,8 +388,8 @@ def _login_with_google(driver, url: str, google_credentials: SiteCredentials):
driver.switch_to.window(driver.window_handles[-1]) driver.switch_to.window(driver.window_handles[-1])
driver.find_element( driver.find_element(
By.XPATH, By.XPATH, f"//*[contains(text(), '{google_credentials.email}')]"
f"//*[contains(text(), '{google_credentials.email}')]").click() ).click()
driver.switch_to.window(main_window) driver.switch_to.window(main_window)
@ -430,11 +442,8 @@ def should_update():
# Parse the XML file # Parse the XML file
root = ET.fromstring(xml_content) root = ET.fromstring(xml_content)
namespaces = { namespaces = {"ns": "http://www.google.com/update2/response"} # add namespaces
"ns": "http://www.google.com/update2/response" xml_version = root.find(".//ns:app/ns:updatecheck", namespaces).get("version")
} # add namespaces
xml_version = root.find(".//ns:app/ns:updatecheck",
namespaces).get("version")
# Load the local JSON file # Load the local JSON file
with open("src/extension/autotab/manifest.json", "r") as f: with open("src/extension/autotab/manifest.json", "r") as f:

@ -56,24 +56,23 @@ HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
def get_remote_tools(organization="huggingface-tools"): def get_remote_tools(organization="huggingface-tools"):
if is_offline_mode(): if is_offline_mode():
logger.info( logger.info("You are in offline mode, so remote tools are not available.")
"You are in offline mode, so remote tools are not available.")
return {} return {}
spaces = list_spaces(author=organization) spaces = list_spaces(author=organization)
tools = {} tools = {}
for space_info in spaces: for space_info in spaces:
repo_id = space_info.id repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, resolved_config_file = hf_hub_download(
TOOL_CONFIG_FILE, repo_id, TOOL_CONFIG_FILE, repo_type="space"
repo_type="space") )
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader) config = json.load(reader)
task = repo_id.split("/")[-1] task = repo_id.split("/")[-1]
tools[config["name"]] = PreTool(task=task, tools[config["name"]] = PreTool(
description=config["description"], task=task, description=config["description"], repo_id=repo_id
repo_id=repo_id) )
return tools return tools
@ -93,7 +92,8 @@ def _setup_default_tools():
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
description = tool_class.description description = tool_class.description
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool( HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(
task=task_name, description=description, repo_id=None) task=task_name, description=description, repo_id=None
)
if not is_offline_mode(): if not is_offline_mode():
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB: for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
@ -197,19 +197,18 @@ class Agent:
one of the default tools, that default tool will be overridden. one of the default tools, that default tool will be overridden.
""" """
def __init__(self, def __init__(
chat_prompt_template=None, self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
run_prompt_template=None, ):
additional_tools=None):
_setup_default_tools() _setup_default_tools()
agent_name = self.__class__.__name__ agent_name = self.__class__.__name__
self.chat_prompt_template = download_prompt(chat_prompt_template, self.chat_prompt_template = download_prompt(
agent_name, chat_prompt_template, agent_name, mode="chat"
mode="chat") )
self.run_prompt_template = download_prompt(run_prompt_template, self.run_prompt_template = download_prompt(
agent_name, run_prompt_template, agent_name, mode="run"
mode="run") )
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy() self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
self.log = print self.log = print
if additional_tools is not None: if additional_tools is not None:
@ -225,16 +224,17 @@ class Agent:
} }
self._toolbox.update(additional_tools) self._toolbox.update(additional_tools)
if len(replacements) > 1: if len(replacements) > 1:
names = "\n".join( names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
[f"- {n}: {t}" for n, t in replacements.items()])
logger.warning( logger.warning(
"The following tools have been replaced by the ones provided in" "The following tools have been replaced by the ones provided in"
f" `additional_tools`:\n{names}.") f" `additional_tools`:\n{names}."
)
elif len(replacements) == 1: elif len(replacements) == 1:
name = list(replacements.keys())[0] name = list(replacements.keys())[0]
logger.warning( logger.warning(
f"{name} has been replaced by {replacements[name]} as provided in" f"{name} has been replaced by {replacements[name]} as provided in"
" `additional_tools`.") " `additional_tools`."
)
self.prepare_for_new_chat() self.prepare_for_new_chat()
@ -244,20 +244,17 @@ class Agent:
return self._toolbox return self._toolbox
def format_prompt(self, task, chat_mode=False): def format_prompt(self, task, chat_mode=False):
description = "\n".join([ description = "\n".join(
f"- {name}: {tool.description}" [f"- {name}: {tool.description}" for name, tool in self.toolbox.items()]
for name, tool in self.toolbox.items() )
])
if chat_mode: if chat_mode:
if self.chat_history is None: if self.chat_history is None:
prompt = self.chat_prompt_template.replace( prompt = self.chat_prompt_template.replace("<<all_tools>>", description)
"<<all_tools>>", description)
else: else:
prompt = self.chat_history prompt = self.chat_history
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task) prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
else: else:
prompt = self.run_prompt_template.replace("<<all_tools>>", prompt = self.run_prompt_template.replace("<<all_tools>>", description)
description)
prompt = prompt.replace("<<prompt>>", task) prompt = prompt.replace("<<prompt>>", task)
return prompt return prompt
@ -306,19 +303,14 @@ class Agent:
if not return_code: if not return_code:
self.log("\n\n==Result==") self.log("\n\n==Result==")
self.cached_tools = resolve_tools( self.cached_tools = resolve_tools(
code, code, self.toolbox, remote=remote, cached_tools=self.cached_tools
self.toolbox, )
remote=remote,
cached_tools=self.cached_tools)
self.chat_state.update(kwargs) self.chat_state.update(kwargs)
return evaluate(code, return evaluate(
self.cached_tools, code, self.cached_tools, self.chat_state, chat_mode=True
self.chat_state, )
chat_mode=True)
else: else:
tool_code = get_tool_creation_code(code, tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
self.toolbox,
remote=remote)
return f"{tool_code}\n{code}" return f"{tool_code}\n{code}"
def prepare_for_new_chat(self): def prepare_for_new_chat(self):
@ -360,15 +352,12 @@ class Agent:
self.log(f"\n\n==Code generated by the agent==\n{code}") self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code: if not return_code:
self.log("\n\n==Result==") self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.cached_tools = resolve_tools(
self.toolbox, code, self.toolbox, remote=remote, cached_tools=self.cached_tools
remote=remote, )
cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy()) return evaluate(code, self.cached_tools, state=kwargs.copy())
else: else:
tool_code = get_tool_creation_code(code, tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
self.toolbox,
remote=remote)
return f"{tool_code}\n{code}" return f"{tool_code}\n{code}"
def generate_one(self, prompt, stop): def generate_one(self, prompt, stop):
@ -428,7 +417,8 @@ class HFAgent(Agent):
): ):
if not is_openai_available(): if not is_openai_available():
raise ImportError( raise ImportError(
"Using `OpenAiAgent` requires `openai`: `pip install openai`.") "Using `OpenAiAgent` requires `openai`: `pip install openai`."
)
if api_key is None: if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY", None) api_key = os.environ.get("OPENAI_API_KEY", None)
@ -436,7 +426,8 @@ class HFAgent(Agent):
raise ValueError( raise ValueError(
"You need an openai key to use `OpenAIAgent`. You can get one here: Get" "You need an openai key to use `OpenAIAgent`. You can get one here: Get"
" one here https://openai.com/api/`. If you have one, set it in your" " one here https://openai.com/api/`. If you have one, set it in your"
" env with `os.environ['OPENAI_API_KEY'] = xxx.") " env with `os.environ['OPENAI_API_KEY'] = xxx."
)
else: else:
openai.api_key = api_key openai.api_key = api_key
self.model = model self.model = model
@ -461,10 +452,7 @@ class HFAgent(Agent):
def _chat_generate(self, prompt, stop): def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create( result = openai.ChatCompletion.create(
model=self.model, model=self.model,
messages=[{ messages=[{"role": "user", "content": prompt}],
"role": "user",
"content": prompt
}],
temperature=0, temperature=0,
stop=stop, stop=stop,
) )
@ -542,7 +530,8 @@ class AzureOpenAI(Agent):
): ):
if not is_openai_available(): if not is_openai_available():
raise ImportError( raise ImportError(
"Using `OpenAiAgent` requires `openai`: `pip install openai`.") "Using `OpenAiAgent` requires `openai`: `pip install openai`."
)
self.deployment_id = deployment_id self.deployment_id = deployment_id
openai.api_type = "azure" openai.api_type = "azure"
@ -552,7 +541,8 @@ class AzureOpenAI(Agent):
raise ValueError( raise ValueError(
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have" "You need an Azure openAI key to use `AzureOpenAIAgent`. If you have"
" one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] =" " one, set it in your env with `os.environ['AZURE_OPENAI_API_KEY'] ="
" xxx.") " xxx."
)
else: else:
openai.api_key = api_key openai.api_key = api_key
if resource_name is None: if resource_name is None:
@ -561,7 +551,8 @@ class AzureOpenAI(Agent):
raise ValueError( raise ValueError(
"You need a resource_name to use `AzureOpenAIAgent`. If you have one," "You need a resource_name to use `AzureOpenAIAgent`. If you have one,"
" set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] =" " set it in your env with `os.environ['AZURE_OPENAI_RESOURCE_NAME'] ="
" xxx.") " xxx."
)
else: else:
openai.api_base = f"https://{resource_name}.openai.azure.com" openai.api_base = f"https://{resource_name}.openai.azure.com"
openai.api_version = api_version openai.api_version = api_version
@ -591,10 +582,7 @@ class AzureOpenAI(Agent):
def _chat_generate(self, prompt, stop): def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create( result = openai.ChatCompletion.create(
engine=self.deployment_id, engine=self.deployment_id,
messages=[{ messages=[{"role": "user", "content": prompt}],
"role": "user",
"content": prompt
}],
temperature=0, temperature=0,
stop=stop, stop=stop,
) )

@ -88,8 +88,9 @@ class MetaPrompterAgent:
Assistant: Assistant:
""" """
prompt = PromptTemplate(input_variables=["history", "human_input"], prompt = PromptTemplate(
template=template) input_variables=["history", "human_input"], template=template
)
self.chain = LLMChain( self.chain = LLMChain(
llm=self.llm(), llm=self.llm(),
@ -101,15 +102,13 @@ class MetaPrompterAgent:
def get_chat_history(self, chain_memory): def get_chat_history(self, chain_memory):
"""Get Chat History from the memory""" """Get Chat History from the memory"""
memory_key = chain_memory.memory_key memory_key = chain_memory.memory_key
chat_history = chain_memory.load_memory_variables( chat_history = chain_memory.load_memory_variables(memory_key)[memory_key]
memory_key)[memory_key]
return chat_history return chat_history
def get_new_instructions(self, meta_output): def get_new_instructions(self, meta_output):
"""Get New Instructions from the meta_output""" """Get New Instructions from the meta_output"""
delimiter = "Instructions: " delimiter = "Instructions: "
new_instructions = meta_output[meta_output.find(delimiter) + new_instructions = meta_output[meta_output.find(delimiter) + len(delimiter) :]
len(delimiter):]
return new_instructions return new_instructions
def run(self, task: str): def run(self, task: str):
@ -150,7 +149,8 @@ class MetaPrompterAgent:
meta_chain = self.initialize_meta_chain() meta_chain = self.initialize_meta_chain()
meta_output = meta_chain.predict( meta_output = meta_chain.predict(
chat_history=self.get_chat_history(chain.memory)) chat_history=self.get_chat_history(chain.memory)
)
print(f"Feedback: {meta_output}") print(f"Feedback: {meta_output}")
self.instructions = self.get_new_instructions(meta_output) self.instructions = self.get_new_instructions(meta_output)

File diff suppressed because it is too large Load Diff

@ -2,7 +2,6 @@
class Replicator: class Replicator:
def __init__( def __init__(
self, self,
model_name, model_name,

@ -3,20 +3,23 @@ from typing import Dict, List
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import ( from langchain_experimental.autonomous_agents.hugginggpt.repsonse_generator import (
load_response_generator,) load_response_generator,
)
from langchain_experimental.autonomous_agents.hugginggpt.task_executor import ( from langchain_experimental.autonomous_agents.hugginggpt.task_executor import (
TaskExecutor,) TaskExecutor,
)
from langchain_experimental.autonomous_agents.hugginggpt.task_planner import ( from langchain_experimental.autonomous_agents.hugginggpt.task_planner import (
load_chat_planner,) load_chat_planner,
)
from transformers import load_tool from transformers import load_tool
from swarms.agents.message import Message from swarms.agents.message import Message
class Step: class Step:
def __init__(
def __init__(self, task: str, id: int, dep: List[int], args: Dict[str, str], self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool
tool: BaseTool): ):
self.task = task self.task = task
self.id = id self.id = id
self.dep = dep self.dep = dep
@ -25,7 +28,6 @@ class Step:
class Plan: class Plan:
def __init__(self, steps: List[Step]): def __init__(self, steps: List[Step]):
self.steps = steps self.steps = steps
@ -71,7 +73,8 @@ class OmniModalAgent:
print("Loading tools...") print("Loading tools...")
self.tools = [ self.tools = [
load_tool(tool_name) for tool_name in [ load_tool(tool_name)
for tool_name in [
"document-question-answering", "document-question-answering",
"image-captioning", "image-captioning",
"image-question-answering", "image-question-answering",
@ -96,15 +99,18 @@ class OmniModalAgent:
def run(self, input: str) -> str: def run(self, input: str) -> str:
"""Run the OmniAgent""" """Run the OmniAgent"""
plan = self.chat_planner.plan(inputs={ plan = self.chat_planner.plan(
"input": input, inputs={
"hf_tools": self.tools, "input": input,
}) "hf_tools": self.tools,
}
)
self.task_executor = TaskExecutor(plan) self.task_executor = TaskExecutor(plan)
self.task_executor.run() self.task_executor.run()
response = self.response_generator.generate( response = self.response_generator.generate(
{"task_execution": self.task_executor}) {"task_execution": self.task_executor}
)
return response return response

@ -145,12 +145,13 @@ def setup_knowledge_base(product_catalog: str = None):
llm = OpenAI(temperature=0) llm = OpenAI(temperature=0)
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_texts(texts, docsearch = Chroma.from_texts(
embeddings, texts, embeddings, collection_name="product-knowledge-base"
collection_name="product-knowledge-base") )
knowledge_base = RetrievalQA.from_chain_type( knowledge_base = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", retriever=docsearch.as_retriever()) llm=llm, chain_type="stuff", retriever=docsearch.as_retriever()
)
return knowledge_base return knowledge_base
@ -162,8 +163,8 @@ def get_tools(product_catalog):
Tool( Tool(
name="ProductSearch", name="ProductSearch",
func=knowledge_base.run, func=knowledge_base.run,
description= description=(
("useful for when you need to answer questions about product information" "useful for when you need to answer questions about product information"
), ),
), ),
# omnimodal agent # omnimodal agent
@ -193,7 +194,8 @@ class CustomPromptTemplateForTools(StringPromptTemplate):
tools = self.tools_getter(kwargs["input"]) tools = self.tools_getter(kwargs["input"])
# Create a tools variable from the list of tools provided # Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join( kwargs["tools"] = "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]) [f"{tool.name}: {tool.description}" for tool in tools]
)
# Create a list of tool names for the tools provided # Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in tools]) kwargs["tool_names"] = ", ".join([tool.name for tool in tools])
return self.template.format(**kwargs) return self.template.format(**kwargs)
@ -216,7 +218,8 @@ class SalesConvoOutputParser(AgentOutputParser):
print("-------") print("-------")
if f"{self.ai_prefix}:" in text: if f"{self.ai_prefix}:" in text:
return AgentFinish( return AgentFinish(
{"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text) {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text
)
regex = r"Action: (.*?)[\n]*Action Input: (.*)" regex = r"Action: (.*?)[\n]*Action Input: (.*)"
match = re.search(regex, text) match = re.search(regex, text)
if not match: if not match:
@ -225,15 +228,15 @@ class SalesConvoOutputParser(AgentOutputParser):
{ {
"output": ( "output": (
"I apologize, I was unable to find the answer to your question." "I apologize, I was unable to find the answer to your question."
" Is there anything else I can help with?") " Is there anything else I can help with?"
)
}, },
text, text,
) )
# raise OutputParserException(f"Could not parse LLM output: `{text}`") # raise OutputParserException(f"Could not parse LLM output: `{text}`")
action = match.group(1) action = match.group(1)
action_input = match.group(2) action_input = match.group(2)
return AgentAction(action.strip(), return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)
action_input.strip(" ").strip('"'), text)
@property @property
def _type(self) -> str: def _type(self) -> str:
@ -261,11 +264,13 @@ class ProfitPilot(Chain, BaseModel):
"2": ( "2": (
"Qualification: Qualify the prospect by confirming if they are the right" "Qualification: Qualify the prospect by confirming if they are the right"
" person to talk to regarding your product/service. Ensure that they have" " person to talk to regarding your product/service. Ensure that they have"
" the authority to make purchasing decisions."), " the authority to make purchasing decisions."
),
"3": ( "3": (
"Value proposition: Briefly explain how your product/service can benefit" "Value proposition: Briefly explain how your product/service can benefit"
" the prospect. Focus on the unique selling points and value proposition of" " the prospect. Focus on the unique selling points and value proposition of"
" your product/service that sets it apart from competitors."), " your product/service that sets it apart from competitors."
),
"4": ( "4": (
"Needs analysis: Ask open-ended questions to uncover the prospect's needs" "Needs analysis: Ask open-ended questions to uncover the prospect's needs"
" and pain points. Listen carefully to their responses and take notes." " and pain points. Listen carefully to their responses and take notes."
@ -277,11 +282,13 @@ class ProfitPilot(Chain, BaseModel):
"6": ( "6": (
"Objection handling: Address any objections that the prospect may have" "Objection handling: Address any objections that the prospect may have"
" regarding your product/service. Be prepared to provide evidence or" " regarding your product/service. Be prepared to provide evidence or"
" testimonials to support your claims."), " testimonials to support your claims."
),
"7": ( "7": (
"Close: Ask for the sale by proposing a next step. This could be a demo, a" "Close: Ask for the sale by proposing a next step. This could be a demo, a"
" trial or a meeting with decision-makers. Ensure to summarize what has" " trial or a meeting with decision-makers. Ensure to summarize what has"
" been discussed and reiterate the benefits."), " been discussed and reiterate the benefits."
),
} }
salesperson_name: str = "Ted Lasso" salesperson_name: str = "Ted Lasso"
@ -291,16 +298,19 @@ class ProfitPilot(Chain, BaseModel):
"Sleep Haven is a premium mattress company that provides customers with the" "Sleep Haven is a premium mattress company that provides customers with the"
" most comfortable and supportive sleeping experience possible. We offer a" " most comfortable and supportive sleeping experience possible. We offer a"
" range of high-quality mattresses, pillows, and bedding accessories that are" " range of high-quality mattresses, pillows, and bedding accessories that are"
" designed to meet the unique needs of our customers.") " designed to meet the unique needs of our customers."
)
company_values: str = ( company_values: str = (
"Our mission at Sleep Haven is to help people achieve a better night's sleep by" "Our mission at Sleep Haven is to help people achieve a better night's sleep by"
" providing them with the best possible sleep solutions. We believe that" " providing them with the best possible sleep solutions. We believe that"
" quality sleep is essential to overall health and well-being, and we are" " quality sleep is essential to overall health and well-being, and we are"
" committed to helping our customers achieve optimal sleep by offering" " committed to helping our customers achieve optimal sleep by offering"
" exceptional products and customer service.") " exceptional products and customer service."
)
conversation_purpose: str = ( conversation_purpose: str = (
"find out whether they are looking to achieve better sleep via buying a premier" "find out whether they are looking to achieve better sleep via buying a premier"
" mattress.") " mattress."
)
conversation_type: str = "call" conversation_type: str = "call"
def retrieve_conversation_stage(self, key): def retrieve_conversation_stage(self, key):
@ -326,7 +336,8 @@ class ProfitPilot(Chain, BaseModel):
) )
self.current_conversation_stage = self.retrieve_conversation_stage( self.current_conversation_stage = self.retrieve_conversation_stage(
conversation_stage_id) conversation_stage_id
)
print(f"Conversation Stage: {self.current_conversation_stage}") print(f"Conversation Stage: {self.current_conversation_stage}")
@ -380,15 +391,13 @@ class ProfitPilot(Chain, BaseModel):
return {} return {}
@classmethod @classmethod
def from_llm(cls, def from_llm(cls, llm: BaseLLM, verbose: bool = False, **kwargs): # noqa: F821
llm: BaseLLM,
verbose: bool = False,
**kwargs): # noqa: F821
"""Initialize the SalesGPT Controller.""" """Initialize the SalesGPT Controller."""
stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose) stage_analyzer_chain = StageAnalyzerChain.from_llm(llm, verbose=verbose)
sales_conversation_utterance_chain = SalesConversationChain.from_llm( sales_conversation_utterance_chain = SalesConversationChain.from_llm(
llm, verbose=verbose) llm, verbose=verbose
)
if "use_tools" in kwargs.keys() and kwargs["use_tools"] is False: if "use_tools" in kwargs.keys() and kwargs["use_tools"] is False:
sales_agent_executor = None sales_agent_executor = None
@ -421,8 +430,7 @@ class ProfitPilot(Chain, BaseModel):
# WARNING: this output parser is NOT reliable yet # WARNING: this output parser is NOT reliable yet
# It makes assumptions about output from LLM which can break and throw an error # It makes assumptions about output from LLM which can break and throw an error
output_parser = SalesConvoOutputParser( output_parser = SalesConvoOutputParser(ai_prefix=kwargs["salesperson_name"])
ai_prefix=kwargs["salesperson_name"])
sales_agent_with_tools = LLMSingleActionAgent( sales_agent_with_tools = LLMSingleActionAgent(
llm_chain=llm_chain, llm_chain=llm_chain,
@ -433,12 +441,12 @@ class ProfitPilot(Chain, BaseModel):
) )
sales_agent_executor = AgentExecutor.from_agent_and_tools( sales_agent_executor = AgentExecutor.from_agent_and_tools(
agent=sales_agent_with_tools, tools=tools, verbose=verbose) agent=sales_agent_with_tools, tools=tools, verbose=verbose
)
return cls( return cls(
stage_analyzer_chain=stage_analyzer_chain, stage_analyzer_chain=stage_analyzer_chain,
sales_conversation_utterance_chain= sales_conversation_utterance_chain=sales_conversation_utterance_chain,
sales_conversation_utterance_chain,
sales_agent_executor=sales_agent_executor, sales_agent_executor=sales_agent_executor,
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
@ -450,27 +458,32 @@ config = dict(
salesperson_name="Ted Lasso", salesperson_name="Ted Lasso",
salesperson_role="Business Development Representative", salesperson_role="Business Development Representative",
company_name="Sleep Haven", company_name="Sleep Haven",
company_business= company_business=(
("Sleep Haven is a premium mattress company that provides customers with the" "Sleep Haven is a premium mattress company that provides customers with the"
" most comfortable and supportive sleeping experience possible. We offer a" " most comfortable and supportive sleeping experience possible. We offer a"
" range of high-quality mattresses, pillows, and bedding accessories that are" " range of high-quality mattresses, pillows, and bedding accessories that are"
" designed to meet the unique needs of our customers."), " designed to meet the unique needs of our customers."
company_values= ),
("Our mission at Sleep Haven is to help people achieve a better night's sleep by" company_values=(
" providing them with the best possible sleep solutions. We believe that" "Our mission at Sleep Haven is to help people achieve a better night's sleep by"
" quality sleep is essential to overall health and well-being, and we are" " providing them with the best possible sleep solutions. We believe that"
" committed to helping our customers achieve optimal sleep by offering" " quality sleep is essential to overall health and well-being, and we are"
" exceptional products and customer service."), " committed to helping our customers achieve optimal sleep by offering"
conversation_purpose= " exceptional products and customer service."
("find out whether they are looking to achieve better sleep via buying a premier" ),
" mattress."), conversation_purpose=(
"find out whether they are looking to achieve better sleep via buying a premier"
" mattress."
),
conversation_history=[], conversation_history=[],
conversation_type="call", conversation_type="call",
conversation_stage=conversation_stages.get( conversation_stage=conversation_stages.get(
"1", "1",
("Introduction: Start the conversation by introducing yourself and your" (
" company. Be polite and respectful while keeping the tone of the" "Introduction: Start the conversation by introducing yourself and your"
" conversation professional."), " company. Be polite and respectful while keeping the tone of the"
" conversation professional."
),
), ),
use_tools=True, use_tools=True,
product_catalog="sample_product_catalog.txt", product_catalog="sample_product_catalog.txt",

@ -1,11 +1,9 @@
class PromptRefiner: class PromptRefiner:
def __init__(self, system_prompt: str, llm): def __init__(self, system_prompt: str, llm):
super().__init__() super().__init__()
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.llm = llm self.llm = llm
def run(self, task: str): def run(self, task: str):
refine = self.llm( refine = self.llm(f"System Prompt: {self.system_prompt} Current task: {task}")
f"System Prompt: {self.system_prompt} Current task: {task}")
return refine return refine

@ -10,7 +10,6 @@ class Registry(BaseModel):
entries: Dict = {} entries: Dict = {}
def register(self, key: str): def register(self, key: str):
def decorator(class_builder): def decorator(class_builder):
self.entries[key] = class_builder self.entries[key] = class_builder
return class_builder return class_builder
@ -21,7 +20,8 @@ class Registry(BaseModel):
if type not in self.entries: if type not in self.entries:
raise ValueError( raise ValueError(
f"{type} is not registered. Please register with the" f"{type} is not registered. Please register with the"
f' .register("{type}") method provided in {self.name} registry') f' .register("{type}") method provided in {self.name} registry'
)
return self.entries[type](**kwargs) return self.entries[type](**kwargs)
def get_all_entries(self): def get_all_entries(self):

@ -29,8 +29,7 @@ class SimpleAgent:
def run(self, task: str) -> str: def run(self, task: str) -> str:
"""Run method""" """Run method"""
metrics = print( metrics = print(colored(f"Agent {self.name} is running task: {task}", "red"))
colored(f"Agent {self.name} is running task: {task}", "red"))
print(metrics) print(metrics)
response = self.flow.run(task) response = self.flow.run(task)

@ -10,8 +10,9 @@ from marshmallow.exceptions import RegistryError
@define @define
class BaseArtifact(ABC): class BaseArtifact(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
name: str = field(default=Factory(lambda self: self.id, takes_self=True), name: str = field(
kw_only=True) default=Factory(lambda self: self.id, takes_self=True), kw_only=True
)
value: any = field() value: any = field()
type: str = field( type: str = field(
default=Factory(lambda self: self.__class__.__name__, takes_self=True), default=Factory(lambda self: self.__class__.__name__, takes_self=True),
@ -53,8 +54,7 @@ class BaseArtifact(ABC):
class_registry.register("ListArtifact", ListArtifactSchema) class_registry.register("ListArtifact", ListArtifactSchema)
try: try:
return class_registry.get_class( return class_registry.get_class(artifact_dict["type"])().load(artifact_dict)
artifact_dict["type"])().load(artifact_dict)
except RegistryError: except RegistryError:
raise ValueError("Unsupported artifact type") raise ValueError("Unsupported artifact type")

@ -15,7 +15,8 @@ class Artifact(BaseModel):
artifact_id: StrictStr = Field(..., description="ID of the artifact") artifact_id: StrictStr = Field(..., description="ID of the artifact")
file_name: StrictStr = Field(..., description="Filename of the artifact") file_name: StrictStr = Field(..., description="Filename of the artifact")
relative_path: Optional[StrictStr] = Field( relative_path: Optional[StrictStr] = Field(
None, description="Relative path of the artifact") None, description="Relative path of the artifact"
)
__properties = ["artifact_id", "file_name", "relative_path"] __properties = ["artifact_id", "file_name", "relative_path"]
class Config: class Config:
@ -48,10 +49,12 @@ class Artifact(BaseModel):
if not isinstance(obj, dict): if not isinstance(obj, dict):
return Artifact.parse_obj(obj) return Artifact.parse_obj(obj)
_obj = Artifact.parse_obj({ _obj = Artifact.parse_obj(
"artifact_id": obj.get("artifact_id"), {
"file_name": obj.get("file_name"), "artifact_id": obj.get("artifact_id"),
"relative_path": obj.get("relative_path"), "file_name": obj.get("file_name"),
}) "relative_path": obj.get("relative_path"),
}
)
return _obj return _obj

@ -48,13 +48,15 @@ class BaseChunker(ABC):
kw_only=True, kw_only=True,
) )
tokenizer: OpenAITokenizer = field( tokenizer: OpenAITokenizer = field(
default=Factory(lambda: OpenAITokenizer( default=Factory(
model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), lambda: OpenAITokenizer(
model=OpenAITokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL
)
),
kw_only=True, kw_only=True,
) )
max_tokens: int = field( max_tokens: int = field(
default=Factory(lambda self: self.tokenizer.max_tokens, default=Factory(lambda self: self.tokenizer.max_tokens, takes_self=True),
takes_self=True),
kw_only=True, kw_only=True,
) )
@ -64,9 +66,8 @@ class BaseChunker(ABC):
return [TextArtifact(c) for c in self._chunk_recursively(text)] return [TextArtifact(c) for c in self._chunk_recursively(text)]
def _chunk_recursively( def _chunk_recursively(
self, self, chunk: str, current_separator: Optional[ChunkSeparator] = None
chunk: str, ) -> list[str]:
current_separator: Optional[ChunkSeparator] = None) -> list[str]:
token_count = self.tokenizer.count_tokens(chunk) token_count = self.tokenizer.count_tokens(chunk)
if token_count <= self.max_tokens: if token_count <= self.max_tokens:
@ -78,8 +79,7 @@ class BaseChunker(ABC):
half_token_count = token_count // 2 half_token_count = token_count // 2
if current_separator: if current_separator:
separators = self.separators[self.separators. separators = self.separators[self.separators.index(current_separator) :]
index(current_separator):]
else: else:
separators = self.separators separators = self.separators
@ -102,19 +102,26 @@ class BaseChunker(ABC):
if separator.is_prefix: if separator.is_prefix:
first_subchunk = separator.value + separator.value.join( first_subchunk = separator.value + separator.value.join(
subchanks[:balance_index + 1]) subchanks[: balance_index + 1]
)
second_subchunk = separator.value + separator.value.join( second_subchunk = separator.value + separator.value.join(
subchanks[balance_index + 1:]) subchanks[balance_index + 1 :]
)
else: else:
first_subchunk = (separator.value.join( first_subchunk = (
subchanks[:balance_index + 1]) + separator.value) separator.value.join(subchanks[: balance_index + 1])
+ separator.value
)
second_subchunk = separator.value.join( second_subchunk = separator.value.join(
subchanks[balance_index + 1:]) subchanks[balance_index + 1 :]
)
first_subchunk_rec = self._chunk_recursively( first_subchunk_rec = self._chunk_recursively(
first_subchunk.strip(), separator) first_subchunk.strip(), separator
)
second_subchunk_rec = self._chunk_recursively( second_subchunk_rec = self._chunk_recursively(
second_subchunk.strip(), separator) second_subchunk.strip(), separator
)
if first_subchunk_rec and second_subchunk_rec: if first_subchunk_rec and second_subchunk_rec:
return first_subchunk_rec + second_subchunk_rec return first_subchunk_rec + second_subchunk_rec

@ -76,7 +76,8 @@ class OmniChunker:
colored( colored(
f"Could not decode file with extension {file_extension}: {e}", f"Could not decode file with extension {file_extension}: {e}",
"yellow", "yellow",
)) )
)
return "" return ""
def chunk_content(self, content: str) -> List[str]: def chunk_content(self, content: str) -> List[str]:
@ -90,7 +91,7 @@ class OmniChunker:
List[str]: The list of chunks. List[str]: The list of chunks.
""" """
return [ return [
content[i:i + self.chunk_size] content[i : i + self.chunk_size]
for i in range(0, len(content), self.chunk_size) for i in range(0, len(content), self.chunk_size)
] ]
@ -112,4 +113,5 @@ class OmniChunker:
{self.metrics()} {self.metrics()}
""", """,
"cyan", "cyan",
)) )
)

@ -18,9 +18,9 @@ class AsanaReader(BaseReader):
self.client = asana.Client.access_token(asana_token) self.client = asana.Client.access_token(asana_token)
def load_data(self, def load_data(
workspace_id: Optional[str] = None, self, workspace_id: Optional[str] = None, project_id: Optional[str] = None
project_id: Optional[str] = None) -> List[Document]: ) -> List[Document]:
"""Load data from the workspace. """Load data from the workspace.
Args: Args:
@ -31,20 +31,18 @@ class AsanaReader(BaseReader):
""" """
if workspace_id is None and project_id is None: if workspace_id is None and project_id is None:
raise ValueError( raise ValueError("Either workspace_id or project_id must be provided")
"Either workspace_id or project_id must be provided")
if workspace_id is not None and project_id is not None: if workspace_id is not None and project_id is not None:
raise ValueError( raise ValueError(
"Only one of workspace_id or project_id should be provided") "Only one of workspace_id or project_id should be provided"
)
results = [] results = []
if workspace_id is not None: if workspace_id is not None:
workspace_name = self.client.workspaces.find_by_id( workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"]
workspace_id)["name"] projects = self.client.projects.find_all({"workspace": workspace_id})
projects = self.client.projects.find_all(
{"workspace": workspace_id})
# Case: Only project_id is provided # Case: Only project_id is provided
else: # since we've handled the other cases, this means project_id is not None else: # since we've handled the other cases, this means project_id is not None
@ -52,58 +50,54 @@ class AsanaReader(BaseReader):
workspace_name = projects[0]["workspace"]["name"] workspace_name = projects[0]["workspace"]["name"]
for project in projects: for project in projects:
tasks = self.client.tasks.find_all({ tasks = self.client.tasks.find_all(
"project": {
project["gid"], "project": project["gid"],
"opt_fields": "opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields",
"name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields", }
}) )
for task in tasks: for task in tasks:
stories = self.client.tasks.stories(task["gid"], stories = self.client.tasks.stories(task["gid"], opt_fields="type,text")
opt_fields="type,text") comments = "\n".join(
comments = "\n".join([ [
story["text"] story["text"]
for story in stories for story in stories
if story.get("type") == "comment" and "text" in story if story.get("type") == "comment" and "text" in story
]) ]
)
task_metadata = { task_metadata = {
"task_id": "task_id": task.get("gid", ""),
task.get("gid", ""), "name": task.get("name", ""),
"name":
task.get("name", ""),
"assignee": (task.get("assignee") or {}).get("name", ""), "assignee": (task.get("assignee") or {}).get("name", ""),
"completed_on": "completed_on": task.get("completed_at", ""),
task.get("completed_at", ""), "completed_by": (task.get("completed_by") or {}).get("name", ""),
"completed_by": (task.get("completed_by") or "project_name": project.get("name", ""),
{}).get("name", ""),
"project_name":
project.get("name", ""),
"custom_fields": [ "custom_fields": [
i["display_value"] i["display_value"]
for i in task.get("custom_fields") for i in task.get("custom_fields")
if task.get("custom_fields") is not None if task.get("custom_fields") is not None
], ],
"workspace_name": "workspace_name": workspace_name,
workspace_name, "url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}",
"url":
f"https://app.asana.com/0/{project['gid']}/{task['gid']}",
} }
if task.get("followers") is not None: if task.get("followers") is not None:
task_metadata["followers"] = [ task_metadata["followers"] = [
i.get("name") i.get("name") for i in task.get("followers") if "name" in i
for i in task.get("followers")
if "name" in i
] ]
else: else:
task_metadata["followers"] = [] task_metadata["followers"] = []
results.append( results.append(
Document( Document(
text=task.get("name", "") + " " + text=task.get("name", "")
task.get("notes", "") + " " + comments, + " "
+ task.get("notes", "")
+ " "
+ comments,
extra_info=task_metadata, extra_info=task_metadata,
)) )
)
return results return results

@ -47,8 +47,7 @@ class BaseComponent(BaseModel):
# TODO: return type here not supported by current mypy version # TODO: return type here not supported by current mypy version
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any], def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
**kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict): if isinstance(kwargs, dict):
data.update(kwargs) data.update(kwargs)
@ -119,10 +118,12 @@ class BaseNode(BaseComponent):
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
id_: str = Field(default_factory=lambda: str(uuid.uuid4()), id_: str = Field(
description="Unique ID of the node.") default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
)
embedding: Optional[List[float]] = Field( embedding: Optional[List[float]] = Field(
default=None, description="Embedding of the node.") default=None, description="Embedding of the node."
)
"""" """"
metadata fields metadata fields
- injected as part of the text shown to LLMs as context - injected as part of the text shown to LLMs as context
@ -137,8 +138,7 @@ class BaseNode(BaseComponent):
) )
excluded_embed_metadata_keys: List[str] = Field( excluded_embed_metadata_keys: List[str] = Field(
default_factory=list, default_factory=list,
description= description="Metadata keys that are excluded from text for the embed model.",
"Metadata keys that are excluded from text for the embed model.",
) )
excluded_llm_metadata_keys: List[str] = Field( excluded_llm_metadata_keys: List[str] = Field(
default_factory=list, default_factory=list,
@ -156,8 +156,7 @@ class BaseNode(BaseComponent):
"""Get Object type.""" """Get Object type."""
@abstractmethod @abstractmethod
def get_content(self, def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
"""Get object content.""" """Get object content."""
@abstractmethod @abstractmethod
@ -188,8 +187,7 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.SOURCE] relation = self.relationships[NodeRelationship.SOURCE]
if isinstance(relation, list): if isinstance(relation, list):
raise ValueError( raise ValueError("Source object must be a single RelatedNodeInfo object")
"Source object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -200,8 +198,7 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.PREVIOUS] relation = self.relationships[NodeRelationship.PREVIOUS]
if not isinstance(relation, RelatedNodeInfo): if not isinstance(relation, RelatedNodeInfo):
raise ValueError( raise ValueError("Previous object must be a single RelatedNodeInfo object")
"Previous object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -212,8 +209,7 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.NEXT] relation = self.relationships[NodeRelationship.NEXT]
if not isinstance(relation, RelatedNodeInfo): if not isinstance(relation, RelatedNodeInfo):
raise ValueError( raise ValueError("Next object must be a single RelatedNodeInfo object")
"Next object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -224,8 +220,7 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.PARENT] relation = self.relationships[NodeRelationship.PARENT]
if not isinstance(relation, RelatedNodeInfo): if not isinstance(relation, RelatedNodeInfo):
raise ValueError( raise ValueError("Parent object must be a single RelatedNodeInfo object")
"Parent object must be a single RelatedNodeInfo object")
return relation return relation
@property @property
@ -236,8 +231,7 @@ class BaseNode(BaseComponent):
relation = self.relationships[NodeRelationship.CHILD] relation = self.relationships[NodeRelationship.CHILD]
if not isinstance(relation, list): if not isinstance(relation, list):
raise ValueError( raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
"Child objects must be a list of RelatedNodeInfo objects.")
return relation return relation
@property @property
@ -254,10 +248,12 @@ class BaseNode(BaseComponent):
return self.metadata return self.metadata
def __str__(self) -> str: def __str__(self) -> str:
source_text_truncated = truncate_text(self.get_content().strip(), source_text_truncated = truncate_text(
TRUNCATE_LENGTH) self.get_content().strip(), TRUNCATE_LENGTH
source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n", )
width=WRAP_WIDTH) source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Node ID: {self.node_id}\n{source_text_wrapped}" return f"Node ID: {self.node_id}\n{source_text_wrapped}"
def get_embedding(self) -> List[float]: def get_embedding(self) -> List[float]:
@ -283,23 +279,28 @@ class BaseNode(BaseComponent):
class TextNode(BaseNode): class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.") text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field( start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node.") default=None, description="Start char index of the node."
)
end_char_idx: Optional[int] = Field( end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node.") default=None, description="End char index of the node."
)
text_template: str = Field( text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL, default=DEFAULT_TEXT_NODE_TMPL,
description=("Template for how text is formatted, with {content} and " description=(
"{metadata_str} placeholders."), "Template for how text is formatted, with {content} and "
"{metadata_str} placeholders."
),
) )
metadata_template: str = Field( metadata_template: str = Field(
default=DEFAULT_METADATA_TMPL, default=DEFAULT_METADATA_TMPL,
description=("Template for how metadata is formatted, with {key} and " description=(
"{value} placeholders."), "Template for how metadata is formatted, with {key} and "
"{value} placeholders."
),
) )
metadata_seperator: str = Field( metadata_seperator: str = Field(
default="\n", default="\n",
description= description="Separator between metadata fields when converting to string.",
"Separator between metadata fields when converting to string.",
) )
@classmethod @classmethod
@ -313,7 +314,8 @@ class TextNode(BaseNode):
metadata = values.get("metadata", {}) metadata = values.get("metadata", {})
doc_identity = str(text) + str(metadata) doc_identity = str(text) + str(metadata)
values["hash"] = str( values["hash"] = str(
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
)
return values return values
@classmethod @classmethod
@ -321,15 +323,15 @@ class TextNode(BaseNode):
"""Get Object type.""" """Get Object type."""
return ObjectType.TEXT return ObjectType.TEXT
def get_content(self, def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get object content.""" """Get object content."""
metadata_str = self.get_metadata_str(mode=metadata_mode).strip() metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
if not metadata_str: if not metadata_str:
return self.text return self.text
return self.text_template.format(content=self.text, return self.text_template.format(
metadata_str=metadata_str).strip() content=self.text, metadata_str=metadata_str
).strip()
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata info string.""" """Metadata info string."""
@ -346,11 +348,13 @@ class TextNode(BaseNode):
if key in usable_metadata_keys: if key in usable_metadata_keys:
usable_metadata_keys.remove(key) usable_metadata_keys.remove(key)
return self.metadata_seperator.join([ return self.metadata_seperator.join(
self.metadata_template.format(key=key, value=str(value)) [
for key, value in self.metadata.items() self.metadata_template.format(key=key, value=str(value))
if key in usable_metadata_keys for key, value in self.metadata.items()
]) if key in usable_metadata_keys
]
)
def set_content(self, value: str) -> None: def set_content(self, value: str) -> None:
"""Set the content of the node.""" """Set the content of the node."""
@ -474,8 +478,7 @@ class NodeWithScore(BaseComponent):
else: else:
raise ValueError("Node must be a TextNode to get text.") raise ValueError("Node must be a TextNode to get text.")
def get_content(self, def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return self.node.get_content(metadata_mode=metadata_mode) return self.node.get_content(metadata_mode=metadata_mode)
def get_embedding(self) -> List[float]: def get_embedding(self) -> List[float]:
@ -512,10 +515,12 @@ class Document(TextNode):
return self.id_ return self.id_
def __str__(self) -> str: def __str__(self) -> str:
source_text_truncated = truncate_text(self.get_content().strip(), source_text_truncated = truncate_text(
TRUNCATE_LENGTH) self.get_content().strip(), TRUNCATE_LENGTH
source_text_wrapped = textwrap.fill(f"Text: {source_text_truncated}\n", )
width=WRAP_WIDTH) source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}" return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
def get_doc_id(self) -> str: def get_doc_id(self) -> str:
@ -531,27 +536,22 @@ class Document(TextNode):
"""Convert struct to Haystack document format.""" """Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument from haystack.schema import Document as HaystackDocument
return HaystackDocument(content=self.text, return HaystackDocument(
meta=self.metadata, content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
embedding=self.embedding, )
id=self.id_)
@classmethod @classmethod
def from_haystack_format(cls, doc: "HaystackDocument") -> "Document": def from_haystack_format(cls, doc: "HaystackDocument") -> "Document":
"""Convert struct from Haystack document format.""" """Convert struct from Haystack document format."""
return cls(text=doc.content, return cls(
metadata=doc.meta, text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id
embedding=doc.embedding, )
id_=doc.id)
def to_embedchain_format(self) -> Dict[str, Any]: def to_embedchain_format(self) -> Dict[str, Any]:
"""Convert struct to EmbedChain document format.""" """Convert struct to EmbedChain document format."""
return { return {
"doc_id": self.id_, "doc_id": self.id_,
"data": { "data": {"content": self.text, "meta_data": self.metadata},
"content": self.text,
"meta_data": self.metadata
},
} }
@classmethod @classmethod
@ -581,8 +581,7 @@ class Document(TextNode):
return cls( return cls(
text=doc._text, text=doc._text,
metadata={"additional_metadata": doc._additional_metadata}, metadata={"additional_metadata": doc._additional_metadata},
embedding=doc._embedding.tolist() embedding=doc._embedding.tolist() if doc._embedding is not None else None,
if doc._embedding is not None else None,
id_=doc._id, id_=doc._id,
) )
@ -590,10 +589,7 @@ class Document(TextNode):
def example(cls) -> "Document": def example(cls) -> "Document":
return Document( return Document(
text=SAMPLE_TEXT, text=SAMPLE_TEXT,
metadata={ metadata={"filename": "README.md", "category": "codebase"},
"filename": "README.md",
"category": "codebase"
},
) )
@classmethod @classmethod

@ -30,25 +30,32 @@ class BaseVectorStore(ABC):
embedding_driver: Any embedding_driver: Any
futures_executor: futures.Executor = field( futures_executor: futures.Executor = field(
default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True
)
def upsert_text_artifacts(self,
artifacts: dict[str, list[TextArtifact]], def upsert_text_artifacts(
meta: Optional[dict] = None, self,
**kwargs) -> None: artifacts: dict[str, list[TextArtifact]],
execute_futures_dict({ meta: Optional[dict] = None,
namespace: **kwargs
self.futures_executor.submit(self.upsert_text_artifact, a, ) -> None:
namespace, meta, **kwargs) execute_futures_dict(
for namespace, artifact_list in artifacts.items() {
for a in artifact_list namespace: self.futures_executor.submit(
}) self.upsert_text_artifact, a, namespace, meta, **kwargs
)
def upsert_text_artifact(self, for namespace, artifact_list in artifacts.items()
artifact: TextArtifact, for a in artifact_list
namespace: Optional[str] = None, }
meta: Optional[dict] = None, )
**kwargs) -> str:
def upsert_text_artifact(
self,
artifact: TextArtifact,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs
) -> str:
if not meta: if not meta:
meta = {} meta = {}
@ -59,37 +66,39 @@ class BaseVectorStore(ABC):
else: else:
vector = artifact.generate_embedding(self.embedding_driver) vector = artifact.generate_embedding(self.embedding_driver)
return self.upsert_vector(vector, return self.upsert_vector(
vector_id=artifact.id, vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs
namespace=namespace, )
meta=meta,
**kwargs) def upsert_text(
self,
def upsert_text(self, string: str,
string: str, vector_id: Optional[str] = None,
vector_id: Optional[str] = None, namespace: Optional[str] = None,
namespace: Optional[str] = None, meta: Optional[dict] = None,
meta: Optional[dict] = None, **kwargs
**kwargs) -> str: ) -> str:
return self.upsert_vector(self.embedding_driver.embed_string(string), return self.upsert_vector(
vector_id=vector_id, self.embedding_driver.embed_string(string),
namespace=namespace, vector_id=vector_id,
meta=meta if meta else {}, namespace=namespace,
**kwargs) meta=meta if meta else {},
**kwargs
)
@abstractmethod @abstractmethod
def upsert_vector(self, def upsert_vector(
vector: list[float], self,
vector_id: Optional[str] = None, vector: list[float],
namespace: Optional[str] = None, vector_id: Optional[str] = None,
meta: Optional[dict] = None, namespace: Optional[str] = None,
**kwargs) -> str: meta: Optional[dict] = None,
**kwargs
) -> str:
... ...
@abstractmethod @abstractmethod
def load_entry(self, def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry:
vector_id: str,
namespace: Optional[str] = None) -> Entry:
... ...
@abstractmethod @abstractmethod
@ -97,10 +106,12 @@ class BaseVectorStore(ABC):
... ...
@abstractmethod @abstractmethod
def query(self, def query(
query: str, self,
count: Optional[int] = None, query: str,
namespace: Optional[str] = None, count: Optional[int] = None,
include_vectors: bool = False, namespace: Optional[str] = None,
**kwargs) -> list[QueryResult]: include_vectors: bool = False,
**kwargs
) -> list[QueryResult]:
... ...

@ -80,8 +80,10 @@ class Chroma(VectorStore):
import chromadb import chromadb
import chromadb.config import chromadb.config
except ImportError: except ImportError:
raise ImportError("Could not import chromadb python package. " raise ImportError(
"Please install it with `pip install chromadb`.") "Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
if client is not None: if client is not None:
self._client_settings = client_settings self._client_settings = client_settings
@ -92,7 +94,8 @@ class Chroma(VectorStore):
# If client_settings is provided with persist_directory specified, # If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode. # then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = ( client_settings.persist_directory = (
persist_directory or client_settings.persist_directory) persist_directory or client_settings.persist_directory
)
if client_settings.persist_directory is not None: if client_settings.persist_directory is not None:
# Maintain backwards compatibility with chromadb < 0.4.0 # Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".") major, minor, _ = chromadb.__version__.split(".")
@ -105,23 +108,25 @@ class Chroma(VectorStore):
major, minor, _ = chromadb.__version__.split(".") major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4: if int(major) == 0 and int(minor) < 4:
_client_settings = chromadb.config.Settings( _client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",) chroma_db_impl="duckdb+parquet",
)
else: else:
_client_settings = chromadb.config.Settings( _client_settings = chromadb.config.Settings(is_persistent=True)
is_persistent=True)
_client_settings.persist_directory = persist_directory _client_settings.persist_directory = persist_directory
else: else:
_client_settings = chromadb.config.Settings() _client_settings = chromadb.config.Settings()
self._client_settings = _client_settings self._client_settings = _client_settings
self._client = chromadb.Client(_client_settings) self._client = chromadb.Client(_client_settings)
self._persist_directory = (_client_settings.persist_directory or self._persist_directory = (
persist_directory) _client_settings.persist_directory or persist_directory
)
self._embedding_function = embedding_function self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=collection_name, name=collection_name,
embedding_function=self._embedding_function.embed_documents embedding_function=self._embedding_function.embed_documents
if self._embedding_function is not None else None, if self._embedding_function is not None
else None,
metadata=collection_metadata, metadata=collection_metadata,
) )
self.override_relevance_score_fn = relevance_score_fn self.override_relevance_score_fn = relevance_score_fn
@ -144,8 +149,10 @@ class Chroma(VectorStore):
try: try:
import chromadb # noqa: F401 import chromadb # noqa: F401
except ImportError: except ImportError:
raise ValueError("Could not import chromadb python package. " raise ValueError(
"Please install it with `pip install chromadb`.") "Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
return self._collection.query( return self._collection.query(
query_texts=query_texts, query_texts=query_texts,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
@ -195,9 +202,9 @@ class Chroma(VectorStore):
if non_empty_ids: if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids] metadatas = [metadatas[idx] for idx in non_empty_ids]
texts_with_metadatas = [texts[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
embeddings_with_metadatas = ([ embeddings_with_metadatas = (
embeddings[idx] for idx in non_empty_ids [embeddings[idx] for idx in non_empty_ids] if embeddings else None
] if embeddings else None) )
ids_with_metadata = [ids[idx] for idx in non_empty_ids] ids_with_metadata = [ids[idx] for idx in non_empty_ids]
try: try:
self._collection.upsert( self._collection.upsert(
@ -218,7 +225,8 @@ class Chroma(VectorStore):
if empty_ids: if empty_ids:
texts_without_metadatas = [texts[j] for j in empty_ids] texts_without_metadatas = [texts[j] for j in empty_ids]
embeddings_without_metadatas = ( embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None) [embeddings[j] for j in empty_ids] if embeddings else None
)
ids_without_metadatas = [ids[j] for j in empty_ids] ids_without_metadatas = [ids[j] for j in empty_ids]
self._collection.upsert( self._collection.upsert(
embeddings=embeddings_without_metadatas, embeddings=embeddings_without_metadatas,
@ -250,9 +258,7 @@ class Chroma(VectorStore):
Returns: Returns:
List[Document]: List of documents most similar to the query text. List[Document]: List of documents most similar to the query text.
""" """
docs_and_scores = self.similarity_search_with_score(query, docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
k,
filter=filter)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def similarity_search_by_vector( def similarity_search_by_vector(
@ -375,7 +381,8 @@ class Chroma(VectorStore):
raise ValueError( raise ValueError(
"No supported normalization function" "No supported normalization function"
f" for distance metric of type: {distance}." f" for distance metric of type: {distance}."
"Consider providing relevance_score_fn to Chroma constructor.") "Consider providing relevance_score_fn to Chroma constructor."
)
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
@ -421,9 +428,7 @@ class Chroma(VectorStore):
candidates = _results_to_docs(results) candidates = _results_to_docs(results)
selected_results = [ selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
r for i, r in enumerate(candidates) if i in mmr_selected
]
return selected_results return selected_results
def max_marginal_relevance_search( def max_marginal_relevance_search(
@ -518,8 +523,10 @@ class Chroma(VectorStore):
It will also be called automatically when the object is destroyed. It will also be called automatically when the object is destroyed.
""" """
if self._persist_directory is None: if self._persist_directory is None:
raise ValueError("You must specify a persist_directory on" raise ValueError(
"creation to persist the collection.") "You must specify a persist_directory on"
"creation to persist the collection."
)
import chromadb import chromadb
# Maintain backwards compatibility with chromadb < 0.4.0 # Maintain backwards compatibility with chromadb < 0.4.0
@ -536,8 +543,7 @@ class Chroma(VectorStore):
""" """
return self.update_documents([document_id], [document]) return self.update_documents([document_id], [document])
def update_documents(self, ids: List[str], def update_documents(self, ids: List[str], documents: List[Document]) -> None:
documents: List[Document]) -> None:
"""Update a document in the collection. """Update a document in the collection.
Args: Args:
@ -552,16 +558,17 @@ class Chroma(VectorStore):
) )
embeddings = self._embedding_function.embed_documents(text) embeddings = self._embedding_function.embed_documents(text)
if hasattr(self._collection._client, if hasattr(
"max_batch_size"): # for Chroma 0.4.10 and above self._collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
for batch in create_batches( for batch in create_batches(
api=self._collection._client, api=self._collection._client,
ids=ids, ids=ids,
metadatas=metadata, metadatas=metadata,
documents=text, documents=text,
embeddings=embeddings, embeddings=embeddings,
): ):
self._collection.update( self._collection.update(
ids=batch[0], ids=batch[0],
@ -621,15 +628,16 @@ class Chroma(VectorStore):
) )
if ids is None: if ids is None:
ids = [str(uuid.uuid1()) for _ in texts] ids = [str(uuid.uuid1()) for _ in texts]
if hasattr(chroma_collection._client, if hasattr(
"max_batch_size"): # for Chroma 0.4.10 and above chroma_collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
for batch in create_batches( for batch in create_batches(
api=chroma_collection._client, api=chroma_collection._client,
ids=ids, ids=ids,
metadatas=metadatas, metadatas=metadatas,
documents=texts, documents=texts,
): ):
chroma_collection.add_texts( chroma_collection.add_texts(
texts=batch[3] if batch[3] else [], texts=batch[3] if batch[3] else [],
@ -637,9 +645,7 @@ class Chroma(VectorStore):
ids=batch[0], ids=batch[0],
) )
else: else:
chroma_collection.add_texts(texts=texts, chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
metadatas=metadatas,
ids=ids)
return chroma_collection return chroma_collection
@classmethod @classmethod

@ -19,7 +19,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
if X.shape[1] != Y.shape[1]: if X.shape[1] != Y.shape[1]:
raise ValueError( raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} " f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}.") f"and Y has shape {Y.shape}."
)
try: try:
import simsimd as simd import simsimd as simd
@ -32,7 +33,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
except ImportError: except ImportError:
logger.info( logger.info(
"Unable to import simsimd, defaulting to NumPy implementation. If you want " "Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`.") "to use simsimd please install with `pip install simsimd`."
)
X_norm = np.linalg.norm(X, axis=1) X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1) Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below. # Ignore divide by zero errors run time warnings as those are handled below.

@ -27,7 +27,6 @@ class NotFoundException(Exception):
class TaskDB(ABC): class TaskDB(ABC):
async def create_task( async def create_task(
self, self,
input: Optional[str], input: Optional[str],
@ -68,9 +67,9 @@ class TaskDB(ABC):
async def list_tasks(self) -> List[Task]: async def list_tasks(self) -> List[Task]:
raise NotImplementedError raise NotImplementedError
async def list_steps(self, async def list_steps(
task_id: str, self, task_id: str, status: Optional[Status] = None
status: Optional[Status] = None) -> List[Step]: ) -> List[Step]:
raise NotImplementedError raise NotImplementedError
@ -137,8 +136,8 @@ class InMemoryTaskDB(TaskDB):
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact: async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
task = await self.get_task(task_id) task = await self.get_task(task_id)
artifact = next( artifact = next(
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
None) )
if not artifact: if not artifact:
raise NotFoundException("Artifact", artifact_id) raise NotFoundException("Artifact", artifact_id)
return artifact return artifact
@ -151,9 +150,9 @@ class InMemoryTaskDB(TaskDB):
step_id: Optional[str] = None, step_id: Optional[str] = None,
) -> Artifact: ) -> Artifact:
artifact_id = str(uuid.uuid4()) artifact_id = str(uuid.uuid4())
artifact = Artifact(artifact_id=artifact_id, artifact = Artifact(
file_name=file_name, artifact_id=artifact_id, file_name=file_name, relative_path=relative_path
relative_path=relative_path) )
task = await self.get_task(task_id) task = await self.get_task(task_id)
task.artifacts.append(artifact) task.artifacts.append(artifact)
@ -166,9 +165,9 @@ class InMemoryTaskDB(TaskDB):
async def list_tasks(self) -> List[Task]: async def list_tasks(self) -> List[Task]:
return [task for task in self._tasks.values()] return [task for task in self._tasks.values()]
async def list_steps(self, async def list_steps(
task_id: str, self, task_id: str, status: Optional[Status] = None
status: Optional[Status] = None) -> List[Step]: ) -> List[Step]:
task = await self.get_task(task_id) task = await self.get_task(task_id)
steps = task.steps steps = task.steps
if status: if status:

@ -63,7 +63,8 @@ class OceanDB:
try: try:
embedding_function = MultiModalEmbeddingFunction(modality=modality) embedding_function = MultiModalEmbeddingFunction(modality=modality)
collection = self.client.create_collection( collection = self.client.create_collection(
collection_name, embedding_function=embedding_function) collection_name, embedding_function=embedding_function
)
return collection return collection
except Exception as e: except Exception as e:
logging.error(f"Failed to create collection. Error {e}") logging.error(f"Failed to create collection. Error {e}")
@ -90,8 +91,7 @@ class OceanDB:
try: try:
return collection.add(documents=[document], ids=[id]) return collection.add(documents=[document], ids=[id])
except Exception as e: except Exception as e:
logging.error( logging.error(f"Failed to append document to the collection. Error {e}")
f"Failed to append document to the collection. Error {e}")
raise raise
def add_documents(self, collection, documents: List[str], ids: List[str]): def add_documents(self, collection, documents: List[str], ids: List[str]):
@ -137,8 +137,7 @@ class OceanDB:
the results of the query the results of the query
""" """
try: try:
results = collection.query(query_texts=query_texts, results = collection.query(query_texts=query_texts, n_results=n_results)
n_results=n_results)
return results return results
except Exception as e: except Exception as e:
logging.error(f"Failed to query the collection. Error {e}") logging.error(f"Failed to query the collection. Error {e}")

@ -88,12 +88,12 @@ class PgVectorVectorStore(BaseVectorStore):
create_engine_params: dict = field(factory=dict, kw_only=True) create_engine_params: dict = field(factory=dict, kw_only=True)
engine: Optional[Engine] = field(default=None, kw_only=True) engine: Optional[Engine] = field(default=None, kw_only=True)
table_name: str = field(kw_only=True) table_name: str = field(kw_only=True)
_model: any = field(default=Factory( _model: any = field(
lambda self: self.default_vector_model(), takes_self=True)) default=Factory(lambda self: self.default_vector_model(), takes_self=True)
)
@connection_string.validator @connection_string.validator
def validate_connection_string(self, _, def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
connection_string: Optional[str]) -> None:
# If an engine is provided, the connection string is not used. # If an engine is provided, the connection string is not used.
if self.engine is not None: if self.engine is not None:
return return
@ -122,8 +122,9 @@ class PgVectorVectorStore(BaseVectorStore):
If not, a connection string is used to create a new database connection here. If not, a connection string is used to create a new database connection here.
""" """
if self.engine is None: if self.engine is None:
self.engine = create_engine(self.connection_string, self.engine = create_engine(
**self.create_engine_params) self.connection_string, **self.create_engine_params
)
def setup( def setup(
self, self,
@ -141,12 +142,14 @@ class PgVectorVectorStore(BaseVectorStore):
if create_schema: if create_schema:
self._model.metadata.create_all(self.engine) self._model.metadata.create_all(self.engine)
def upsert_vector(self, def upsert_vector(
vector: list[float], self,
vector_id: Optional[str] = None, vector: list[float],
namespace: Optional[str] = None, vector_id: Optional[str] = None,
meta: Optional[dict] = None, namespace: Optional[str] = None,
**kwargs) -> str: meta: Optional[dict] = None,
**kwargs
) -> str:
"""Inserts or updates a vector in the collection.""" """Inserts or updates a vector in the collection."""
with Session(self.engine) as session: with Session(self.engine) as session:
obj = self._model( obj = self._model(
@ -161,9 +164,9 @@ class PgVectorVectorStore(BaseVectorStore):
return str(obj.id) return str(obj.id)
def load_entry(self, def load_entry(
vector_id: str, self, vector_id: str, namespace: Optional[str] = None
namespace: Optional[str] = None) -> BaseVectorStore.Entry: ) -> BaseVectorStore.Entry:
"""Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
with Session(self.engine) as session: with Session(self.engine) as session:
result = session.get(self._model, vector_id) result = session.get(self._model, vector_id)
@ -176,8 +179,8 @@ class PgVectorVectorStore(BaseVectorStore):
) )
def load_entries( def load_entries(
self, self, namespace: Optional[str] = None
namespace: Optional[str] = None) -> list[BaseVectorStore.Entry]: ) -> list[BaseVectorStore.Entry]:
"""Retrieves all vector entries from the collection, optionally filtering to only """Retrieves all vector entries from the collection, optionally filtering to only
those that match the provided namespace. those that match the provided namespace.
""" """
@ -194,16 +197,19 @@ class PgVectorVectorStore(BaseVectorStore):
vector=result.vector, vector=result.vector,
namespace=result.namespace, namespace=result.namespace,
meta=result.meta, meta=result.meta,
) for result in results )
for result in results
] ]
def query(self, def query(
query: str, self,
count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT, query: str,
namespace: Optional[str] = None, count: Optional[int] = BaseVectorStore.DEFAULT_QUERY_COUNT,
include_vectors: bool = False, namespace: Optional[str] = None,
distance_metric: str = "cosine_distance", include_vectors: bool = False,
**kwargs) -> list[BaseVectorStore.QueryResult]: distance_metric: str = "cosine_distance",
**kwargs
) -> list[BaseVectorStore.QueryResult]:
"""Performs a search on the collection to find vectors similar to the provided input vector, """Performs a search on the collection to find vectors similar to the provided input vector,
optionally filtering to only those that match the provided namespace. optionally filtering to only those that match the provided namespace.
""" """
@ -239,7 +245,8 @@ class PgVectorVectorStore(BaseVectorStore):
score=result[1], score=result[1],
meta=result[0].meta, meta=result[0].meta,
namespace=result[0].namespace, namespace=result[0].namespace,
) for result in results )
for result in results
] ]
def default_vector_model(self) -> any: def default_vector_model(self) -> any:

@ -102,12 +102,14 @@ class PineconeVectorStoreStore(BaseVector):
self.index = pinecone.Index(self.index_name) self.index = pinecone.Index(self.index_name)
def upsert_vector(self, def upsert_vector(
vector: list[float], self,
vector_id: Optional[str] = None, vector: list[float],
namespace: Optional[str] = None, vector_id: Optional[str] = None,
meta: Optional[dict] = None, namespace: Optional[str] = None,
**kwargs) -> str: meta: Optional[dict] = None,
**kwargs
) -> str:
"""Upsert vector""" """Upsert vector"""
vector_id = vector_id if vector_id else str_to_hash(str(vector)) vector_id = vector_id if vector_id else str_to_hash(str(vector))
@ -118,12 +120,10 @@ class PineconeVectorStoreStore(BaseVector):
return vector_id return vector_id
def load_entry( def load_entry(
self, self, vector_id: str, namespace: Optional[str] = None
vector_id: str, ) -> Optional[BaseVector.Entry]:
namespace: Optional[str] = None) -> Optional[BaseVector.Entry]:
"""Load entry""" """Load entry"""
result = self.index.fetch(ids=[vector_id], result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
namespace=namespace).to_dict()
vectors = list(result["vectors"].values()) vectors = list(result["vectors"].values())
if len(vectors) > 0: if len(vectors) > 0:
@ -138,8 +138,7 @@ class PineconeVectorStoreStore(BaseVector):
else: else:
return None return None
def load_entries(self, def load_entries(self, namespace: Optional[str] = None) -> list[BaseVector.Entry]:
namespace: Optional[str] = None) -> list[BaseVector.Entry]:
"""Load entries""" """Load entries"""
# This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching
# all values from a namespace: # all values from a namespace:
@ -158,18 +157,20 @@ class PineconeVectorStoreStore(BaseVector):
vector=r["values"], vector=r["values"],
meta=r["metadata"], meta=r["metadata"],
namespace=results["namespace"], namespace=results["namespace"],
) for r in results["matches"] )
for r in results["matches"]
] ]
def query( def query(
self, self,
query: str, query: str,
count: Optional[int] = None, count: Optional[int] = None,
namespace: Optional[str] = None, namespace: Optional[str] = None,
include_vectors: bool = False, include_vectors: bool = False,
# PineconeVectorStoreStorageDriver-specific params: # PineconeVectorStoreStorageDriver-specific params:
include_metadata=True, include_metadata=True,
**kwargs) -> list[BaseVector.QueryResult]: **kwargs
) -> list[BaseVector.QueryResult]:
"""Query vectors""" """Query vectors"""
vector = self.embedding_driver.embed_string(query) vector = self.embedding_driver.embed_string(query)
@ -189,14 +190,12 @@ class PineconeVectorStoreStore(BaseVector):
score=r["score"], score=r["score"],
meta=r["metadata"], meta=r["metadata"],
namespace=results["namespace"], namespace=results["namespace"],
) for r in results["matches"] )
for r in results["matches"]
] ]
def create_index(self, name: str, **kwargs) -> None: def create_index(self, name: str, **kwargs) -> None:
"""Create index""" """Create index"""
params = { params = {"name": name, "dimension": self.embedding_driver.dimensions} | kwargs
"name": name,
"dimension": self.embedding_driver.dimensions
} | kwargs
pinecone.create_index(**params) pinecone.create_index(**params)

@ -20,9 +20,9 @@ class Artifact(BaseModel):
description="Id of the artifact", description="Id of the artifact",
example="b225e278-8b4c-4f99-a696-8facf19f0e56", example="b225e278-8b4c-4f99-a696-8facf19f0e56",
) )
file_name: str = Field(..., file_name: str = Field(
description="Filename of the artifact", ..., description="Filename of the artifact", example="main.py"
example="main.py") )
relative_path: Optional[str] = Field( relative_path: Optional[str] = Field(
None, None,
description="Relative path of the artifact in the agent's workspace", description="Relative path of the artifact in the agent's workspace",
@ -50,8 +50,7 @@ class StepInput(BaseModel):
class StepOutput(BaseModel): class StepOutput(BaseModel):
__root__: Any = Field( __root__: Any = Field(
..., ...,
description= description="Output that the task step has produced. Any value is allowed.",
"Output that the task step has produced. Any value is allowed.",
example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}', example='{\n"tokens": 7894,\n"estimated_cost": "0,24$"\n}',
) )
@ -82,9 +81,9 @@ class Task(TaskRequestBody):
class StepRequestBody(BaseModel): class StepRequestBody(BaseModel):
input: Optional[str] = Field(None, input: Optional[str] = Field(
description="Input prompt for the step.", None, description="Input prompt for the step.", example="Washington"
example="Washington") )
additional_input: Optional[StepInput] = None additional_input: Optional[StepInput] = None
@ -105,19 +104,22 @@ class Step(StepRequestBody):
description="The ID of the task step.", description="The ID of the task step.",
example="6bb1801a-fd80-45e8-899a-4dd723cc602e", example="6bb1801a-fd80-45e8-899a-4dd723cc602e",
) )
name: Optional[str] = Field(None, name: Optional[str] = Field(
description="The name of the task step.", None, description="The name of the task step.", example="Write to file"
example="Write to file") )
status: Status = Field(..., description="The status of the task step.") status: Status = Field(..., description="The status of the task step.")
output: Optional[str] = Field( output: Optional[str] = Field(
None, None,
description="Output of the task step.", description="Output of the task step.",
example= example=(
("I am going to use the write_to_file command and write Washington to a file" "I am going to use the write_to_file command and write Washington to a file"
" called output.txt <write_to_file('output.txt', 'Washington')"), " called output.txt <write_to_file('output.txt', 'Washington')"
),
) )
additional_output: Optional[StepOutput] = None additional_output: Optional[StepOutput] = None
artifacts: List[Artifact] = Field( artifacts: List[Artifact] = Field(
[], description="A list of artifacts that the step has produced.") [], description="A list of artifacts that the step has produced."
)
is_last: Optional[bool] = Field( is_last: Optional[bool] = Field(
False, description="Whether this is the last step in the task.") False, description="Whether this is the last step in the task."
)

@ -43,8 +43,9 @@ def maximal_marginal_relevance(
if i in idxs: if i in idxs:
continue continue
redundant_score = max(similarity_to_selected[i]) redundant_score = max(similarity_to_selected[i])
equation_score = (lambda_mult * query_score - equation_score = (
(1 - lambda_mult) * redundant_score) lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score: if equation_score > best_score:
best_score = equation_score best_score = equation_score
idx_to_add = i idx_to_add = i
@ -56,8 +57,8 @@ def maximal_marginal_relevance(
def filter_complex_metadata( def filter_complex_metadata(
documents: List[Document], documents: List[Document],
*, *,
allowed_types: Tuple[Type, allowed_types: Tuple[Type, ...] = (str, bool, int, float)
...] = (str, bool, int, float)) -> List[Document]: ) -> List[Document]:
"""Filter out metadata types that are not supported for a vector store.""" """Filter out metadata types that are not supported for a vector store."""
updated_documents = [] updated_documents = []
for document in documents: for document in documents:

@ -41,24 +41,21 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive.""" """Validate specified keyword args are mutually exclusive."""
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Validate exactly one arg in each group is not None.""" """Validate exactly one arg in each group is not None."""
counts = [ counts = [
sum(1 sum(1 for arg in arg_group if kwargs.get(arg) is not None)
for arg in arg_group
if kwargs.get(arg) is not None)
for arg_group in arg_groups for arg_group in arg_groups
] ]
invalid_groups = [i for i, count in enumerate(counts) if count != 1] invalid_groups = [i for i, count in enumerate(counts) if count != 1]
if invalid_groups: if invalid_groups:
invalid_group_names = [ invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
", ".join(arg_groups[i]) for i in invalid_groups raise ValueError(
] "Exactly one argument in each of the following"
raise ValueError("Exactly one argument in each of the following" " groups must be defined:"
" groups must be defined:" f" {', '.join(invalid_group_names)}"
f" {', '.join(invalid_group_names)}") )
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
@ -108,10 +105,9 @@ def mock_now(dt_value): # type: ignore
datetime.datetime = real_datetime datetime.datetime = real_datetime
def guard_import(module_name: str, def guard_import(
*, module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
pip_name: Optional[str] = None, ) -> Any:
package: Optional[str] = None) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not """Dynamically imports a module and raises a helpful exception if the module is not
installed.""" installed."""
try: try:
@ -119,7 +115,8 @@ def guard_import(module_name: str,
except ImportError: except ImportError:
raise ImportError( raise ImportError(
f"Could not import {module_name} python package. " f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name or module_name}`.") f"Please install it with `pip install {pip_name or module_name}`."
)
return module return module
@ -135,19 +132,23 @@ def check_package_version(
if lt_version is not None and imported_version >= parse(lt_version): if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be < {lt_version}. Received " f"Expected {package} version to be < {lt_version}. Received "
f"{imported_version}.") f"{imported_version}."
)
if lte_version is not None and imported_version > parse(lte_version): if lte_version is not None and imported_version > parse(lte_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be <= {lte_version}. Received " f"Expected {package} version to be <= {lte_version}. Received "
f"{imported_version}.") f"{imported_version}."
)
if gt_version is not None and imported_version <= parse(gt_version): if gt_version is not None and imported_version <= parse(gt_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be > {gt_version}. Received " f"Expected {package} version to be > {gt_version}. Received "
f"{imported_version}.") f"{imported_version}."
)
if gte_version is not None and imported_version < parse(gte_version): if gte_version is not None and imported_version < parse(gte_version):
raise ValueError( raise ValueError(
f"Expected {package} version to be >= {gte_version}. Received " f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}.") f"{imported_version}."
)
def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]: def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
@ -179,17 +180,19 @@ def build_extra_kwargs(
if field_name in extra_kwargs: if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.") raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names: if field_name not in all_required_field_names:
warnings.warn(f"""WARNING! {field_name} is not default parameter. warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs. {field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""") Please confirm that {field_name} is what you intended."""
)
extra_kwargs[field_name] = values.pop(field_name) extra_kwargs[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection( invalid_model_kwargs = all_required_field_names.intersection(extra_kwargs.keys())
extra_kwargs.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "
"Instead they were passed in as part of `model_kwargs` parameter.") "Instead they were passed in as part of `model_kwargs` parameter."
)
return extra_kwargs return extra_kwargs
@ -238,16 +241,17 @@ class _AnthropicCommon(BaseLanguageModel):
def build_extra(cls, values: Dict) -> Dict: def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(extra, values, values["model_kwargs"] = build_extra_kwargs(
all_required_field_names) extra, values, all_required_field_names
)
return values return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str( values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
"ANTHROPIC_API_KEY")) )
# Get custom api url from environment. # Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env( values["anthropic_api_url"] = get_from_dict_or_env(
values, values,
@ -277,7 +281,8 @@ class _AnthropicCommon(BaseLanguageModel):
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import anthropic python package. " "Could not import anthropic python package. "
"Please it install it with `pip install anthropic`.") "Please it install it with `pip install anthropic`."
)
return values return values
@property @property
@ -300,8 +305,7 @@ class _AnthropicCommon(BaseLanguageModel):
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{}, **self._default_params} return {**{}, **self._default_params}
def _get_anthropic_stop(self, def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
stop: Optional[List[str]] = None) -> List[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT: if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded") raise NameError("Please ensure the anthropic package is loaded")
@ -368,8 +372,7 @@ class Anthropic(LLM, _AnthropicCommon):
return prompt # Already wrapped. return prompt # Already wrapped.
# Guard against common errors in specifying wrong number of newlines. # Guard against common errors in specifying wrong number of newlines.
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
prompt)
if n_subs == 1: if n_subs == 1:
return corrected_prompt return corrected_prompt
@ -402,10 +405,9 @@ class Anthropic(LLM, _AnthropicCommon):
""" """
if self.streaming: if self.streaming:
completion = "" completion = ""
for chunk in self._stream(prompt=prompt, for chunk in self._stream(
stop=stop, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
run_manager=run_manager, ):
**kwargs):
completion += chunk.text completion += chunk.text
return completion return completion
@ -431,10 +433,9 @@ class Anthropic(LLM, _AnthropicCommon):
"""Call out to Anthropic's completion endpoint asynchronously.""" """Call out to Anthropic's completion endpoint asynchronously."""
if self.streaming: if self.streaming:
completion = "" completion = ""
async for chunk in self._astream(prompt=prompt, async for chunk in self._astream(
stop=stop, prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
run_manager=run_manager, ):
**kwargs):
completion += chunk.text completion += chunk.text
return completion return completion
@ -475,10 +476,8 @@ class Anthropic(LLM, _AnthropicCommon):
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
for token in self.client.completions.create( for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
stop_sequences=stop, ):
stream=True,
**params):
chunk = GenerationChunk(text=token.completion) chunk = GenerationChunk(text=token.completion)
yield chunk yield chunk
if run_manager: if run_manager:
@ -510,10 +509,10 @@ class Anthropic(LLM, _AnthropicCommon):
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
async for token in await self.async_client.completions.create( async for token in await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
stream=True, stream=True,
**params, **params,
): ):
chunk = GenerationChunk(text=token.completion) chunk = GenerationChunk(text=token.completion)
yield chunk yield chunk

@ -97,8 +97,9 @@ class BioClip:
self.preprocess_val, self.preprocess_val,
) = open_clip.create_model_and_transforms(model_path) ) = open_clip.create_model_and_transforms(model_path)
self.tokenizer = open_clip.get_tokenizer(model_path) self.tokenizer = open_clip.get_tokenizer(model_path)
self.device = (torch.device("cuda") self.device = (
if torch.cuda.is_available() else torch.device("cpu")) torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@ -109,17 +110,18 @@ class BioClip:
template: str = "this is a photo of ", template: str = "this is a photo of ",
context_length: int = 256, context_length: int = 256,
): ):
image = torch.stack([self.preprocess_val(Image.open(img_path)) image = torch.stack([self.preprocess_val(Image.open(img_path))]).to(self.device)
]).to(self.device) texts = self.tokenizer(
texts = self.tokenizer([template + l for l in labels], [template + l for l in labels], context_length=context_length
context_length=context_length).to(self.device) ).to(self.device)
with torch.no_grad(): with torch.no_grad():
image_features, text_features, logit_scale = self.model( image_features, text_features, logit_scale = self.model(image, texts)
image, texts) logits = (
logits = ((logit_scale * (logit_scale * image_features @ text_features.t())
image_features @ text_features.t()).detach().softmax( .detach()
dim=-1)) .softmax(dim=-1)
)
sorted_indices = torch.argsort(logits, dim=-1, descending=True) sorted_indices = torch.argsort(logits, dim=-1, descending=True)
logits = logits.cpu().numpy() logits = logits.cpu().numpy()
sorted_indices = sorted_indices.cpu().numpy() sorted_indices = sorted_indices.cpu().numpy()
@ -137,8 +139,11 @@ class BioClip:
fig, ax = plt.subplots(figsize=(5, 5)) fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(img) ax.imshow(img)
ax.axis("off") ax.axis("off")
title = (metadata["filename"] + "\n" + "\n".join( title = (
[f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()])) metadata["filename"]
+ "\n"
+ "\n".join([f"{k}: {v*100:.1f}" for k, v in metadata["top_probs"].items()])
)
ax.set_title(title, fontsize=14) ax.set_title(title, fontsize=14)
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

@ -102,9 +102,9 @@ class BioGPT:
list[dict]: A list of generated texts. list[dict]: A list of generated texts.
""" """
set_seed(42) set_seed(42)
generator = pipeline("text-generation", generator = pipeline(
model=self.model, "text-generation", model=self.model, tokenizer=self.tokenizer
tokenizer=self.tokenizer) )
out = generator( out = generator(
text, text,
max_length=self.max_length, max_length=self.max_length,
@ -149,11 +149,13 @@ class BioGPT:
inputs = self.tokenizer(sentence, return_tensors="pt") inputs = self.tokenizer(sentence, return_tensors="pt")
set_seed(42) set_seed(42)
with torch.no_grad(): with torch.no_grad():
beam_output = self.model.generate(**inputs, beam_output = self.model.generate(
min_length=self.min_length, **inputs,
max_length=self.max_length, min_length=self.min_length,
num_beams=num_beams, max_length=self.max_length,
early_stopping=early_stopping) num_beams=num_beams,
early_stopping=early_stopping
)
return self.tokenizer.decode(beam_output[0], skip_special_tokens=True) return self.tokenizer.decode(beam_output[0], skip_special_tokens=True)
# Feature 1: Set a new tokenizer and model # Feature 1: Set a new tokenizer and model

@ -124,10 +124,13 @@ class Dalle3:
# Handling exceptions and printing the errors details # Handling exceptions and printing the errors details
print( print(
colored( colored(
(f"Error running Dalle3: {error} try optimizing your api key and" (
" or try again"), f"Error running Dalle3: {error} try optimizing your api key and"
" or try again"
),
"red", "red",
)) )
)
raise error raise error
def create_variations(self, img: str): def create_variations(self, img: str):
@ -154,19 +157,22 @@ class Dalle3:
""" """
try: try:
response = self.client.images.create_variation(img=open(img, "rb"), response = self.client.images.create_variation(
n=self.n, img=open(img, "rb"), n=self.n, size=self.size
size=self.size) )
img = response.data[0].url img = response.data[0].url
return img return img
except (Exception, openai.OpenAIError) as error: except (Exception, openai.OpenAIError) as error:
print( print(
colored( colored(
(f"Error running Dalle3: {error} try optimizing your api key and" (
" or try again"), f"Error running Dalle3: {error} try optimizing your api key and"
" or try again"
),
"red", "red",
)) )
)
print(colored(f"Error running Dalle3: {error.http_status}", "red")) print(colored(f"Error running Dalle3: {error.http_status}", "red"))
print(colored(f"Error running Dalle3: {error.error}", "red")) print(colored(f"Error running Dalle3: {error.error}", "red"))
raise error raise error

@ -18,7 +18,6 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1):
""" """
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
retries = max_retries retries = max_retries
@ -29,9 +28,7 @@ def async_retry(max_retries=3, exceptions=(Exception,), delay=1):
retries -= 1 retries -= 1
if retries <= 0: if retries <= 0:
raise raise
print( print(f"Retry after exception: {e}, Attempts remaining: {retries}")
f"Retry after exception: {e}, Attempts remaining: {retries}"
)
await asyncio.sleep(delay) await asyncio.sleep(delay)
return wrapper return wrapper
@ -65,8 +62,7 @@ class DistilWhisperModel:
def __init__(self, model_id="distil-whisper/distil-large-v2"): def __init__(self, model_id="distil-whisper/distil-large-v2"):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available( self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
) else torch.float32
self.model_id = model_id self.model_id = model_id
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, model_id,
@ -123,14 +119,14 @@ class DistilWhisperModel:
try: try:
with torch.no_grad(): with torch.no_grad():
# Load the whole audio file, but process and transcribe it in chunks # Load the whole audio file, but process and transcribe it in chunks
audio_input = self.processor.audio_file_to_array( audio_input = self.processor.audio_file_to_array(audio_file_path)
audio_file_path)
sample_rate = audio_input.sampling_rate sample_rate = audio_input.sampling_rate
total_duration = len(audio_input.array) / sample_rate total_duration = len(audio_input.array) / sample_rate
chunks = [ chunks = [
audio_input.array[i:i + sample_rate * chunk_duration] audio_input.array[i : i + sample_rate * chunk_duration]
for i in range(0, len(audio_input.array), sample_rate * for i in range(
chunk_duration) 0, len(audio_input.array), sample_rate * chunk_duration
)
] ]
print(colored("Starting real-time transcription...", "green")) print(colored("Starting real-time transcription...", "green"))
@ -143,22 +139,22 @@ class DistilWhisperModel:
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
) )
processed_inputs = processed_inputs.input_values.to( processed_inputs = processed_inputs.input_values.to(self.device)
self.device)
# Generate transcription for the chunk # Generate transcription for the chunk
logits = self.model.generate(processed_inputs) logits = self.model.generate(processed_inputs)
transcription = self.processor.batch_decode( transcription = self.processor.batch_decode(
logits, skip_special_tokens=True)[0] logits, skip_special_tokens=True
)[0]
# Print the chunk's transcription # Print the chunk's transcription
print( print(
colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow") + colored(f"Chunk {i+1}/{len(chunks)}: ", "yellow")
transcription) + transcription
)
# Wait for the chunk's duration to simulate real-time processing # Wait for the chunk's duration to simulate real-time processing
time.sleep(chunk_duration) time.sleep(chunk_duration)
except Exception as e: except Exception as e:
print(colored(f"An error occurred during transcription: {e}", print(colored(f"An error occurred during transcription: {e}", "red"))
"red"))

@ -11,8 +11,7 @@ from pydantic import BaseModel, StrictFloat, StrictInt, validator
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the classes for image classification # Load the classes for image classification
with open(os.path.join(os.path.dirname(__file__), with open(os.path.join(os.path.dirname(__file__), "fast_vit_classes.json")) as f:
"fast_vit_classes.json")) as f:
FASTVIT_IMAGENET_1K_CLASSES = json.load(f) FASTVIT_IMAGENET_1K_CLASSES = json.load(f)
@ -22,8 +21,7 @@ class ClassificationResult(BaseModel):
@validator("class_id", "confidence", pre=True, each_item=True) @validator("class_id", "confidence", pre=True, each_item=True)
def check_list_contents(cls, v): def check_list_contents(cls, v):
assert isinstance(v, int) or isinstance( assert isinstance(v, int) or isinstance(v, float), "must be integer or float"
v, float), "must be integer or float"
return v return v
@ -49,16 +47,16 @@ class FastViT:
""" """
def __init__(self): def __init__(self):
self.model = timm.create_model("hf_hub:timm/fastvit_s12.apple_in1k", self.model = timm.create_model(
pretrained=True).to(DEVICE) "hf_hub:timm/fastvit_s12.apple_in1k", pretrained=True
).to(DEVICE)
data_config = timm.data.resolve_model_data_config(self.model) data_config = timm.data.resolve_model_data_config(self.model)
self.transforms = timm.data.create_transform(**data_config, self.transforms = timm.data.create_transform(**data_config, is_training=False)
is_training=False)
self.model.eval() self.model.eval()
def __call__(self, def __call__(
img: str, self, img: str, confidence_threshold: float = 0.5
confidence_threshold: float = 0.5) -> ClassificationResult: ) -> ClassificationResult:
"""classifies the input image and returns the top k classes and their probabilities""" """classifies the input image and returns the top k classes and their probabilities"""
img = Image.open(img).convert("RGB") img = Image.open(img).convert("RGB")
img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE) img_tensor = self.transforms(img).unsqueeze(0).to(DEVICE)
@ -67,8 +65,9 @@ class FastViT:
probabilities = torch.nn.functional.softmax(output, dim=1) probabilities = torch.nn.functional.softmax(output, dim=1)
# Get top k classes and their probabilities # Get top k classes and their probabilities
top_probs, top_classes = torch.topk(probabilities, top_probs, top_classes = torch.topk(
k=FASTVIT_IMAGENET_1K_CLASSES) probabilities, k=FASTVIT_IMAGENET_1K_CLASSES
)
# Filter by confidence threshold # Filter by confidence threshold
mask = top_probs > confidence_threshold mask = top_probs > confidence_threshold

@ -45,9 +45,9 @@ class Fuyu:
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
self.image_processor = FuyuImageProcessor() self.image_processor = FuyuImageProcessor()
self.processor = FuyuProcessor(image_processor=self.image_processor, self.processor = FuyuProcessor(
tokenizer=self.tokenizer, image_processor=self.image_processor, tokenizer=self.tokenizer, **kwargs
**kwargs) )
self.model = FuyuForCausalLM.from_pretrained( self.model = FuyuForCausalLM.from_pretrained(
pretrained_path, pretrained_path,
device_map=device_map, device_map=device_map,
@ -62,17 +62,15 @@ class Fuyu:
def __call__(self, text: str, img: str): def __call__(self, text: str, img: str):
"""Call the model with text and img paths""" """Call the model with text and img paths"""
image_pil = Image.open(img) image_pil = Image.open(img)
model_inputs = self.processor(text=text, model_inputs = self.processor(
images=[image_pil], text=text, images=[image_pil], device=self.device_map
device=self.device_map) )
for k, v in model_inputs.items(): for k, v in model_inputs.items():
model_inputs[k] = v.to(self.device_map) model_inputs[k] = v.to(self.device_map)
output = self.model.generate(**model_inputs, output = self.model.generate(**model_inputs, max_new_tokens=self.max_new_tokens)
max_new_tokens=self.max_new_tokens) text = self.processor.batch_decode(output[:, -7:], skip_special_tokens=True)
text = self.processor.batch_decode(output[:, -7:],
skip_special_tokens=True)
return print(str(text)) return print(str(text))
def get_img_from_web(self, img_url: str): def get_img_from_web(self, img_url: str):

@ -69,7 +69,9 @@ class GPT4Vision:
quality: str = "low" quality: str = "low"
# Max tokens to use for the API request, the maximum might be 3,000 but we don't know # Max tokens to use for the API request, the maximum might be 3,000 but we don't know
max_tokens: int = 200 max_tokens: int = 200
client = OpenAI(api_key=openai_api_key,) client = OpenAI(
api_key=openai_api_key,
)
dashboard: bool = True dashboard: bool = True
call_limit: int = 1 call_limit: int = 1
period_seconds: int = 60 period_seconds: int = 60
@ -88,8 +90,9 @@ class GPT4Vision:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")
@sleep_and_retry @sleep_and_retry
@limits(calls=call_limit, @limits(
period=period_seconds) # Rate limit of 10 calls per minute calls=call_limit, period=period_seconds
) # Rate limit of 10 calls per minute
def run(self, task: str, img: str): def run(self, task: str, img: str):
""" """
Run the GPT-4 Vision model Run the GPT-4 Vision model
@ -105,22 +108,20 @@ class GPT4Vision:
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model="gpt-4-vision-preview", model="gpt-4-vision-preview",
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": task},
"type": "text", {
"text": task "type": "image_url",
}, "image_url": {
{ "url": str(img),
"type": "image_url", },
"image_url": {
"url": str(img),
}, },
}, ],
], }
}], ],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -160,22 +161,20 @@ class GPT4Vision:
try: try:
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
model="gpt-4-vision-preview", model="gpt-4-vision-preview",
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": task},
"type": "text", {
"text": task "type": "image_url",
}, "image_url": {
{ "url": img,
"type": "image_url", },
"image_url": {
"url": img,
}, },
}, ],
], }
}], ],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -190,14 +189,12 @@ class GPT4Vision:
"""Process a batch of tasks and images""" """Process a batch of tasks and images"""
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [ futures = [
executor.submit(self.run, task, img) executor.submit(self.run, task, img) for task, img in tasks_images
for task, img in tasks_images
] ]
results = [future.result() for future in futures] results = [future.result() for future in futures]
return results return results
async def run_batch_async(self, async def run_batch_async(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
tasks_images: List[Tuple[str, str]]) -> List[str]:
"""Process a batch of tasks and images asynchronously""" """Process a batch of tasks and images asynchronously"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
futures = [ futures = [
@ -207,7 +204,8 @@ class GPT4Vision:
return await asyncio.gather(*futures) return await asyncio.gather(*futures)
async def run_batch_async_with_retries( async def run_batch_async_with_retries(
self, tasks_images: List[Tuple[str, str]]) -> List[str]: self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously with retries""" """Process a batch of tasks and images asynchronously with retries"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
futures = [ futures = [
@ -231,7 +229,8 @@ class GPT4Vision:
""", """,
"green", "green",
)) )
)
return dashboard return dashboard
def health_check(self): def health_check(self):

@ -47,8 +47,9 @@ class HuggingfaceLLM:
**kwargs, **kwargs,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = (device if device else self.device = (
("cuda" if torch.cuda.is_available() else "cpu")) device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -57,8 +58,9 @@ class HuggingfaceLLM:
self.model, self.tokenizer = None, None self.model, self.tokenizer = None, None
if self.distributed: if self.distributed:
assert (torch.cuda.device_count() > assert (
1), "You need more than 1 gpu for distributed processing" torch.cuda.device_count() > 1
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -73,17 +75,17 @@ class HuggingfaceLLM:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, *args, **kwargs) self.model_id, *args, **kwargs
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config, *args, **kwargs) self.model_id, quantization_config=bnb_config, *args, **kwargs
)
self.model # .to(self.device) self.model # .to(self.device)
except Exception as e: except Exception as e:
# self.logger.error(f"Failed to load the model or the tokenizer: {e}") # self.logger.error(f"Failed to load the model or the tokenizer: {e}")
# raise # raise
print( print(colored(f"Failed to load the model and or the tokenizer: {e}", "red"))
colored(f"Failed to load the model and or the tokenizer: {e}",
"red"))
def print_error(self, error: str): def print_error(self, error: str):
"""Print error""" """Print error"""
@ -95,18 +97,20 @@ class HuggingfaceLLM:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = (BitsAndBytesConfig(**self.quantization_config) bnb_config = (
if self.quantization_config else None) BitsAndBytesConfig(**self.quantization_config)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id, quantization_config=bnb_config
quantization_config=bnb_config).to(self.device) ).to(self.device)
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error( self.logger.error(f"Failed to load the model or the tokenizer: {error}")
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, task: str): def run(self, task: str):
@ -127,8 +131,7 @@ class HuggingfaceLLM:
self.print_dashboard(task) self.print_dashboard(task)
try: try:
inputs = self.tokenizer.encode(task, inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device)
return_tensors="pt").to(self.device)
# self.log.start() # self.log.start()
@ -137,36 +140,39 @@ class HuggingfaceLLM:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=len(inputs) + inputs, max_length=len(inputs) + 1, do_sample=True
1, )
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode([output_tokens], self.tokenizer.decode(
skip_special_tokens=True), [output_tokens], skip_special_tokens=True
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=max_length, inputs, max_length=max_length, do_sample=True
do_sample=True) )
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e: except Exception as e:
print( print(
colored( colored(
(f"HuggingfaceLLM could not generate text because of error: {e}," (
" try optimizing your arguments"), f"HuggingfaceLLM could not generate text because of error: {e},"
" try optimizing your arguments"
),
"red", "red",
)) )
)
raise raise
async def run_async(self, task: str, *args, **kwargs) -> str: async def run_async(self, task: str, *args, **kwargs) -> str:
@ -210,8 +216,7 @@ class HuggingfaceLLM:
self.print_dashboard(task) self.print_dashboard(task)
try: try:
inputs = self.tokenizer.encode(task, inputs = self.tokenizer.encode(task, return_tensors="pt").to(self.device)
return_tensors="pt").to(self.device)
# self.log.start() # self.log.start()
@ -220,26 +225,26 @@ class HuggingfaceLLM:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=len(inputs) + inputs, max_length=len(inputs) + 1, do_sample=True
1, )
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode([output_tokens], self.tokenizer.decode(
skip_special_tokens=True), [output_tokens], skip_special_tokens=True
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=max_length, inputs, max_length=max_length, do_sample=True
do_sample=True) )
del inputs del inputs
@ -300,7 +305,8 @@ class HuggingfaceLLM:
""", """,
"red", "red",
)) )
)
print(dashboard) print(dashboard)

@ -65,8 +65,9 @@ class Idefics:
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
max_length=100, max_length=100,
): ):
self.device = (device if device else self.device = (
("cuda" if torch.cuda.is_available() else "cpu")) device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model = IdeficsForVisionText2Text.from_pretrained( self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, checkpoint,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -95,17 +96,21 @@ class Idefics:
list list
A list of generated text strings. A list of generated text strings.
""" """
inputs = (self.processor( inputs = (
prompts, add_end_of_utterance_token=False, return_tensors="pt").to( self.processor(
self.device) if batched_mode else self.processor( prompts, add_end_of_utterance_token=False, return_tensors="pt"
prompts[0], return_tensors="pt").to(self.device)) ).to(self.device)
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(self.device)
)
exit_condition = self.processor.tokenizer( exit_condition = self.processor.tokenizer(
"<end_of_utterance>", add_special_tokens=False).input_ids "<end_of_utterance>", add_special_tokens=False
).input_ids
bad_words_ids = self.processor.tokenizer( bad_words_ids = self.processor.tokenizer(
["<image>", "<fake_token_around_image"], ["<image>", "<fake_token_around_image"], add_special_tokens=False
add_special_tokens=False).input_ids ).input_ids
generated_ids = self.model.generate( generated_ids = self.model.generate(
**inputs, **inputs,
@ -113,8 +118,9 @@ class Idefics:
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
max_length=self.max_length, max_length=self.max_length,
) )
generated_text = self.processor.batch_decode(generated_ids, generated_text = self.processor.batch_decode(
skip_special_tokens=True) generated_ids, skip_special_tokens=True
)
return generated_text return generated_text
def __call__(self, prompts, batched_mode=True): def __call__(self, prompts, batched_mode=True):
@ -135,17 +141,21 @@ class Idefics:
list list
A list of generated text strings. A list of generated text strings.
""" """
inputs = (self.processor( inputs = (
prompts, add_end_of_utterance_token=False, return_tensors="pt").to( self.processor(
self.device) if batched_mode else self.processor( prompts, add_end_of_utterance_token=False, return_tensors="pt"
prompts[0], return_tensors="pt").to(self.device)) ).to(self.device)
if batched_mode
else self.processor(prompts[0], return_tensors="pt").to(self.device)
)
exit_condition = self.processor.tokenizer( exit_condition = self.processor.tokenizer(
"<end_of_utterance>", add_special_tokens=False).input_ids "<end_of_utterance>", add_special_tokens=False
).input_ids
bad_words_ids = self.processor.tokenizer( bad_words_ids = self.processor.tokenizer(
["<image>", "<fake_token_around_image"], ["<image>", "<fake_token_around_image"], add_special_tokens=False
add_special_tokens=False).input_ids ).input_ids
generated_ids = self.model.generate( generated_ids = self.model.generate(
**inputs, **inputs,
@ -153,8 +163,9 @@ class Idefics:
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
max_length=self.max_length, max_length=self.max_length,
) )
generated_text = self.processor.batch_decode(generated_ids, generated_text = self.processor.batch_decode(
skip_special_tokens=True) generated_ids, skip_special_tokens=True
)
return generated_text return generated_text
def chat(self, user_input): def chat(self, user_input):
@ -191,7 +202,8 @@ class Idefics:
The name of the new pre-trained model checkpoint. The name of the new pre-trained model checkpoint.
""" """
self.model = IdeficsForVisionText2Text.from_pretrained( self.model = IdeficsForVisionText2Text.from_pretrained(
checkpoint, torch_dtype=torch.bfloat16).to(self.device) checkpoint, torch_dtype=torch.bfloat16
).to(self.device)
self.processor = AutoProcessor.from_pretrained(checkpoint) self.processor = AutoProcessor.from_pretrained(checkpoint)
def set_device(self, device): def set_device(self, device):

@ -53,8 +53,9 @@ class JinaEmbeddings:
**kwargs, **kwargs,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = (device if device else self.device = (
("cuda" if torch.cuda.is_available() else "cpu")) device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -65,8 +66,9 @@ class JinaEmbeddings:
self.cos_sim = cos_sim self.cos_sim = cos_sim
if self.distributed: if self.distributed:
assert (torch.cuda.device_count() > assert (
1), "You need more than 1 gpu for distributed processing" torch.cuda.device_count() > 1
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -81,9 +83,8 @@ class JinaEmbeddings:
try: try:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id, quantization_config=bnb_config, trust_remote_code=True
quantization_config=bnb_config, )
trust_remote_code=True)
self.model # .to(self.device) self.model # .to(self.device)
except Exception as e: except Exception as e:
@ -96,8 +97,11 @@ class JinaEmbeddings:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = (BitsAndBytesConfig(**self.quantization_config) bnb_config = (
if self.quantization_config else None) BitsAndBytesConfig(**self.quantization_config)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id,
@ -108,8 +112,7 @@ class JinaEmbeddings:
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error( self.logger.error(f"Failed to load the model or the tokenizer: {error}")
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, task: str): def run(self, task: str):

@ -14,8 +14,11 @@ class Detections(BaseModel):
@root_validator @root_validator
def check_length(cls, values): def check_length(cls, values):
assert (len(values.get("xyxy")) == len(values.get("class_id")) == len( assert (
values.get("confidence"))), "All fields must have the same length." len(values.get("xyxy"))
== len(values.get("class_id"))
== len(values.get("confidence"))
), "All fields must have the same length."
return values return values
@validator("xyxy", "class_id", "confidence", pre=True, each_item=True) @validator("xyxy", "class_id", "confidence", pre=True, each_item=True)
@ -36,9 +39,11 @@ class Kosmos2(BaseModel):
@classmethod @classmethod
def initialize(cls): def initialize(cls):
model = AutoModelForVision2Seq.from_pretrained( model = AutoModelForVision2Seq.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True) "ydshieh/kosmos-2-patch14-224", trust_remote_code=True
)
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
"ydshieh/kosmos-2-patch14-224", trust_remote_code=True) "ydshieh/kosmos-2-patch14-224", trust_remote_code=True
)
return cls(model=model, processor=processor) return cls(model=model, processor=processor)
def __call__(self, img: str) -> Detections: def __call__(self, img: str) -> Detections:
@ -46,12 +51,11 @@ class Kosmos2(BaseModel):
prompt = "<grounding>An image of" prompt = "<grounding>An image of"
inputs = self.processor(text=prompt, images=image, return_tensors="pt") inputs = self.processor(text=prompt, images=image, return_tensors="pt")
outputs = self.model.generate(**inputs, outputs = self.model.generate(**inputs, use_cache=True, max_new_tokens=64)
use_cache=True,
max_new_tokens=64)
generated_text = self.processor.batch_decode( generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[
outputs, skip_special_tokens=True)[0] 0
]
# The actual processing of generated_text to entities would go here # The actual processing of generated_text to entities would go here
# For the purpose of this example, assume a mock function 'extract_entities' exists: # For the purpose of this example, assume a mock function 'extract_entities' exists:
@ -62,8 +66,8 @@ class Kosmos2(BaseModel):
return detections return detections
def extract_entities( def extract_entities(
self, self, text: str
text: str) -> List[Tuple[str, Tuple[float, float, float, float]]]: ) -> List[Tuple[str, Tuple[float, float, float, float]]]:
# Placeholder function for entity extraction # Placeholder function for entity extraction
# This should be replaced with the actual method of extracting entities # This should be replaced with the actual method of extracting entities
return [] return []
@ -76,19 +80,19 @@ class Kosmos2(BaseModel):
if not entities: if not entities:
return Detections.empty() return Detections.empty()
class_ids = [0] * len( class_ids = [0] * len(entities) # Replace with actual class ID extraction logic
entities) # Replace with actual class ID extraction logic xyxys = [
xyxys = [( (
e[1][0] * image.width, e[1][0] * image.width,
e[1][1] * image.height, e[1][1] * image.height,
e[1][2] * image.width, e[1][2] * image.width,
e[1][3] * image.height, e[1][3] * image.height,
) for e in entities] )
for e in entities
]
confidences = [1.0] * len(entities) # Placeholder confidence confidences = [1.0] * len(entities) # Placeholder confidence
return Detections(xyxy=xyxys, return Detections(xyxy=xyxys, class_id=class_ids, confidence=confidences)
class_id=class_ids,
confidence=confidences)
# Usage: # Usage:

@ -46,9 +46,11 @@ class Kosmos:
model_name="ydshieh/kosmos-2-patch14-224", model_name="ydshieh/kosmos-2-patch14-224",
): ):
self.model = AutoModelForVision2Seq.from_pretrained( self.model = AutoModelForVision2Seq.from_pretrained(
model_name, trust_remote_code=True) model_name, trust_remote_code=True
self.processor = AutoProcessor.from_pretrained(model_name, )
trust_remote_code=True) self.processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True
)
def get_image(self, url): def get_image(self, url):
"""Image""" """Image"""
@ -71,7 +73,8 @@ class Kosmos:
skip_special_tokens=True, skip_special_tokens=True,
)[0] )[0]
processed_text, entities = self.processor.post_process_generation( processed_text, entities = self.processor.post_process_generation(
generated_texts) generated_texts
)
def __call__(self, prompt, image): def __call__(self, prompt, image):
"""Run call""" """Run call"""
@ -90,7 +93,8 @@ class Kosmos:
skip_special_tokens=True, skip_special_tokens=True,
)[0] )[0]
processed_text, entities = self.processor.post_process_generation( processed_text, entities = self.processor.post_process_generation(
generated_texts) generated_texts
)
# tasks # tasks
def multimodal_grounding(self, phrase, image_url): def multimodal_grounding(self, phrase, image_url):
@ -141,10 +145,12 @@ class Kosmos:
elif isinstance(image, torch.Tensor): elif isinstance(image, torch.Tensor):
# pdb.set_trace() # pdb.set_trace()
image_tensor = image.cpu() image_tensor = image.cpu()
reverse_norm_mean = torch.tensor( reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[
[0.48145466, 0.4578275, 0.40821073])[:, None, None] :, None, None
reverse_norm_std = torch.tensor( ]
[0.26862954, 0.26130258, 0.27577711])[:, None, None] reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[
:, None, None
]
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
pil_img = T.ToPILImage()(image_tensor) pil_img = T.ToPILImage()(image_tensor)
image_h = pil_img.height image_h = pil_img.height
@ -163,9 +169,9 @@ class Kosmos:
# thickness of text # thickness of text
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
box_line = 3 box_line = 3
(c_width, text_height), _ = cv2.getTextSize("F", (c_width, text_height), _ = cv2.getTextSize(
cv2.FONT_HERSHEY_COMPLEX, "F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line
text_size, text_line) )
base_height = int(text_height * 0.675) base_height = int(text_height * 0.675)
text_offset_original = text_height - base_height text_offset_original = text_height - base_height
text_spaces = 3 text_spaces = 3
@ -181,8 +187,9 @@ class Kosmos:
# draw bbox # draw bbox
# random color # random color
color = tuple(np.random.randint(0, 255, size=3).tolist()) color = tuple(np.random.randint(0, 255, size=3).tolist())
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), new_image = cv2.rectangle(
(orig_x2, orig_y2), color, box_line) new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line
)
l_o, r_o = ( l_o, r_o = (
box_line // 2 + box_line % 2, box_line // 2 + box_line % 2,
@ -193,15 +200,19 @@ class Kosmos:
y1 = orig_y1 - l_o y1 = orig_y1 - l_o
if y1 < text_height + text_offset_original + 2 * text_spaces: if y1 < text_height + text_offset_original + 2 * text_spaces:
y1 = (orig_y1 + r_o + text_height + text_offset_original + y1 = (
2 * text_spaces) orig_y1
+ r_o
+ text_height
+ text_offset_original
+ 2 * text_spaces
)
x1 = orig_x1 + r_o x1 = orig_x1 + r_o
# add text background # add text background
(text_width, (text_width, text_height), _ = cv2.getTextSize(
text_height), _ = cv2.getTextSize(f" {entity_name}", f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line
cv2.FONT_HERSHEY_COMPLEX, )
text_size, text_line)
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = ( text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = (
x1, x1,
y1 - (text_height + text_offset_original + 2 * text_spaces), y1 - (text_height + text_offset_original + 2 * text_spaces),
@ -211,19 +222,23 @@ class Kosmos:
for prev_bbox in previous_bboxes: for prev_bbox in previous_bboxes:
while is_overlapping( while is_overlapping(
(text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox
prev_bbox): ):
text_bg_y1 += (text_height + text_offset_original + text_bg_y1 += (
2 * text_spaces) text_height + text_offset_original + 2 * text_spaces
text_bg_y2 += (text_height + text_offset_original + )
2 * text_spaces) text_bg_y2 += (
text_height + text_offset_original + 2 * text_spaces
)
y1 += text_height + text_offset_original + 2 * text_spaces y1 += text_height + text_offset_original + 2 * text_spaces
if text_bg_y2 >= image_h: if text_bg_y2 >= image_h:
text_bg_y1 = max( text_bg_y1 = max(
0, 0,
image_h - (text_height + text_offset_original + image_h
2 * text_spaces), - (
text_height + text_offset_original + 2 * text_spaces
),
) )
text_bg_y2 = image_h text_bg_y2 = image_h
y1 = image_h y1 = image_h
@ -240,9 +255,9 @@ class Kosmos:
# white # white
bg_color = [255, 255, 255] bg_color = [255, 255, 255]
new_image[i, j] = ( new_image[i, j] = (
alpha * new_image[i, j] + alpha * new_image[i, j]
(1 - alpha) * np.array(bg_color)).astype( + (1 - alpha) * np.array(bg_color)
np.uint8) ).astype(np.uint8)
cv2.putText( cv2.putText(
new_image, new_image,
@ -255,8 +270,7 @@ class Kosmos:
cv2.LINE_AA, cv2.LINE_AA,
) )
# previous_locations.append((x1, y1)) # previous_locations.append((x1, y1))
previous_bboxes.append( previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
(text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
if save_path: if save_path:

@ -48,8 +48,9 @@ class MultiModalLlava:
revision=revision, revision=revision,
).to(self.device) ).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, self.tokenizer = AutoTokenizer.from_pretrained(
use_fast=True) model_name_or_path, use_fast=True
)
self.pipe = pipeline( self.pipe = pipeline(
"text-generation", "text-generation",
model=self.model, model=self.model,

@ -49,8 +49,7 @@ class Mistral:
# Check if the specified device is available # Check if the specified device is available
if not torch.cuda.is_available() and device == "cuda": if not torch.cuda.is_available() and device == "cuda":
raise ValueError( raise ValueError("CUDA is not available. Please choose a different device.")
"CUDA is not available. Please choose a different device.")
# Load the model and tokenizer # Load the model and tokenizer
self.model = None self.model = None
@ -71,8 +70,7 @@ class Mistral:
"""Run the model on a given task.""" """Run the model on a given task."""
try: try:
model_inputs = self.tokenizer([task], model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
return_tensors="pt").to(self.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
**model_inputs, **model_inputs,
max_length=self.max_length, max_length=self.max_length,
@ -89,8 +87,7 @@ class Mistral:
"""Run the model on a given task.""" """Run the model on a given task."""
try: try:
model_inputs = self.tokenizer([task], model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
return_tensors="pt").to(self.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
**model_inputs, **model_inputs,
max_length=self.max_length, max_length=self.max_length,

@ -26,10 +26,7 @@ class MPT7B:
""" """
def __init__(self, def __init__(self, model_name: str, tokenizer_name: str, max_tokens: int = 100):
model_name: str,
tokenizer_name: str,
max_tokens: int = 100):
# Loading model and tokenizer details # Loading model and tokenizer details
self.model_name = model_name self.model_name = model_name
self.tokenizer_name = tokenizer_name self.tokenizer_name = tokenizer_name
@ -40,9 +37,11 @@ class MPT7B:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
config = AutoModelForCausalLM.from_pretrained( config = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True).config model_name, trust_remote_code=True
).config
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, config=config, trust_remote_code=True) model_name, config=config, trust_remote_code=True
)
# Initializing a text-generation pipeline # Initializing a text-generation pipeline
self.pipe = pipeline( self.pipe = pipeline(
@ -115,10 +114,9 @@ class MPT7B:
""" """
with torch.autocast("cuda", dtype=torch.bfloat16): with torch.autocast("cuda", dtype=torch.bfloat16):
return self.pipe(prompt, return self.pipe(
max_new_tokens=self.max_tokens, prompt, max_new_tokens=self.max_tokens, do_sample=True, use_cache=True
do_sample=True, )[0]["generated_text"]
use_cache=True)[0]["generated_text"]
async def generate_async(self, prompt: str) -> str: async def generate_async(self, prompt: str) -> str:
"""Generate Async""" """Generate Async"""

@ -41,10 +41,8 @@ class Nougat:
self.min_length = min_length self.min_length = min_length
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.processor = NougatProcessor.from_pretrained( self.processor = NougatProcessor.from_pretrained(self.model_name_or_path)
self.model_name_or_path) self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(
self.model_name_or_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device) self.model.to(self.device)
@ -65,10 +63,8 @@ class Nougat:
max_new_tokens=self.max_new_tokens, max_new_tokens=self.max_new_tokens,
) )
sequence = self.processor.batch_decode(outputs, sequence = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
skip_special_tokens=True)[0] sequence = self.processor.post_process_generation(sequence, fix_markdown=False)
sequence = self.processor.post_process_generation(sequence,
fix_markdown=False)
out = print(sequence) out = print(sequence)
return out return out
@ -76,7 +72,8 @@ class Nougat:
def clean_nougat_output(raw_output): def clean_nougat_output(raw_output):
# Define the pattern to extract the relevant data # Define the pattern to extract the relevant data
daily_balance_pattern = ( daily_balance_pattern = (
r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*") r"\*\*(\d{2}/\d{2}/\d{4})\*\*\n\n\*\*([\d,]+\.\d{2})\*\*"
)
# Find all matches of the pattern # Find all matches of the pattern
matches = re.findall(daily_balance_pattern, raw_output) matches = re.findall(daily_balance_pattern, raw_output)

@ -55,9 +55,9 @@ class OpenAIAssistant:
return thread return thread
def add_message_to_thread(self, thread_id: str, message: str): def add_message_to_thread(self, thread_id: str, message: str):
message = self.client.beta.threads.add_message(thread_id=thread_id, message = self.client.beta.threads.add_message(
role=self.user, thread_id=thread_id, role=self.user, content=message
content=message) )
return message return message
def run(self, task: str): def run(self, task: str):
@ -67,7 +67,8 @@ class OpenAIAssistant:
instructions=self.instructions, instructions=self.instructions,
) )
out = self.client.beta.threads.runs.retrieve(thread_id=run.thread_id, out = self.client.beta.threads.runs.retrieve(
run_id=run.id) thread_id=run.thread_id, run_id=run.id
)
return out return out

@ -28,10 +28,9 @@ from tenacity import (
from swarms.models.embeddings_base import Embeddings from swarms.models.embeddings_base import Embeddings
def get_from_dict_or_env(values: dict, def get_from_dict_or_env(
key: str, values: dict, key: str, env_key: str, default: Any = None
env_key: str, ) -> Any:
default: Any = None) -> Any:
import os import os
return values.get(key) or os.getenv(env_key) or default return values.get(key) or os.getenv(env_key) or default
@ -44,8 +43,7 @@ def get_pydantic_field_names(cls: Any) -> Set[str]:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _create_retry_decorator( def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]:
embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]:
import llm import llm
min_seconds = 4 min_seconds = 4
@ -56,11 +54,13 @@ def _create_retry_decorator(
reraise=True, reraise=True,
stop=stop_after_attempt(embeddings.max_retries), stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(llm.error.Timeout) | retry=(
retry_if_exception_type(llm.error.APIError) | retry_if_exception_type(llm.error.Timeout)
retry_if_exception_type(llm.error.APIConnectionError) | | retry_if_exception_type(llm.error.APIError)
retry_if_exception_type(llm.error.RateLimitError) | | retry_if_exception_type(llm.error.APIConnectionError)
retry_if_exception_type(llm.error.ServiceUnavailableError)), | retry_if_exception_type(llm.error.RateLimitError)
| retry_if_exception_type(llm.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
@ -76,16 +76,17 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
reraise=True, reraise=True,
stop=stop_after_attempt(embeddings.max_retries), stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(llm.error.Timeout) | retry=(
retry_if_exception_type(llm.error.APIError) | retry_if_exception_type(llm.error.Timeout)
retry_if_exception_type(llm.error.APIConnectionError) | | retry_if_exception_type(llm.error.APIError)
retry_if_exception_type(llm.error.RateLimitError) | | retry_if_exception_type(llm.error.APIConnectionError)
retry_if_exception_type(llm.error.ServiceUnavailableError)), | retry_if_exception_type(llm.error.RateLimitError)
| retry_if_exception_type(llm.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
def wrap(func: Callable) -> Callable: def wrap(func: Callable) -> Callable:
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
async for _ in async_retrying: async for _ in async_retrying:
return await func(*args, **kwargs) return await func(*args, **kwargs)
@ -117,8 +118,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
return _embed_with_retry(**kwargs) return _embed_with_retry(**kwargs)
async def async_embed_with_retry(embeddings: OpenAIEmbeddings, async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
**kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call.""" """Use tenacity to retry the embedding call."""
@_async_retry_decorator(embeddings) @_async_retry_decorator(embeddings)
@ -225,11 +225,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
warnings.warn( warnings.warn(
f"""WARNING! {field_name} is not default parameter. f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs. {field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended.""") Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name) extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection( invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
extra.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "
@ -242,9 +242,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(values, values["openai_api_key"] = get_from_dict_or_env(
"openai_api_key", values, "openai_api_key", "OPENAI_API_KEY"
"OPENAI_API_KEY") )
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = get_from_dict_or_env(
values, values,
"openai_api_base", "openai_api_base",
@ -284,8 +284,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
values["client"] = llm.Embedding values["client"] = llm.Embedding
except ImportError: except ImportError:
raise ImportError("Could not import openai python package. " raise ImportError(
"Please install it with `pip install openai`.") "Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values return values
@property @property
@ -313,11 +315,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return openai_args return openai_args
def _get_len_safe_embeddings( def _get_len_safe_embeddings(
self, self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
texts: List[str], ) -> List[List[float]]:
*,
engine: str,
chunk_size: Optional[int] = None) -> List[List[float]]:
embeddings: List[List[float]] = [[] for _ in range(len(texts))] embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try: try:
import tiktoken import tiktoken
@ -325,7 +324,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. " "This is needed in order to for OpenAIEmbeddings. "
"Please install it with `pip install tiktoken`.") "Please install it with `pip install tiktoken`."
)
tokens = [] tokens = []
indices = [] indices = []
@ -333,8 +333,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
try: try:
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
logger.warning( logger.warning("Warning: model not found. Using cl100k_base encoding.")
"Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base" model = "cl100k_base"
encoding = tiktoken.get_encoding(model) encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -348,7 +347,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
for j in range(0, len(token), self.embedding_ctx_length): for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j:j + self.embedding_ctx_length]) tokens.append(token[j : j + self.embedding_ctx_length])
indices.append(i) indices.append(i)
batched_embeddings: List[List[float]] = [] batched_embeddings: List[List[float]] = []
@ -367,7 +366,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in _iter: for i in _iter:
response = embed_with_retry( response = embed_with_retry(
self, self,
input=tokens[i:i + _chunk_size], input=tokens[i : i + _chunk_size],
**self._invocation_params, **self._invocation_params,
) )
batched_embeddings.extend(r["embedding"] for r in response["data"]) batched_embeddings.extend(r["embedding"] for r in response["data"])
@ -385,11 +384,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
self, self,
input="", input="",
**self._invocation_params, **self._invocation_params,
)["data"][0]["embedding"] )[
"data"
][0]["embedding"]
else: else:
average = np.average(_result, average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
axis=0,
weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings return embeddings
@ -397,11 +396,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# please refer to # please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
async def _aget_len_safe_embeddings( async def _aget_len_safe_embeddings(
self, self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
texts: List[str], ) -> List[List[float]]:
*,
engine: str,
chunk_size: Optional[int] = None) -> List[List[float]]:
embeddings: List[List[float]] = [[] for _ in range(len(texts))] embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try: try:
import tiktoken import tiktoken
@ -409,7 +405,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. " "This is needed in order to for OpenAIEmbeddings. "
"Please install it with `pip install tiktoken`.") "Please install it with `pip install tiktoken`."
)
tokens = [] tokens = []
indices = [] indices = []
@ -417,8 +414,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
try: try:
encoding = tiktoken.encoding_for_model(model_name) encoding = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
logger.warning( logger.warning("Warning: model not found. Using cl100k_base encoding.")
"Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base" model = "cl100k_base"
encoding = tiktoken.get_encoding(model) encoding = tiktoken.get_encoding(model)
for i, text in enumerate(texts): for i, text in enumerate(texts):
@ -432,7 +428,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
for j in range(0, len(token), self.embedding_ctx_length): for j in range(0, len(token), self.embedding_ctx_length):
tokens.append(token[j:j + self.embedding_ctx_length]) tokens.append(token[j : j + self.embedding_ctx_length])
indices.append(i) indices.append(i)
batched_embeddings: List[List[float]] = [] batched_embeddings: List[List[float]] = []
@ -440,7 +436,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(0, len(tokens), _chunk_size): for i in range(0, len(tokens), _chunk_size):
response = await async_embed_with_retry( response = await async_embed_with_retry(
self, self,
input=tokens[i:i + _chunk_size], input=tokens[i : i + _chunk_size],
**self._invocation_params, **self._invocation_params,
) )
batched_embeddings.extend(r["embedding"] for r in response["data"]) batched_embeddings.extend(r["embedding"] for r in response["data"])
@ -454,22 +450,22 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
for i in range(len(texts)): for i in range(len(texts)):
_result = results[i] _result = results[i]
if len(_result) == 0: if len(_result) == 0:
average = (await async_embed_with_retry( average = (
self, await async_embed_with_retry(
input="", self,
**self._invocation_params, input="",
))["data"][0]["embedding"] **self._invocation_params,
)
)["data"][0]["embedding"]
else: else:
average = np.average(_result, average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
axis=0,
weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist() embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings return embeddings
def embed_documents(self, def embed_documents(
texts: List[str], self, texts: List[str], chunk_size: Optional[int] = 0
chunk_size: Optional[int] = 0) -> List[List[float]]: ) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs. """Call out to OpenAI's embedding endpoint for embedding search docs.
Args: Args:
@ -485,9 +481,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return self._get_len_safe_embeddings(texts, engine=self.deployment) return self._get_len_safe_embeddings(texts, engine=self.deployment)
async def aembed_documents( async def aembed_documents(
self, self, texts: List[str], chunk_size: Optional[int] = 0
texts: List[str], ) -> List[List[float]]:
chunk_size: Optional[int] = 0) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint async for embedding search docs. """Call out to OpenAI's embedding endpoint async for embedding search docs.
Args: Args:
@ -500,8 +495,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
""" """
# NOTE: to keep things simple, we assume the list may contain texts longer # NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function. # than the maximum context and use length-safe embedding function.
return await self._aget_len_safe_embeddings(texts, return await self._aget_len_safe_embeddings(texts, engine=self.deployment)
engine=self.deployment)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.

@ -33,8 +33,9 @@ from langchain.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_token_usage(keys: Set[str], response: Dict[str, Any], def update_token_usage(
token_usage: Dict[str, Any]) -> None: keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
"""Update token usage.""" """Update token usage."""
_keys_to_use = keys.intersection(response["usage"]) _keys_to_use = keys.intersection(response["usage"])
for _key in _keys_to_use: for _key in _keys_to_use:
@ -45,42 +46,44 @@ def update_token_usage(keys: Set[str], response: Dict[str, Any],
def _stream_response_to_generation_chunk( def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],) -> GenerationChunk: stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk.""" """Convert a stream response to a generation chunk."""
return GenerationChunk( return GenerationChunk(
text=stream_response["choices"][0]["text"], text=stream_response["choices"][0]["text"],
generation_info=dict( generation_info=dict(
finish_reason=stream_response["choices"][0].get( finish_reason=stream_response["choices"][0].get("finish_reason", None),
"finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None), logprobs=stream_response["choices"][0].get("logprobs", None),
), ),
) )
def _update_response(response: Dict[str, Any], def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
stream_response: Dict[str, Any]) -> None:
"""Update response from the stream response.""" """Update response from the stream response."""
response["choices"][0]["text"] += stream_response["choices"][0]["text"] response["choices"][0]["text"] += stream_response["choices"][0]["text"]
response["choices"][0]["finish_reason"] = stream_response["choices"][0].get( response["choices"][0]["finish_reason"] = stream_response["choices"][0].get(
"finish_reason", None) "finish_reason", None
response["choices"][0]["logprobs"] = stream_response["choices"][0][ )
"logprobs"] response["choices"][0]["logprobs"] = stream_response["choices"][0]["logprobs"]
def _streaming_response_template() -> Dict[str, Any]: def _streaming_response_template() -> Dict[str, Any]:
return { return {
"choices": [{ "choices": [
"text": "", {
"finish_reason": None, "text": "",
"logprobs": None, "finish_reason": None,
}] "logprobs": None,
}
]
} }
def _create_retry_decorator( def _create_retry_decorator(
llm: Union[BaseOpenAI, OpenAIChat], llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[Union[AsyncCallbackManagerForLLMRun, run_manager: Optional[
CallbackManagerForLLMRun]] = None, Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
import openai import openai
@ -91,9 +94,9 @@ def _create_retry_decorator(
openai.error.RateLimitError, openai.error.RateLimitError,
openai.error.ServiceUnavailableError, openai.error.ServiceUnavailableError,
] ]
return create_base_retry_decorator(error_types=errors, return create_base_retry_decorator(
max_retries=llm.max_retries, error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
run_manager=run_manager) )
def completion_with_retry( def completion_with_retry(
@ -203,8 +206,7 @@ class BaseOpenAI(BaseLLM):
API but with different models. In those cases, in order to avoid erroring API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here.""" when tiktoken is called, you can specify a model name to use here."""
def __new__(cls, def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
**data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore
"""Initialize the OpenAI object.""" """Initialize the OpenAI object."""
data.get("model_name", "") data.get("model_name", "")
return super().__new__(cls) return super().__new__(cls)
@ -219,16 +221,17 @@ class BaseOpenAI(BaseLLM):
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls) all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
values["model_kwargs"] = build_extra_kwargs(extra, values, values["model_kwargs"] = build_extra_kwargs(
all_required_field_names) extra, values, all_required_field_names
)
return values return values
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(values, values["openai_api_key"] = get_from_dict_or_env(
"openai_api_key", values, "openai_api_key", "OPENAI_API_KEY"
"OPENAI_API_KEY") )
values["openai_api_base"] = get_from_dict_or_env( values["openai_api_base"] = get_from_dict_or_env(
values, values,
"openai_api_base", "openai_api_base",
@ -252,8 +255,10 @@ class BaseOpenAI(BaseLLM):
values["client"] = openai.Completion values["client"] = openai.Completion
except ImportError: except ImportError:
raise ImportError("Could not import openai python package. " raise ImportError(
"Please install it with `pip install openai`.") "Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1: if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.") raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1: if values["streaming"] and values["best_of"] > 1:
@ -290,10 +295,9 @@ class BaseOpenAI(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True} params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(self, for stream_resp in completion_with_retry(
prompt=prompt, self, prompt=prompt, run_manager=run_manager, **params
run_manager=run_manager, ):
**params):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
@ -302,7 +306,8 @@ class BaseOpenAI(BaseLLM):
chunk=chunk, chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"] logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info else None, if chunk.generation_info
else None,
) )
async def _astream( async def _astream(
@ -315,7 +320,8 @@ class BaseOpenAI(BaseLLM):
params = {**self._invocation_params, **kwargs, "stream": True} params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params): self, prompt=prompt, run_manager=run_manager, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp) chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
@ -324,7 +330,8 @@ class BaseOpenAI(BaseLLM):
chunk=chunk, chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"] logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info else None, if chunk.generation_info
else None,
) )
def _generate( def _generate(
@ -360,32 +367,30 @@ class BaseOpenAI(BaseLLM):
for _prompts in sub_prompts: for _prompts in sub_prompts:
if self.streaming: if self.streaming:
if len(_prompts) > 1: if len(_prompts) > 1:
raise ValueError( raise ValueError("Cannot stream results with multiple prompts.")
"Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
for chunk in self._stream(_prompts[0], stop, run_manager, for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
**kwargs):
if generation is None: if generation is None:
generation = chunk generation = chunk
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
choices.append({ choices.append(
"text": {
generation.text, "text": generation.text,
"finish_reason": "finish_reason": generation.generation_info.get("finish_reason")
generation.generation_info.get("finish_reason") if generation.generation_info
if generation.generation_info else None, else None,
"logprobs": "logprobs": generation.generation_info.get("logprobs")
generation.generation_info.get("logprobs") if generation.generation_info
if generation.generation_info else None, else None,
}) }
)
else: else:
response = completion_with_retry(self, response = completion_with_retry(
prompt=_prompts, self, prompt=_prompts, run_manager=run_manager, **params
run_manager=run_manager, )
**params)
choices.extend(response["choices"]) choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -409,32 +414,32 @@ class BaseOpenAI(BaseLLM):
for _prompts in sub_prompts: for _prompts in sub_prompts:
if self.streaming: if self.streaming:
if len(_prompts) > 1: if len(_prompts) > 1:
raise ValueError( raise ValueError("Cannot stream results with multiple prompts.")
"Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
async for chunk in self._astream(_prompts[0], stop, run_manager, async for chunk in self._astream(
**kwargs): _prompts[0], stop, run_manager, **kwargs
):
if generation is None: if generation is None:
generation = chunk generation = chunk
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
choices.append({ choices.append(
"text": {
generation.text, "text": generation.text,
"finish_reason": "finish_reason": generation.generation_info.get("finish_reason")
generation.generation_info.get("finish_reason") if generation.generation_info
if generation.generation_info else None, else None,
"logprobs": "logprobs": generation.generation_info.get("logprobs")
generation.generation_info.get("logprobs") if generation.generation_info
if generation.generation_info else None, else None,
}) }
)
else: else:
response = await acompletion_with_retry(self, response = await acompletion_with_retry(
prompt=_prompts, self, prompt=_prompts, run_manager=run_manager, **params
run_manager=run_manager, )
**params)
choices.extend(response["choices"]) choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage) update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
@ -448,35 +453,39 @@ class BaseOpenAI(BaseLLM):
"""Get the sub prompts for llm call.""" """Get the sub prompts for llm call."""
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError( raise ValueError("`stop` found in both the input and default params.")
"`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
if params["max_tokens"] == -1: if params["max_tokens"] == -1:
if len(prompts) != 1: if len(prompts) != 1:
raise ValueError( raise ValueError(
"max_tokens set to -1 not supported for multiple inputs.") "max_tokens set to -1 not supported for multiple inputs."
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [ sub_prompts = [
prompts[i:i + self.batch_size] prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size) for i in range(0, len(prompts), self.batch_size)
] ]
return sub_prompts return sub_prompts
def create_llm_result(self, choices: Any, prompts: List[str], def create_llm_result(
token_usage: Dict[str, int]) -> LLMResult: self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
) -> LLMResult:
"""Create the LLMResult from the choices and prompts.""" """Create the LLMResult from the choices and prompts."""
generations = [] generations = []
for i, _ in enumerate(prompts): for i, _ in enumerate(prompts):
sub_choices = choices[i * self.n:(i + 1) * self.n] sub_choices = choices[i * self.n : (i + 1) * self.n]
generations.append([ generations.append(
Generation( [
text=choice["text"], Generation(
generation_info=dict( text=choice["text"],
finish_reason=choice.get("finish_reason"), generation_info=dict(
logprobs=choice.get("logprobs"), finish_reason=choice.get("finish_reason"),
), logprobs=choice.get("logprobs"),
) for choice in sub_choices ),
]) )
for choice in sub_choices
]
)
llm_output = {"token_usage": token_usage, "model_name": self.model_name} llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return LLMResult(generations=generations, llm_output=llm_output) return LLMResult(generations=generations, llm_output=llm_output)
@ -518,14 +527,14 @@ class BaseOpenAI(BaseLLM):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. " "This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`.") "Please install it with `pip install tiktoken`."
)
model_name = self.tiktoken_model_name or self.model_name model_name = self.tiktoken_model_name or self.model_name
try: try:
enc = tiktoken.encoding_for_model(model_name) enc = tiktoken.encoding_for_model(model_name)
except KeyError: except KeyError:
logger.warning( logger.warning("Warning: model not found. Using cl100k_base encoding.")
"Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base" model = "cl100k_base"
enc = tiktoken.get_encoding(model) enc = tiktoken.get_encoding(model)
@ -587,7 +596,8 @@ class BaseOpenAI(BaseLLM):
if context_size is None: if context_size is None:
raise ValueError( raise ValueError(
f"Unknown model: {modelname}. Please provide a valid OpenAI model name." f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
"Known models are: " + ", ".join(model_token_mapping.keys())) "Known models are: " + ", ".join(model_token_mapping.keys())
)
return context_size return context_size
@ -665,15 +675,14 @@ class AzureOpenAI(BaseOpenAI):
"OPENAI_API_VERSION", "OPENAI_API_VERSION",
) )
values["openai_api_type"] = get_from_dict_or_env( values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", "azure") values, "openai_api_type", "OPENAI_API_TYPE", "azure"
)
return values return values
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return { return {
**{ **{"deployment_name": self.deployment_name},
"deployment_name": self.deployment_name
},
**super()._identifying_params, **super()._identifying_params,
} }
@ -738,9 +747,7 @@ class OpenAIChat(BaseLLM):
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = { all_required_field_names = {field.alias for field in cls.__fields__.values()}
field.alias for field in cls.__fields__.values()
}
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
for field_name in list(values): for field_name in list(values):
@ -754,8 +761,9 @@ class OpenAIChat(BaseLLM):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(values, "openai_api_key", openai_api_key = get_from_dict_or_env(
"OPENAI_API_KEY") values, "openai_api_key", "OPENAI_API_KEY"
)
openai_api_base = get_from_dict_or_env( openai_api_base = get_from_dict_or_env(
values, values,
"openai_api_base", "openai_api_base",
@ -768,10 +776,9 @@ class OpenAIChat(BaseLLM):
"OPENAI_PROXY", "OPENAI_PROXY",
default="", default="",
) )
openai_organization = get_from_dict_or_env(values, openai_organization = get_from_dict_or_env(
"openai_organization", values, "openai_organization", "OPENAI_ORGANIZATION", default=""
"OPENAI_ORGANIZATION", )
default="")
try: try:
import openai import openai
@ -786,15 +793,18 @@ class OpenAIChat(BaseLLM):
"https": openai_proxy, "https": openai_proxy,
} # type: ignore[assignment] # noqa: E501 } # type: ignore[assignment] # noqa: E501
except ImportError: except ImportError:
raise ImportError("Could not import openai python package. " raise ImportError(
"Please install it with `pip install openai`.") "Could not import openai python package. "
"Please install it with `pip install openai`."
)
try: try:
values["client"] = openai.ChatCompletion values["client"] = openai.ChatCompletion
except AttributeError: except AttributeError:
raise ValueError( raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely " "`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it " "due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`.") "with `pip install --upgrade openai`."
)
return values return values
@property @property
@ -802,27 +812,18 @@ class OpenAIChat(BaseLLM):
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
return self.model_kwargs return self.model_kwargs
def _get_chat_params(self, def _get_chat_params(
prompts: List[str], self, prompts: List[str], stop: Optional[List[str]] = None
stop: Optional[List[str]] = None) -> Tuple: ) -> Tuple:
if len(prompts) > 1: if len(prompts) > 1:
raise ValueError( raise ValueError(
f"OpenAIChat currently only supports single prompt, got {prompts}" f"OpenAIChat currently only supports single prompt, got {prompts}"
) )
messages = self.prefix_messages + [{ messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
"role": "user", params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
"content": prompts[0]
}]
params: Dict[str, Any] = {
**{
"model": self.model_name
},
**self._default_params
}
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError( raise ValueError("`stop` found in both the input and default params.")
"`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
if params.get("max_tokens") == -1: if params.get("max_tokens") == -1:
# for ChatGPT api, omitting max_tokens is equivalent to having no limit # for ChatGPT api, omitting max_tokens is equivalent to having no limit
@ -838,10 +839,9 @@ class OpenAIChat(BaseLLM):
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop) messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(self, for stream_resp in completion_with_retry(
messages=messages, self, messages=messages, run_manager=run_manager, **params
run_manager=run_manager, ):
**params):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token) chunk = GenerationChunk(text=token)
yield chunk yield chunk
@ -858,7 +858,8 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params([prompt], stop) messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params): self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "") token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token) chunk = GenerationChunk(text=token)
yield chunk yield chunk
@ -884,19 +885,17 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
full_response = completion_with_retry(self, full_response = completion_with_retry(
messages=messages, self, messages=messages, run_manager=run_manager, **params
run_manager=run_manager, )
**params)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,
} }
return LLMResult( return LLMResult(
generations=[[ generations=[
Generation( [Generation(text=full_response["choices"][0]["message"]["content"])]
text=full_response["choices"][0]["message"]["content"]) ],
]],
llm_output=llm_output, llm_output=llm_output,
) )
@ -909,8 +908,7 @@ class OpenAIChat(BaseLLM):
) -> LLMResult: ) -> LLMResult:
if self.streaming: if self.streaming:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompts[0], stop, run_manager, async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
**kwargs):
if generation is None: if generation is None:
generation = chunk generation = chunk
else: else:
@ -920,19 +918,17 @@ class OpenAIChat(BaseLLM):
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} params = {**params, **kwargs}
full_response = await acompletion_with_retry(self, full_response = await acompletion_with_retry(
messages=messages, self, messages=messages, run_manager=run_manager, **params
run_manager=run_manager, )
**params)
llm_output = { llm_output = {
"token_usage": full_response["usage"], "token_usage": full_response["usage"],
"model_name": self.model_name, "model_name": self.model_name,
} }
return LLMResult( return LLMResult(
generations=[[ generations=[
Generation( [Generation(text=full_response["choices"][0]["message"]["content"])]
text=full_response["choices"][0]["message"]["content"]) ],
]],
llm_output=llm_output, llm_output=llm_output,
) )
@ -957,7 +953,8 @@ class OpenAIChat(BaseLLM):
raise ImportError( raise ImportError(
"Could not import tiktoken python package. " "Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. " "This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`.") "Please install it with `pip install tiktoken`."
)
enc = tiktoken.encoding_for_model(self.model_name) enc = tiktoken.encoding_for_model(self.model_name)
return enc.encode( return enc.encode(

@ -71,15 +71,16 @@ class OpenAITokenizer(BaseTokenizer):
@property @property
def max_tokens(self) -> int: def max_tokens(self) -> int:
tokens = next(v for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() tokens = next(
if self.model.startswith(k)) v
for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
if self.model.startswith(k)
)
offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET
return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset
def count_tokens(self, def count_tokens(self, text: str | list, model: Optional[str] = None) -> int:
text: str | list,
model: Optional[str] = None) -> int:
""" """
Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@ -95,12 +96,12 @@ class OpenAITokenizer(BaseTokenizer):
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
if model in { if model in {
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k-0613",
"gpt-4-0314", "gpt-4-0314",
"gpt-4-32k-0314", "gpt-4-32k-0314",
"gpt-4-0613", "gpt-4-0613",
"gpt-4-32k-0613", "gpt-4-32k-0613",
}: }:
tokens_per_message = 3 tokens_per_message = 3
tokens_per_name = 1 tokens_per_name = 1
@ -112,18 +113,21 @@ class OpenAITokenizer(BaseTokenizer):
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logging.info( logging.info(
"gpt-3.5-turbo may update over time. Returning num tokens assuming" "gpt-3.5-turbo may update over time. Returning num tokens assuming"
" gpt-3.5-turbo-0613.") " gpt-3.5-turbo-0613."
)
return self.count_tokens(text, model="gpt-3.5-turbo-0613") return self.count_tokens(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model: elif "gpt-4" in model:
logging.info( logging.info(
"gpt-4 may update over time. Returning num tokens assuming" "gpt-4 may update over time. Returning num tokens assuming"
" gpt-4-0613.") " gpt-4-0613."
)
return self.count_tokens(text, model="gpt-4-0613") return self.count_tokens(text, model="gpt-4-0613")
else: else:
raise NotImplementedError( raise NotImplementedError(
f"""token_count() is not implemented for model {model}. f"""token_count() is not implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for See https://github.com/openai/openai-python/blob/main/chatml.md for
information on how messages are converted to tokens.""") information on how messages are converted to tokens."""
)
num_tokens = 0 num_tokens = 0
@ -140,5 +144,5 @@ class OpenAITokenizer(BaseTokenizer):
return num_tokens return num_tokens
else: else:
return len( return len(
self.encoding.encode(text, self.encoding.encode(text, allowed_special=set(self.stop_sequences))
allowed_special=set(self.stop_sequences))) )

@ -26,7 +26,8 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import google-api-core python package. " "Could not import google-api-core python package. "
"Please install it with `pip install google-api-core`.") "Please install it with `pip install google-api-core`."
)
multiplier = 2 multiplier = 2
min_seconds = 1 min_seconds = 1
@ -36,15 +37,12 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
return retry( return retry(
reraise=True, reraise=True,
stop=stop_after_attempt(max_retries), stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
min=min_seconds, retry=(
max=max_seconds), retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
retry=(retry_if_exception_type( | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
google.api_core.exceptions.ResourceExhausted) | | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
retry_if_exception_type( ),
google.api_core.exceptions.ServiceUnavailable) |
retry_if_exception_type(
google.api_core.exceptions.GoogleAPIError)),
before_sleep=before_sleep_log(logger, logging.WARNING), before_sleep=before_sleep_log(logger, logging.WARNING),
) )
@ -66,8 +64,7 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
The PaLM API will sometimes erroneously return a single leading space in all The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space. lines > 1. This function strips that space.
""" """
has_leading_space = all( has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space: if has_leading_space:
return text.replace("\n ", "\n") return text.replace("\n ", "\n")
else: else:
@ -100,8 +97,9 @@ class GooglePalm(BaseLLM, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists.""" """Validate api key, python package exists."""
google_api_key = get_from_dict_or_env(values, "google_api_key", google_api_key = get_from_dict_or_env(
"GOOGLE_API_KEY") values, "google_api_key", "GOOGLE_API_KEY"
)
try: try:
import google.generativeai as genai import google.generativeai as genai
@ -109,12 +107,12 @@ class GooglePalm(BaseLLM, BaseModel):
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import google-generativeai python package. " "Could not import google-generativeai python package. "
"Please install it with `pip install google-generativeai`.") "Please install it with `pip install google-generativeai`."
)
values["client"] = genai values["client"] = genai
if values["temperature"] is not None and not 0 <= values[ if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
"temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]") raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
@ -123,8 +121,7 @@ class GooglePalm(BaseLLM, BaseModel):
if values["top_k"] is not None and values["top_k"] <= 0: if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive") raise ValueError("top_k must be positive")
if values["max_output_tokens"] is not None and values[ if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
"max_output_tokens"] <= 0:
raise ValueError("max_output_tokens must be greater than zero") raise ValueError("max_output_tokens must be greater than zero")
return values return values

@ -33,10 +33,9 @@ class PegasusEmbedding:
""" """
def __init__(self, def __init__(
modality: str, self, modality: str, multi_process: bool = False, n_processes: int = 4
multi_process: bool = False, ):
n_processes: int = 4):
self.modality = modality self.modality = modality
self.multi_process = multi_process self.multi_process = multi_process
self.n_processes = n_processes self.n_processes = n_processes
@ -44,7 +43,8 @@ class PegasusEmbedding:
self.pegasus = Pegasus(modality, multi_process, n_processes) self.pegasus = Pegasus(modality, multi_process, n_processes)
except Exception as e: except Exception as e:
logging.error( logging.error(
f"Failed to initialize Pegasus with modality: {modality}: {e}") f"Failed to initialize Pegasus with modality: {modality}: {e}"
)
raise raise
def embed(self, data: Union[str, list[str]]): def embed(self, data: Union[str, list[str]]):

@ -21,4 +21,6 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
return openai.Embedding.create( return openai.Embedding.create(
input=[text], input=[text],
model=model, model=model,
)["data"][0]["embedding"] )["data"][
0
]["embedding"]

@ -90,17 +90,17 @@ class SpeechT5:
self.processor = SpeechT5Processor.from_pretrained(self.model_name) self.processor = SpeechT5Processor.from_pretrained(self.model_name)
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name) self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_name)
self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name) self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_name)
self.embeddings_dataset = load_dataset(self.dataset_name, self.embeddings_dataset = load_dataset(self.dataset_name, split="validation")
split="validation")
def __call__(self, text: str, speaker_id: float = 7306): def __call__(self, text: str, speaker_id: float = 7306):
"""Call the model on some text and return the speech.""" """Call the model on some text and return the speech."""
speaker_embedding = torch.tensor( speaker_embedding = torch.tensor(
self.embeddings_dataset[speaker_id]["xvector"]).unsqueeze(0) self.embeddings_dataset[speaker_id]["xvector"]
).unsqueeze(0)
inputs = self.processor(text=text, return_tensors="pt") inputs = self.processor(text=text, return_tensors="pt")
speech = self.model.generate_speech(inputs["input_ids"], speech = self.model.generate_speech(
speaker_embedding, inputs["input_ids"], speaker_embedding, vocoder=self.vocoder
vocoder=self.vocoder) )
return speech return speech
def save_speech(self, speech, filename="speech.wav"): def save_speech(self, speech, filename="speech.wav"):
@ -121,8 +121,7 @@ class SpeechT5:
def set_embeddings_dataset(self, dataset_name): def set_embeddings_dataset(self, dataset_name):
"""Set the embeddings dataset to a new dataset.""" """Set the embeddings dataset to a new dataset."""
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.embeddings_dataset = load_dataset(self.dataset_name, self.embeddings_dataset = load_dataset(self.dataset_name, split="validation")
split="validation")
# Feature 1: Get sampling rate # Feature 1: Get sampling rate
def get_sampling_rate(self): def get_sampling_rate(self):

@ -50,8 +50,9 @@ class TimmModel:
in_chans=model_info.in_chans, in_chans=model_info.in_chans,
) )
def __call__(self, model_info: TimmModelInfo, def __call__(
input_tensor: torch.Tensor) -> torch.Size: self, model_info: TimmModelInfo, input_tensor: torch.Tensor
) -> torch.Size:
""" """
Create and run a model specified by `model_info` on `input_tensor`. Create and run a model specified by `model_info` on `input_tensor`.

@ -10,8 +10,9 @@ import requests
class TrOCR: class TrOCR:
def __init__(
def __init__(self,): self,
):
pass pass
def __call__(self): def __call__(self):

@ -23,9 +23,11 @@ class Vilt:
def __init__(self): def __init__(self):
self.processor = ViltProcessor.from_pretrained( self.processor = ViltProcessor.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa") "dandelin/vilt-b32-finetuned-vqa"
)
self.model = ViltForQuestionAnswering.from_pretrained( self.model = ViltForQuestionAnswering.from_pretrained(
"dandelin/vilt-b32-finetuned-vqa") "dandelin/vilt-b32-finetuned-vqa"
)
def __call__(self, text: str, image_url: str): def __call__(self, text: str, image_url: str):
""" """

@ -33,8 +33,7 @@ class WizardLLMStoryTeller:
def __init__( def __init__(
self, self,
model_id: model_id: str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF",
str = "TheBloke/WizardLM-Uncensored-SuperCOT-StoryTelling-30B-GGUF",
device: str = None, device: str = None,
max_length: int = 500, max_length: int = 500,
quantize: bool = False, quantize: bool = False,
@ -45,8 +44,9 @@ class WizardLLMStoryTeller:
decoding=False, decoding=False,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = (device if device else self.device = (
("cuda" if torch.cuda.is_available() else "cpu")) device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -56,8 +56,9 @@ class WizardLLMStoryTeller:
# self.log = Logging() # self.log = Logging()
if self.distributed: if self.distributed:
assert (torch.cuda.device_count() > assert (
1), "You need more than 1 gpu for distributed processing" torch.cuda.device_count() > 1
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -73,7 +74,8 @@ class WizardLLMStoryTeller:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config) self.model_id, quantization_config=bnb_config
)
self.model # .to(self.device) self.model # .to(self.device)
except Exception as e: except Exception as e:
@ -86,18 +88,20 @@ class WizardLLMStoryTeller:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = (BitsAndBytesConfig(**self.quantization_config) bnb_config = (
if self.quantization_config else None) BitsAndBytesConfig(**self.quantization_config)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id, quantization_config=bnb_config
quantization_config=bnb_config).to(self.device) ).to(self.device)
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error( self.logger.error(f"Failed to load the model or the tokenizer: {error}")
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, prompt_text: str): def run(self, prompt_text: str):
@ -116,8 +120,9 @@ class WizardLLMStoryTeller:
max_length = self.max_length max_length = self.max_length
try: try:
inputs = self.tokenizer.encode(prompt_text, inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(
return_tensors="pt").to(self.device) self.device
)
# self.log.start() # self.log.start()
@ -126,26 +131,26 @@ class WizardLLMStoryTeller:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=len(inputs) + inputs, max_length=len(inputs) + 1, do_sample=True
1, )
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode([output_tokens], self.tokenizer.decode(
skip_special_tokens=True), [output_tokens], skip_special_tokens=True
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=max_length, inputs, max_length=max_length, do_sample=True
do_sample=True) )
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
@ -169,8 +174,9 @@ class WizardLLMStoryTeller:
max_length = self.max_ max_length = self.max_
try: try:
inputs = self.tokenizer.encode(prompt_text, inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(
return_tensors="pt").to(self.device) self.device
)
# self.log.start() # self.log.start()
@ -179,26 +185,26 @@ class WizardLLMStoryTeller:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=len(inputs) + inputs, max_length=len(inputs) + 1, do_sample=True
1, )
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode([output_tokens], self.tokenizer.decode(
skip_special_tokens=True), [output_tokens], skip_special_tokens=True
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=max_length, inputs, max_length=max_length, do_sample=True
do_sample=True) )
del inputs del inputs

@ -44,8 +44,9 @@ class YarnMistral128:
decoding=False, decoding=False,
): ):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = (device if device else self.device = (
("cuda" if torch.cuda.is_available() else "cpu")) device if device else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model_id = model_id self.model_id = model_id
self.max_length = max_length self.max_length = max_length
self.verbose = verbose self.verbose = verbose
@ -55,8 +56,9 @@ class YarnMistral128:
# self.log = Logging() # self.log = Logging()
if self.distributed: if self.distributed:
assert (torch.cuda.device_count() > assert (
1), "You need more than 1 gpu for distributed processing" torch.cuda.device_count() > 1
), "You need more than 1 gpu for distributed processing"
bnb_config = None bnb_config = None
if quantize: if quantize:
@ -91,18 +93,20 @@ class YarnMistral128:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
bnb_config = (BitsAndBytesConfig(**self.quantization_config) bnb_config = (
if self.quantization_config else None) BitsAndBytesConfig(**self.quantization_config)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id, quantization_config=bnb_config
quantization_config=bnb_config).to(self.device) ).to(self.device)
if self.distributed: if self.distributed:
self.model = DDP(self.model) self.model = DDP(self.model)
except Exception as error: except Exception as error:
self.logger.error( self.logger.error(f"Failed to load the model or the tokenizer: {error}")
f"Failed to load the model or the tokenizer: {error}")
raise raise
def run(self, prompt_text: str): def run(self, prompt_text: str):
@ -121,8 +125,9 @@ class YarnMistral128:
max_length = self.max_length max_length = self.max_length
try: try:
inputs = self.tokenizer.encode(prompt_text, inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(
return_tensors="pt").to(self.device) self.device
)
# self.log.start() # self.log.start()
@ -131,26 +136,26 @@ class YarnMistral128:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=len(inputs) + inputs, max_length=len(inputs) + 1, do_sample=True
1, )
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode([output_tokens], self.tokenizer.decode(
skip_special_tokens=True), [output_tokens], skip_special_tokens=True
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=max_length, inputs, max_length=max_length, do_sample=True
do_sample=True) )
del inputs del inputs
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
@ -197,8 +202,9 @@ class YarnMistral128:
max_length = self.max_ max_length = self.max_
try: try:
inputs = self.tokenizer.encode(prompt_text, inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(
return_tensors="pt").to(self.device) self.device
)
# self.log.start() # self.log.start()
@ -207,26 +213,26 @@ class YarnMistral128:
for _ in range(max_length): for _ in range(max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=len(inputs) + inputs, max_length=len(inputs) + 1, do_sample=True
1, )
do_sample=True)
output_tokens = outputs[0][-1] output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item()) output_sequence.append(output_tokens.item())
# print token in real-time # print token in real-time
print( print(
self.tokenizer.decode([output_tokens], self.tokenizer.decode(
skip_special_tokens=True), [output_tokens], skip_special_tokens=True
),
end="", end="",
flush=True, flush=True,
) )
inputs = outputs inputs = outputs
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(inputs, outputs = self.model.generate(
max_length=max_length, inputs, max_length=max_length, do_sample=True
do_sample=True) )
del inputs del inputs

@ -28,8 +28,7 @@ class Zephyr:
model_name: str = "HuggingFaceH4/zephyr-7b-alpha", model_name: str = "HuggingFaceH4/zephyr-7b-alpha",
tokenize: bool = False, tokenize: bool = False,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
system_prompt: system_prompt: str = "You are a friendly chatbot who always responds in the style of a pirate",
str = "You are a friendly chatbot who always responds in the style of a pirate",
max_new_tokens: int = 300, max_new_tokens: int = 300,
temperature: float = 0.5, temperature: float = 0.5,
top_k: float = 50, top_k: float = 50,

@ -24,8 +24,9 @@ class AgentOutputParser(BaseAgentOutputParser):
@staticmethod @staticmethod
def _preprocess_json_input(input_str: str) -> str: def _preprocess_json_input(input_str: str) -> str:
corrected_str = re.sub(r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', corrected_str = re.sub(
r"\\\\", input_str) r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
)
return corrected_str return corrected_str
def _parse_json(self, text: str) -> dict: def _parse_json(self, text: str) -> dict:

@ -13,23 +13,13 @@ class PromptGenerator:
self.performance_evaluation: List[str] = [] self.performance_evaluation: List[str] = []
self.response_format = { self.response_format = {
"thoughts": { "thoughts": {
"text": "text": "thought",
"thought", "reasoning": "reasoning",
"reasoning": "plan": "- short bulleted\n- list that conveys\n- long-term plan",
"reasoning", "criticism": "constructive self-criticism",
"plan": "speak": "thoughts summary to say to user",
"- short bulleted\n- list that conveys\n- long-term plan",
"criticism":
"constructive self-criticism",
"speak":
"thoughts summary to say to user",
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}, },
"command": {"name": "command name", "args": {"arg name": "value"}},
} }
def add_constraint(self, constraint: str) -> None: def add_constraint(self, constraint: str) -> None:
@ -82,6 +72,7 @@ class PromptGenerator:
f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n" f"Performance Evaluation:\n{''.join(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below " "You should only respond in JSON format as described below "
f"\nResponse Format: \n{formatted_response_format} " f"\nResponse Format: \n{formatted_response_format} "
"\nEnsure the response can be parsed by Python json.loads") "\nEnsure the response can be parsed by Python json.loads"
)
return prompt_string return prompt_string

@ -7,21 +7,25 @@ def generate_agent_role_prompt(agent):
"Finance Agent": ( "Finance Agent": (
"You are a seasoned finance analyst AI assistant. Your primary goal is to" "You are a seasoned finance analyst AI assistant. Your primary goal is to"
" compose comprehensive, astute, impartial, and methodically arranged" " compose comprehensive, astute, impartial, and methodically arranged"
" financial reports based on provided data and trends."), " financial reports based on provided data and trends."
),
"Travel Agent": ( "Travel Agent": (
"You are a world-travelled AI tour guide assistant. Your main purpose is to" "You are a world-travelled AI tour guide assistant. Your main purpose is to"
" draft engaging, insightful, unbiased, and well-structured travel reports" " draft engaging, insightful, unbiased, and well-structured travel reports"
" on given locations, including history, attractions, and cultural" " on given locations, including history, attractions, and cultural"
" insights."), " insights."
),
"Academic Research Agent": ( "Academic Research Agent": (
"You are an AI academic research assistant. Your primary responsibility is" "You are an AI academic research assistant. Your primary responsibility is"
" to create thorough, academically rigorous, unbiased, and systematically" " to create thorough, academically rigorous, unbiased, and systematically"
" organized reports on a given research topic, following the standards of" " organized reports on a given research topic, following the standards of"
" scholarly work."), " scholarly work."
),
"Default Agent": ( "Default Agent": (
"You are an AI critical thinker research assistant. Your sole purpose is to" "You are an AI critical thinker research assistant. Your sole purpose is to"
" write well written, critically acclaimed, objective and structured" " write well written, critically acclaimed, objective and structured"
" reports on given text."), " reports on given text."
),
} }
return prompts.get(agent, "No such agent") return prompts.get(agent, "No such agent")
@ -40,7 +44,8 @@ def generate_report_prompt(question, research_summary):
" focus on the answer to the question, should be well structured, informative," " focus on the answer to the question, should be well structured, informative,"
" in depth, with facts and numbers if available, a minimum of 1,200 words and" " in depth, with facts and numbers if available, a minimum of 1,200 words and"
" with markdown syntax and apa format. Write all source urls at the end of the" " with markdown syntax and apa format. Write all source urls at the end of the"
" report in apa format") " report in apa format"
)
def generate_search_queries_prompt(question): def generate_search_queries_prompt(question):
@ -52,7 +57,8 @@ def generate_search_queries_prompt(question):
return ( return (
"Write 4 google search queries to search online that form an objective opinion" "Write 4 google search queries to search online that form an objective opinion"
f' from the following: "{question}"You must respond with a list of strings in' f' from the following: "{question}"You must respond with a list of strings in'
' the following format: ["query 1", "query 2", "query 3", "query 4"]') ' the following format: ["query 1", "query 2", "query 3", "query 4"]'
)
def generate_resource_report_prompt(question, research_summary): def generate_resource_report_prompt(question, research_summary):
@ -74,7 +80,8 @@ def generate_resource_report_prompt(question, research_summary):
" significance of each source. Ensure that the report is well-structured," " significance of each source. Ensure that the report is well-structured,"
" informative, in-depth, and follows Markdown syntax. Include relevant facts," " informative, in-depth, and follows Markdown syntax. Include relevant facts,"
" figures, and numbers whenever available. The report should have a minimum" " figures, and numbers whenever available. The report should have a minimum"
" length of 1,200 words.") " length of 1,200 words."
)
def generate_outline_report_prompt(question, research_summary): def generate_outline_report_prompt(question, research_summary):
@ -91,7 +98,8 @@ def generate_outline_report_prompt(question, research_summary):
" research report, including the main sections, subsections, and key points to" " research report, including the main sections, subsections, and key points to"
" be covered. The research report should be detailed, informative, in-depth," " be covered. The research report should be detailed, informative, in-depth,"
" and a minimum of 1,200 words. Use appropriate Markdown syntax to format the" " and a minimum of 1,200 words. Use appropriate Markdown syntax to format the"
" outline and ensure readability.") " outline and ensure readability."
)
def generate_concepts_prompt(question, research_summary): def generate_concepts_prompt(question, research_summary):
@ -106,7 +114,8 @@ def generate_concepts_prompt(question, research_summary):
" main concepts to learn for a research report on the following question or" " main concepts to learn for a research report on the following question or"
f' topic: "{question}". The outline should provide a well-structured' f' topic: "{question}". The outline should provide a well-structured'
" frameworkYou must respond with a list of strings in the following format:" " frameworkYou must respond with a list of strings in the following format:"
' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]') ' ["concepts 1", "concepts 2", "concepts 3", "concepts 4, concepts 5"]'
)
def generate_lesson_prompt(concept): def generate_lesson_prompt(concept):
@ -122,7 +131,8 @@ def generate_lesson_prompt(concept):
f"generate a comprehensive lesson about {concept} in Markdown syntax. This" f"generate a comprehensive lesson about {concept} in Markdown syntax. This"
f" should include the definitionof {concept}, its historical background and" f" should include the definitionof {concept}, its historical background and"
" development, its applications or uses in differentfields, and notable events" " development, its applications or uses in differentfields, and notable events"
f" or facts related to {concept}.") f" or facts related to {concept}."
)
return prompt return prompt

@ -11,9 +11,9 @@ if TYPE_CHECKING:
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
def get_buffer_string(messages: Sequence[BaseMessage], def get_buffer_string(
human_prefix: str = "Human", messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
ai_prefix: str = "AI") -> str: ) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string. """Convert sequence of Messages to strings and concatenate them into one string.
Args: Args:
@ -88,9 +88,9 @@ class BaseMessage(Serializable):
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
def _merge_kwargs_dict(
def _merge_kwargs_dict(self, left: Dict[str, Any], self, left: Dict[str, Any], right: Dict[str, Any]
right: Dict[str, Any]) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one.""" """Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy() merged = left.copy()
for k, v in right.items(): for k, v in right.items():
@ -99,7 +99,8 @@ class BaseMessageChunk(BaseMessage):
elif not isinstance(merged[k], type(v)): elif not isinstance(merged[k], type(v)):
raise ValueError( raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,' f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type.") " but with a different type."
)
elif isinstance(merged[k], str): elif isinstance(merged[k], str):
merged[k] += v merged[k] += v
elif isinstance(merged[k], dict): elif isinstance(merged[k], dict):
@ -118,12 +119,15 @@ class BaseMessageChunk(BaseMessage):
return self.__class__( return self.__class__(
content=self.content + other.content, content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict( additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs), self.additional_kwargs, other.additional_kwargs
),
) )
else: else:
raise TypeError('unsupported operand type(s) for +: "' raise TypeError(
f"{self.__class__.__name__}" 'unsupported operand type(s) for +: "'
f'" and "{other.__class__.__name__}"') f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage): class HumanMessage(BaseMessage):

@ -66,10 +66,9 @@ class SystemMessage(Message):
of input messages. of input messages.
""" """
def __init__(self, def __init__(
content: str, self, content: str, role: str = "System", additional_kwargs: Dict = None
role: str = "System", ):
additional_kwargs: Dict = None):
super().__init__(content, role, additional_kwargs) super().__init__(content, role, additional_kwargs)
def get_type(self) -> str: def get_type(self) -> str:
@ -107,9 +106,9 @@ class ChatMessage(Message):
return "chat" return "chat"
def get_buffer_string(messages: Sequence[Message], def get_buffer_string(
human_prefix: str = "Human", messages: Sequence[Message], human_prefix: str = "Human", ai_prefix: str = "AI"
ai_prefix: str = "AI") -> str: ) -> str:
string_messages = [] string_messages = []
for m in messages: for m in messages:
message = f"{m.role}: {m.content}" message = f"{m.role}: {m.content}"

@ -38,6 +38,7 @@ def debate_monitor(game_description, word_limit, character_names):
return prompt return prompt
def generate_character_header(game_description, topic, character_name, def generate_character_header(
character_description): game_description, topic, character_name, character_description
):
pass pass

@ -1,6 +1,7 @@
ERROR_PROMPT = ( ERROR_PROMPT = (
"An error has occurred for the following text: \n{promptedQuery} Please explain" "An error has occurred for the following text: \n{promptedQuery} Please explain"
" this error.\n {e}") " this error.\n {e}"
)
IMAGE_PROMPT = """ IMAGE_PROMPT = """
provide a figure named {filename}. The description is: {description}. provide a figure named {filename}. The description is: {description}.

@ -3,25 +3,30 @@ PY_REFLEXION_COMPLETION_INSTRUCTION = (
"You are a Python writing assistant. You will be given your past function" "You are a Python writing assistant. You will be given your past function"
" implementation, a series of unit tests, and a hint to change the implementation" " implementation, a series of unit tests, and a hint to change the implementation"
" appropriately. Write your full implementation (restate the function" " appropriately. Write your full implementation (restate the function"
" signature).\n\n-----") " signature).\n\n-----"
)
PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = ( PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = (
"You are a Python writing assistant. You will be given a function implementation" "You are a Python writing assistant. You will be given a function implementation"
" and a series of unit tests. Your goal is to write a few sentences to explain why" " and a series of unit tests. Your goal is to write a few sentences to explain why"
" your implementation is wrong as indicated by the tests. You will need this as a" " your implementation is wrong as indicated by the tests. You will need this as a"
" hint when you try again later. Only provide the few sentence description in your" " hint when you try again later. Only provide the few sentence description in your"
" answer, not the implementation.\n\n-----") " answer, not the implementation.\n\n-----"
)
USE_PYTHON_CODEBLOCK_INSTRUCTION = ( USE_PYTHON_CODEBLOCK_INSTRUCTION = (
"Use a Python code block to write your response. For" "Use a Python code block to write your response. For"
" example:\n```python\nprint('Hello world!')\n```") " example:\n```python\nprint('Hello world!')\n```"
)
PY_SIMPLE_CHAT_INSTRUCTION = ( PY_SIMPLE_CHAT_INSTRUCTION = (
"You are an AI that only responds with python code, NOT ENGLISH. You will be given" "You are an AI that only responds with python code, NOT ENGLISH. You will be given"
" a function signature and its docstring by the user. Write your full" " a function signature and its docstring by the user. Write your full"
" implementation (restate the function signature).") " implementation (restate the function signature)."
)
PY_SIMPLE_CHAT_INSTRUCTION_V2 = ( PY_SIMPLE_CHAT_INSTRUCTION_V2 = (
"You are an AI that only responds with only python code. You will be given a" "You are an AI that only responds with only python code. You will be given a"
" function signature and its docstring by the user. Write your full implementation" " function signature and its docstring by the user. Write your full implementation"
" (restate the function signature).") " (restate the function signature)."
)
PY_REFLEXION_CHAT_INSTRUCTION = ( PY_REFLEXION_CHAT_INSTRUCTION = (
"You are an AI Python assistant. You will be given your past function" "You are an AI Python assistant. You will be given your past function"
" implementation, a series of unit tests, and a hint to change the implementation" " implementation, a series of unit tests, and a hint to change the implementation"
@ -31,7 +36,8 @@ PY_REFLEXION_CHAT_INSTRUCTION_V2 = (
"You are an AI Python assistant. You will be given your previous implementation of" "You are an AI Python assistant. You will be given your previous implementation of"
" a function, a series of unit tests results, and your self-reflection on your" " a function, a series of unit tests results, and your self-reflection on your"
" previous implementation. Write your full implementation (restate the function" " previous implementation. Write your full implementation (restate the function"
" signature).") " signature)."
)
PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: PY_REFLEXION_FEW_SHOT_ADD = '''Example 1:
[previous impl]: [previous impl]:
```python ```python
@ -169,14 +175,16 @@ PY_SELF_REFLECTION_CHAT_INSTRUCTION = (
" implementation and a series of unit tests. Your goal is to write a few sentences" " implementation and a series of unit tests. Your goal is to write a few sentences"
" to explain why your implementation is wrong as indicated by the tests. You will" " to explain why your implementation is wrong as indicated by the tests. You will"
" need this as a hint when you try again later. Only provide the few sentence" " need this as a hint when you try again later. Only provide the few sentence"
" description in your answer, not the implementation.") " description in your answer, not the implementation."
)
PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = ( PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = (
"You are a Python programming assistant. You will be given a function" "You are a Python programming assistant. You will be given a function"
" implementation and a series of unit test results. Your goal is to write a few" " implementation and a series of unit test results. Your goal is to write a few"
" sentences to explain why your implementation is wrong as indicated by the tests." " sentences to explain why your implementation is wrong as indicated by the tests."
" You will need this as guidance when you try again later. Only provide the few" " You will need this as guidance when you try again later. Only provide the few"
" sentence description in your answer, not the implementation. You will be given a" " sentence description in your answer, not the implementation. You will be given a"
" few examples by the user.") " few examples by the user."
)
PY_SELF_REFLECTION_FEW_SHOT = """Example 1: PY_SELF_REFLECTION_FEW_SHOT = """Example 1:
[function impl]: [function impl]:
```python ```python

@ -3,29 +3,36 @@ conversation_stages = {
"Introduction: Start the conversation by introducing yourself and your company." "Introduction: Start the conversation by introducing yourself and your company."
" Be polite and respectful while keeping the tone of the conversation" " Be polite and respectful while keeping the tone of the conversation"
" professional. Your greeting should be welcoming. Always clarify in your" " professional. Your greeting should be welcoming. Always clarify in your"
" greeting the reason why you are contacting the prospect."), " greeting the reason why you are contacting the prospect."
),
"2": ( "2": (
"Qualification: Qualify the prospect by confirming if they are the right person" "Qualification: Qualify the prospect by confirming if they are the right person"
" to talk to regarding your product/service. Ensure that they have the" " to talk to regarding your product/service. Ensure that they have the"
" authority to make purchasing decisions."), " authority to make purchasing decisions."
),
"3": ( "3": (
"Value proposition: Briefly explain how your product/service can benefit the" "Value proposition: Briefly explain how your product/service can benefit the"
" prospect. Focus on the unique selling points and value proposition of your" " prospect. Focus on the unique selling points and value proposition of your"
" product/service that sets it apart from competitors."), " product/service that sets it apart from competitors."
),
"4": ( "4": (
"Needs analysis: Ask open-ended questions to uncover the prospect's needs and" "Needs analysis: Ask open-ended questions to uncover the prospect's needs and"
" pain points. Listen carefully to their responses and take notes."), " pain points. Listen carefully to their responses and take notes."
"5": ("Solution presentation: Based on the prospect's needs, present your" ),
" product/service as the solution that can address their pain points." "5": (
), "Solution presentation: Based on the prospect's needs, present your"
"6": " product/service as the solution that can address their pain points."
("Objection handling: Address any objections that the prospect may have" ),
" regarding your product/service. Be prepared to provide evidence or" "6": (
" testimonials to support your claims."), "Objection handling: Address any objections that the prospect may have"
" regarding your product/service. Be prepared to provide evidence or"
" testimonials to support your claims."
),
"7": ( "7": (
"Close: Ask for the sale by proposing a next step. This could be a demo, a" "Close: Ask for the sale by proposing a next step. This could be a demo, a"
" trial or a meeting with decision-makers. Ensure to summarize what has been" " trial or a meeting with decision-makers. Ensure to summarize what has been"
" discussed and reiterate the benefits."), " discussed and reiterate the benefits."
),
} }
SALES_AGENT_TOOLS_PROMPT = """ SALES_AGENT_TOOLS_PROMPT = """

@ -49,27 +49,34 @@ conversation_stages = {
"Introduction: Start the conversation by introducing yourself and your company." "Introduction: Start the conversation by introducing yourself and your company."
" Be polite and respectful while keeping the tone of the conversation" " Be polite and respectful while keeping the tone of the conversation"
" professional. Your greeting should be welcoming. Always clarify in your" " professional. Your greeting should be welcoming. Always clarify in your"
" greeting the reason why you are contacting the prospect."), " greeting the reason why you are contacting the prospect."
),
"2": ( "2": (
"Qualification: Qualify the prospect by confirming if they are the right person" "Qualification: Qualify the prospect by confirming if they are the right person"
" to talk to regarding your product/service. Ensure that they have the" " to talk to regarding your product/service. Ensure that they have the"
" authority to make purchasing decisions."), " authority to make purchasing decisions."
),
"3": ( "3": (
"Value proposition: Briefly explain how your product/service can benefit the" "Value proposition: Briefly explain how your product/service can benefit the"
" prospect. Focus on the unique selling points and value proposition of your" " prospect. Focus on the unique selling points and value proposition of your"
" product/service that sets it apart from competitors."), " product/service that sets it apart from competitors."
),
"4": ( "4": (
"Needs analysis: Ask open-ended questions to uncover the prospect's needs and" "Needs analysis: Ask open-ended questions to uncover the prospect's needs and"
" pain points. Listen carefully to their responses and take notes."), " pain points. Listen carefully to their responses and take notes."
"5": ("Solution presentation: Based on the prospect's needs, present your" ),
" product/service as the solution that can address their pain points." "5": (
), "Solution presentation: Based on the prospect's needs, present your"
"6": " product/service as the solution that can address their pain points."
("Objection handling: Address any objections that the prospect may have" ),
" regarding your product/service. Be prepared to provide evidence or" "6": (
" testimonials to support your claims."), "Objection handling: Address any objections that the prospect may have"
" regarding your product/service. Be prepared to provide evidence or"
" testimonials to support your claims."
),
"7": ( "7": (
"Close: Ask for the sale by proposing a next step. This could be a demo, a" "Close: Ask for the sale by proposing a next step. This could be a demo, a"
" trial or a meeting with decision-makers. Ensure to summarize what has been" " trial or a meeting with decision-makers. Ensure to summarize what has been"
" discussed and reiterate the benefits."), " discussed and reiterate the benefits."
),
} }

@ -18,11 +18,13 @@ class ChatbotError(Exception):
def __init__(self, *args: object) -> None: def __init__(self, *args: object) -> None:
if SUPPORT_ADD_NOTES: if SUPPORT_ADD_NOTES:
super().add_note((
"Please check that the input is correct, or you can resolve this"
" issue by filing an issue"),)
super().add_note( super().add_note(
"Project URL: https://github.com/acheong08/ChatGPT") (
"Please check that the input is correct, or you can resolve this"
" issue by filing an issue"
),
)
super().add_note("Project URL: https://github.com/acheong08/ChatGPT")
super().__init__(*args) super().__init__(*args)

@ -63,8 +63,9 @@ class BaseDocumentTransformer(ABC):
""" # noqa: E501 """ # noqa: E501
@abstractmethod @abstractmethod
def transform_documents(self, documents: Sequence[Document], def transform_documents(
**kwargs: Any) -> Sequence[Document]: self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents. """Transform a list of documents.
Args: Args:
@ -74,8 +75,9 @@ class BaseDocumentTransformer(ABC):
A list of transformed Documents. A list of transformed Documents.
""" """
async def atransform_documents(self, documents: Sequence[Document], async def atransform_documents(
**kwargs: Any) -> Sequence[Document]: self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents. """Asynchronously transform a list of documents.
Args: Args:
@ -85,4 +87,5 @@ class BaseDocumentTransformer(ABC):
A list of transformed Documents. A list of transformed Documents.
""" """
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.transform_documents, **kwargs), documents) None, partial(self.transform_documents, **kwargs), documents
)

@ -100,7 +100,7 @@ class Flow:
self, self,
llm: Any, llm: Any,
# template: str, # template: str,
max_loops = 5, max_loops=5,
stopping_condition: Optional[Callable[[str], bool]] = None, stopping_condition: Optional[Callable[[str], bool]] = None,
loop_interval: int = 1, loop_interval: int = 1,
retry_attempts: int = 3, retry_attempts: int = 3,
@ -188,7 +188,8 @@ class Flow:
value = self.llm.__dict__.get(name, "Unknown") value = self.llm.__dict__.get(name, "Unknown")
params_str_list.append( params_str_list.append(
f" {name.capitalize().replace('_', ' ')}: {value}") f" {name.capitalize().replace('_', ' ')}: {value}"
)
return "\n".join(params_str_list) return "\n".join(params_str_list)
@ -196,7 +197,7 @@ class Flow:
""" """
Take the history and truncate it to fit into the model context length Take the history and truncate it to fit into the model context length
""" """
truncated_history = self.memory[-1][-self.context_length:] truncated_history = self.memory[-1][-self.context_length :]
self.memory[-1] = truncated_history self.memory[-1] = truncated_history
def add_task_to_memory(self, task: str): def add_task_to_memory(self, task: str):
@ -246,7 +247,8 @@ class Flow:
---------------------------------------- ----------------------------------------
""", """,
"green", "green",
)) )
)
# print(dashboard) # print(dashboard)
@ -256,17 +258,18 @@ class Flow:
print(colored("Initializing Autonomous Agent...", "yellow")) print(colored("Initializing Autonomous Agent...", "yellow"))
# print(colored("Loading modules...", "yellow")) # print(colored("Loading modules...", "yellow"))
# print(colored("Modules loaded successfully.", "green")) # print(colored("Modules loaded successfully.", "green"))
print(colored("Autonomous Agent Activated.", "cyan", print(colored("Autonomous Agent Activated.", "cyan", attrs=["bold"]))
attrs=["bold"])) print(colored("All systems operational. Executing task...", "green"))
print(colored("All systems operational. Executing task...",
"green"))
except Exception as error: except Exception as error:
print( print(
colored( colored(
("Error activating autonomous agent. Try optimizing your" (
" parameters..."), "Error activating autonomous agent. Try optimizing your"
" parameters..."
),
"red", "red",
)) )
)
print(error) print(error)
def run(self, task: str, **kwargs): def run(self, task: str, **kwargs):
@ -296,7 +299,7 @@ class Flow:
loop_count = 0 loop_count = 0
# for i in range(self.max_loops): # for i in range(self.max_loops):
while self.max_loops == 'auto' or loop_count < self.max_loops: while self.max_loops == "auto" or loop_count < self.max_loops:
loop_count += 1 loop_count += 1
print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue"))
print("\n") print("\n")
@ -315,8 +318,7 @@ class Flow:
while attempt < self.retry_attempts: while attempt < self.retry_attempts:
try: try:
response = self.llm( response = self.llm(
task task**kwargs,
**kwargs,
) )
if self.interactive: if self.interactive:
print(f"AI: {response}") print(f"AI: {response}")
@ -344,7 +346,7 @@ class Flow:
if self.return_history: if self.return_history:
return response, history return response, history
return response return response
async def arun(self, task: str, **kwargs): async def arun(self, task: str, **kwargs):
""" """
@ -373,7 +375,7 @@ class Flow:
loop_count = 0 loop_count = 0
# for i in range(self.max_loops): # for i in range(self.max_loops):
while self.max_loops == 'auto' or loop_count < self.max_loops: while self.max_loops == "auto" or loop_count < self.max_loops:
loop_count += 1 loop_count += 1
print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue"))
print("\n") print("\n")
@ -392,8 +394,7 @@ class Flow:
while attempt < self.retry_attempts: while attempt < self.retry_attempts:
try: try:
response = self.llm( response = self.llm(
task task**kwargs,
**kwargs,
) )
if self.interactive: if self.interactive:
print(f"AI: {response}") print(f"AI: {response}")
@ -421,7 +422,7 @@ class Flow:
if self.return_history: if self.return_history:
return response, history return response, history
return response return response
def _run(self, **kwargs: Any) -> str: def _run(self, **kwargs: Any) -> str:
"""Generate a result using the provided keyword args.""" """Generate a result using the provided keyword args."""
@ -460,9 +461,7 @@ class Flow:
Args: Args:
tasks (List[str]): A list of tasks to run. tasks (List[str]): A list of tasks to run.
""" """
task_coroutines = [ task_coroutines = [self.run_async(task, **kwargs) for task in tasks]
self.run_async(task, **kwargs) for task in tasks
]
completed_tasks = await asyncio.gather(*task_coroutines) completed_tasks = await asyncio.gather(*task_coroutines)
return completed_tasks return completed_tasks
@ -575,9 +574,7 @@ class Flow:
import boto3 import boto3
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.put_object(Bucket=bucket_name, s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory))
Key=object_name,
Body=json.dumps(self.memory))
print(f"Backed up memory to S3: {bucket_name}/{object_name}") print(f"Backed up memory to S3: {bucket_name}/{object_name}")
def analyze_feedback(self): def analyze_feedback(self):
@ -681,7 +678,7 @@ class Flow:
def get_llm_params(self): def get_llm_params(self):
""" """
Extracts and returns the parameters of the llm object for serialization. Extracts and returns the parameters of the llm object for serialization.
It assumes that the llm object has an __init__ method It assumes that the llm object has an __init__ method
with parameters that can be used to recreate it. with parameters that can be used to recreate it.
""" """
if not hasattr(self.llm, "__init__"): if not hasattr(self.llm, "__init__"):
@ -697,8 +694,8 @@ class Flow:
if hasattr(self.llm, name): if hasattr(self.llm, name):
value = getattr(self.llm, name) value = getattr(self.llm, name)
if isinstance( if isinstance(
value, value, (str, int, float, bool, list, dict, tuple, type(None))
(str, int, float, bool, list, dict, tuple, type(None))): ):
llm_params[name] = value llm_params[name] = value
else: else:
llm_params[name] = str( llm_params[name] = str(
@ -758,10 +755,7 @@ class Flow:
print(f"Flow state loaded from {file_path}") print(f"Flow state loaded from {file_path}")
def retry_on_failure(self, def retry_on_failure(self, function, retries: int = 3, retry_delay: int = 1):
function,
retries: int = 3,
retry_delay: int = 1):
"""Retry wrapper for LLM calls.""" """Retry wrapper for LLM calls."""
attempt = 0 attempt = 0
while attempt < retries: while attempt < retries:

@ -8,10 +8,9 @@ class Task:
Task is a unit of work that can be executed by an agent Task is a unit of work that can be executed by an agent
""" """
def __init__(self, def __init__(
id: str, self, id: str, parents: List["Task"] = None, children: List["Task"] = None
parents: List["Task"] = None, ):
children: List["Task"] = None):
self.id = id self.id = id
self.parents = parents self.parents = parents
self.children = children self.children = children
@ -80,8 +79,7 @@ class NonLinearWorkflow:
for task in ordered_tasks: for task in ordered_tasks:
if task.can_execute: if task.can_execute:
future = self.executor.submit(self.agents.run, future = self.executor.submit(self.agents.run, task.task_string)
task.task_string)
futures_list[future] = task futures_list[future] = task
for future in as_completed(futures_list): for future in as_completed(futures_list):
@ -97,8 +95,7 @@ class NonLinearWorkflow:
def to_graph(self) -> Dict[str, set[str]]: def to_graph(self) -> Dict[str, set[str]]:
"""Convert the workflow to a graph""" """Convert the workflow to a graph"""
graph = { graph = {
task.id: set(child.id for child in task.children) task.id: set(child.id for child in task.children) for task in self.tasks
for task in self.tasks
} }
return graph return graph

@ -61,12 +61,13 @@ class Task:
if isinstance(self.flow, Flow): if isinstance(self.flow, Flow):
# Add a prompt to notify the Flow of the sequential workflow # Add a prompt to notify the Flow of the sequential workflow
if "prompt" in self.kwargs: if "prompt" in self.kwargs:
self.kwargs["prompt"] += (f"\n\nPrevious output: {self.result}" self.kwargs["prompt"] += (
if self.result else "") f"\n\nPrevious output: {self.result}" if self.result else ""
)
else: else:
self.kwargs["prompt"] = f"Main task: {self.description}" + ( self.kwargs["prompt"] = f"Main task: {self.description}" + (
f"\n\nPrevious output: {self.result}" f"\n\nPrevious output: {self.result}" if self.result else ""
if self.result else "") )
self.result = self.flow.run(*self.args, **self.kwargs) self.result = self.flow.run(*self.args, **self.kwargs)
else: else:
self.result = self.flow(*self.args, **self.kwargs) self.result = self.flow(*self.args, **self.kwargs)
@ -110,8 +111,7 @@ class SequentialWorkflow:
restore_state_filepath: Optional[str] = None restore_state_filepath: Optional[str] = None
dashboard: bool = False dashboard: bool = False
def add(self, task: str, flow: Union[Callable, Flow], *args, def add(self, task: str, flow: Union[Callable, Flow], *args, **kwargs) -> None:
**kwargs) -> None:
""" """
Add a task to the workflow. Add a task to the workflow.
@ -127,7 +127,8 @@ class SequentialWorkflow:
# Append the task to the tasks list # Append the task to the tasks list
self.tasks.append( self.tasks.append(
Task(description=task, flow=flow, args=list(args), kwargs=kwargs)) Task(description=task, flow=flow, args=list(args), kwargs=kwargs)
)
def reset_workflow(self) -> None: def reset_workflow(self) -> None:
"""Resets the workflow by clearing the results of each task.""" """Resets the workflow by clearing the results of each task."""
@ -179,9 +180,8 @@ class SequentialWorkflow:
raise ValueError(f"Task {task_description} not found in workflow.") raise ValueError(f"Task {task_description} not found in workflow.")
def save_workflow_state( def save_workflow_state(
self, self, filepath: Optional[str] = "sequential_workflow_state.json", **kwargs
filepath: Optional[str] = "sequential_workflow_state.json", ) -> None:
**kwargs) -> None:
""" """
Saves the workflow state to a json file. Saves the workflow state to a json file.
@ -202,13 +202,16 @@ class SequentialWorkflow:
with open(filepath, "w") as f: with open(filepath, "w") as f:
# Saving the state as a json for simplicuty # Saving the state as a json for simplicuty
state = { state = {
"tasks": [{ "tasks": [
"description": task.description, {
"args": task.args, "description": task.description,
"kwargs": task.kwargs, "args": task.args,
"result": task.result, "kwargs": task.kwargs,
"history": task.history, "result": task.result,
} for task in self.tasks], "history": task.history,
}
for task in self.tasks
],
"max_loops": self.max_loops, "max_loops": self.max_loops,
} }
json.dump(state, f, indent=4) json.dump(state, f, indent=4)
@ -220,7 +223,8 @@ class SequentialWorkflow:
Sequential Workflow Initializing...""", Sequential Workflow Initializing...""",
"green", "green",
attrs=["bold", "underline"], attrs=["bold", "underline"],
)) )
)
def workflow_dashboard(self, **kwargs) -> None: def workflow_dashboard(self, **kwargs) -> None:
""" """
@ -259,7 +263,8 @@ class SequentialWorkflow:
""", """,
"cyan", "cyan",
attrs=["bold", "underline"], attrs=["bold", "underline"],
)) )
)
def workflow_shutdown(self, **kwargs) -> None: def workflow_shutdown(self, **kwargs) -> None:
print( print(
@ -268,7 +273,8 @@ class SequentialWorkflow:
Sequential Workflow Shutdown...""", Sequential Workflow Shutdown...""",
"red", "red",
attrs=["bold", "underline"], attrs=["bold", "underline"],
)) )
)
def add_objective_to_workflow(self, task: str, **kwargs) -> None: def add_objective_to_workflow(self, task: str, **kwargs) -> None:
print( print(
@ -277,7 +283,8 @@ class SequentialWorkflow:
Adding Objective to Workflow...""", Adding Objective to Workflow...""",
"green", "green",
attrs=["bold", "underline"], attrs=["bold", "underline"],
)) )
)
task = Task( task = Task(
description=task, description=task,
@ -342,12 +349,13 @@ class SequentialWorkflow:
if "task" not in task.kwargs: if "task" not in task.kwargs:
raise ValueError( raise ValueError(
"The 'task' argument is required for the Flow flow" "The 'task' argument is required for the Flow flow"
f" execution in '{task.description}'") f" execution in '{task.description}'"
)
# Separate the 'task' argument from other kwargs # Separate the 'task' argument from other kwargs
flow_task_arg = task.kwargs.pop("task") flow_task_arg = task.kwargs.pop("task")
task.result = task.flow.run(flow_task_arg, task.result = task.flow.run(
*task.args, flow_task_arg, *task.args, **task.kwargs
**task.kwargs) )
else: else:
# If it's not a Flow instance, call the flow directly # If it's not a Flow instance, call the flow directly
task.result = task.flow(*task.args, **task.kwargs) task.result = task.flow(*task.args, **task.kwargs)
@ -365,17 +373,19 @@ class SequentialWorkflow:
# Autosave the workflow state # Autosave the workflow state
if self.autosave: if self.autosave:
self.save_workflow_state( self.save_workflow_state("sequential_workflow_state.json")
"sequential_workflow_state.json")
except Exception as e: except Exception as e:
print( print(
colored( colored(
(f"Error initializing the Sequential workflow: {e} try" (
" optimizing your inputs like the flow class and task" f"Error initializing the Sequential workflow: {e} try"
" description"), " optimizing your inputs like the flow class and task"
" description"
),
"red", "red",
attrs=["bold", "underline"], attrs=["bold", "underline"],
)) )
)
async def arun(self) -> None: async def arun(self) -> None:
""" """
@ -395,11 +405,13 @@ class SequentialWorkflow:
if "task" not in task.kwargs: if "task" not in task.kwargs:
raise ValueError( raise ValueError(
"The 'task' argument is required for the Flow flow" "The 'task' argument is required for the Flow flow"
f" execution in '{task.description}'") f" execution in '{task.description}'"
)
# Separate the 'task' argument from other kwargs # Separate the 'task' argument from other kwargs
flow_task_arg = task.kwargs.pop("task") flow_task_arg = task.kwargs.pop("task")
task.result = await task.flow.arun( task.result = await task.flow.arun(
flow_task_arg, *task.args, **task.kwargs) flow_task_arg, *task.args, **task.kwargs
)
else: else:
# If it's not a Flow instance, call the flow directly # If it's not a Flow instance, call the flow directly
task.result = await task.flow(*task.args, **task.kwargs) task.result = await task.flow(*task.args, **task.kwargs)
@ -417,5 +429,4 @@ class SequentialWorkflow:
# Autosave the workflow state # Autosave the workflow state
if self.autosave: if self.autosave:
self.save_workflow_state( self.save_workflow_state("sequential_workflow_state.json")
"sequential_workflow_state.json")

@ -13,7 +13,6 @@ from swarms.artifacts.error_artifact import ErrorArtifact
class BaseTask(ABC): class BaseTask(ABC):
class State(Enum): class State(Enum):
PENDING = 1 PENDING = 1
EXECUTING = 2 EXECUTING = 2
@ -34,15 +33,11 @@ class BaseTask(ABC):
@property @property
def parents(self) -> List[BaseTask]: def parents(self) -> List[BaseTask]:
return [ return [self.structure.find_task(parent_id) for parent_id in self.parent_ids]
self.structure.find_task(parent_id) for parent_id in self.parent_ids
]
@property @property
def children(self) -> List[BaseTask]: def children(self) -> List[BaseTask]:
return [ return [self.structure.find_task(child_id) for child_id in self.child_ids]
self.structure.find_task(child_id) for child_id in self.child_ids
]
def __rshift__(self, child: BaseTask) -> BaseTask: def __rshift__(self, child: BaseTask) -> BaseTask:
return self.add_child(child) return self.add_child(child)
@ -123,7 +118,8 @@ class BaseTask(ABC):
def can_execute(self) -> bool: def can_execute(self) -> bool:
return self.state == self.State.PENDING and all( return self.state == self.State.PENDING and all(
parent.is_finished() for parent in self.parents) parent.is_finished() for parent in self.parents
)
def reset(self) -> BaseTask: def reset(self) -> BaseTask:
self.state = self.State.PENDING self.state = self.State.PENDING
@ -136,10 +132,10 @@ class BaseTask(ABC):
class Task(BaseModel): class Task(BaseModel):
input: Optional[StrictStr] = Field(None, input: Optional[StrictStr] = Field(None, description="Input prompt for the task")
description="Input prompt for the task")
additional_input: Optional[Any] = Field( additional_input: Optional[Any] = Field(
None, description="Input parameters for the task. Any value is allowed") None, description="Input parameters for the task. Any value is allowed"
)
task_id: StrictStr = Field(..., description="ID of the task") task_id: StrictStr = Field(..., description="ID of the task")
class Config: class Config:

@ -65,13 +65,11 @@ class Workflow:
def context(self, task: Task) -> Dict[str, Any]: def context(self, task: Task) -> Dict[str, Any]:
"""Context in tasks""" """Context in tasks"""
return { return {
"parent_output": "parent_output": task.parents[0].output
task.parents[0].output if task.parents and task.parents[0].output
if task.parents and task.parents[0].output else None, else None,
"parent": "parent": task.parents[0] if task.parents else None,
task.parents[0] if task.parents else None, "child": task.children[0] if task.children else None,
"child":
task.children[0] if task.children else None,
} }
def __run_from_task(self, task: Optional[Task]) -> None: def __run_from_task(self, task: Optional[Task]) -> None:

@ -87,8 +87,7 @@ class AutoScaler:
while True: while True:
sleep(60) # check minute sleep(60) # check minute
pending_tasks = self.task_queue.qsize() pending_tasks = self.task_queue.qsize()
active_agents = sum( active_agents = sum([1 for agent in self.agents_pool if agent.is_busy()])
[1 for agent in self.agents_pool if agent.is_busy()])
if pending_tasks / len(self.agents_pool) > self.busy_threshold: if pending_tasks / len(self.agents_pool) > self.busy_threshold:
self.scale_up() self.scale_up()

@ -117,9 +117,7 @@ class AbstractSwarm(ABC):
pass pass
@abstractmethod @abstractmethod
def broadcast(self, def broadcast(self, message: str, sender: Optional["AbstractWorker"] = None):
message: str,
sender: Optional["AbstractWorker"] = None):
"""Broadcast a message to all workers""" """Broadcast a message to all workers"""
pass pass

@ -77,15 +77,19 @@ class BattleRoyalSwarm:
# Check for clashes and handle them # Check for clashes and handle them
for i, worker1 in enumerate(self.workers): for i, worker1 in enumerate(self.workers):
for j, worker2 in enumerate(self.workers): for j, worker2 in enumerate(self.workers):
if (i != j and worker1.is_within_proximity(worker2) and if (
set(worker1.teams) != set(worker2.teams)): i != j
and worker1.is_within_proximity(worker2)
and set(worker1.teams) != set(worker2.teams)
):
winner, loser = self.clash(worker1, worker2, question) winner, loser = self.clash(worker1, worker2, question)
print(f"Worker {winner.id} won over Worker {loser.id}") print(f"Worker {winner.id} won over Worker {loser.id}")
def communicate(self, sender: Worker, reciever: Worker, message: str): def communicate(self, sender: Worker, reciever: Worker, message: str):
"""Communicate a message from one worker to another.""" """Communicate a message from one worker to another."""
if sender.is_within_proximity(reciever) or any( if sender.is_within_proximity(reciever) or any(
team in sender.teams for team in reciever.teams): team in sender.teams for team in reciever.teams
):
pass pass
def clash(self, worker1: Worker, worker2: Worker, question: str): def clash(self, worker1: Worker, worker2: Worker, question: str):

@ -49,8 +49,9 @@ class GodMode:
table.append([f"LLM {i+1}", response]) table.append([f"LLM {i+1}", response])
print( print(
colored( colored(
tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan"
"cyan")) )
)
def run_all(self, task): def run_all(self, task):
"""Run the task on all LLMs""" """Run the task on all LLMs"""
@ -73,15 +74,18 @@ class GodMode:
table.append([f"LLM {i+1}", response]) table.append([f"LLM {i+1}", response])
print( print(
colored( colored(
tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan"
"cyan")) )
)
# New Features # New Features
def save_responses_to_file(self, filename): def save_responses_to_file(self, filename):
"""Save responses to file""" """Save responses to file"""
with open(filename, "w") as file: with open(filename, "w") as file:
table = [[f"LLM {i+1}", response] table = [
for i, response in enumerate(self.last_responses)] [f"LLM {i+1}", response]
for i, response in enumerate(self.last_responses)
]
file.write(tabulate(table, headers=["LLM", "Response"])) file.write(tabulate(table, headers=["LLM", "Response"]))
@classmethod @classmethod
@ -101,9 +105,11 @@ class GodMode:
for i, task in enumerate(self.task_history): for i, task in enumerate(self.task_history):
print(f"{i + 1}. {task}") print(f"{i + 1}. {task}")
print("\nLast Responses:") print("\nLast Responses:")
table = [[f"LLM {i+1}", response] table = [
for i, response in enumerate(self.last_responses)] [f"LLM {i+1}", response] for i, response in enumerate(self.last_responses)
]
print( print(
colored( colored(
tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan"
"cyan")) )
)

@ -33,8 +33,7 @@ class GroupChat:
def next_agent(self, agent: Flow) -> Flow: def next_agent(self, agent: Flow) -> Flow:
"""Return the next agent in the list.""" """Return the next agent in the list."""
return self.agents[(self.agent_names.index(agent.name) + 1) % return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]
len(self.agents)]
def select_speaker_msg(self): def select_speaker_msg(self):
"""Return the message for selecting the next speaker.""" """Return the message for selecting the next speaker."""
@ -55,17 +54,24 @@ class GroupChat:
if n_agents < 3: if n_agents < 3:
logger.warning( logger.warning(
f"GroupChat is underpopulated with {n_agents} agents. Direct" f"GroupChat is underpopulated with {n_agents} agents. Direct"
" communication would be more efficient.") " communication would be more efficient."
)
name = selector.generate_reply( name = selector.generate_reply(
self.format_history(self.messages + [{ self.format_history(
"role": self.messages
"system", + [
"content": {
("Read the above conversation. Then select the next most" "role": "system",
f" suitable role from {self.agent_names} to play. Only" "content": (
" return the role."), "Read the above conversation. Then select the next most"
}])) f" suitable role from {self.agent_names} to play. Only"
" return the role."
),
}
]
)
)
try: try:
return self.agent_by_name(name["content"]) return self.agent_by_name(name["content"])
except ValueError: except ValueError:
@ -73,7 +79,8 @@ class GroupChat:
def _participant_roles(self): def _participant_roles(self):
return "\n".join( return "\n".join(
[f"{agent.name}: {agent.system_message}" for agent in self.agents]) [f"{agent.name}: {agent.system_message}" for agent in self.agents]
)
def format_history(self, messages: List[Dict]) -> str: def format_history(self, messages: List[Dict]) -> str:
formatted_messages = [] formatted_messages = []
@ -84,21 +91,19 @@ class GroupChat:
class GroupChatManager: class GroupChatManager:
def __init__(self, groupchat: GroupChat, selector: Flow): def __init__(self, groupchat: GroupChat, selector: Flow):
self.groupchat = groupchat self.groupchat = groupchat
self.selector = selector self.selector = selector
def __call__(self, task: str): def __call__(self, task: str):
self.groupchat.messages.append({ self.groupchat.messages.append({"role": self.selector.name, "content": task})
"role": self.selector.name,
"content": task
})
for i in range(self.groupchat.max_round): for i in range(self.groupchat.max_round):
speaker = self.groupchat.select_speaker(last_speaker=self.selector, speaker = self.groupchat.select_speaker(
selector=self.selector) last_speaker=self.selector, selector=self.selector
)
reply = speaker.generate_reply( reply = speaker.generate_reply(
self.groupchat.format_history(self.groupchat.messages)) self.groupchat.format_history(self.groupchat.messages)
)
self.groupchat.messages.append(reply) self.groupchat.messages.append(reply)
print(reply) print(reply)
if i == self.groupchat.max_round - 1: if i == self.groupchat.max_round - 1:

@ -5,16 +5,16 @@ from langchain.output_parsers import RegexParser
# utils # utils
class BidOutputParser(RegexParser): class BidOutputParser(RegexParser):
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return ( return (
"Your response should be an integrater delimited by angled brackets like" "Your response should be an integrater delimited by angled brackets like"
" this: <int>") " this: <int>"
)
bid_parser = BidOutputParser(regex=r"<(\d+)>", bid_parser = BidOutputParser(
output_keys=["bid"], regex=r"<(\d+)>", output_keys=["bid"], default_output_key="bid"
default_output_key="bid") )
def select_next_speaker(step: int, agents, director) -> int: def select_next_speaker(step: int, agents, director) -> int:
@ -29,7 +29,6 @@ def select_next_speaker(step: int, agents, director) -> int:
# main # main
class MultiAgentCollaboration: class MultiAgentCollaboration:
def __init__( def __init__(
self, self,
agents, agents,

@ -46,6 +46,7 @@ class MultiAgentDebate:
def format_results(self, results): def format_results(self, results):
formatted_results = "\n".join( formatted_results = "\n".join(
[f"Agent responded: {result['response']}" for result in results]) [f"Agent responded: {result['response']}" for result in results]
)
return formatted_results return formatted_results

@ -111,8 +111,7 @@ class Orchestrator:
self.chroma_client = chromadb.Client() self.chroma_client = chromadb.Client()
self.collection = self.chroma_client.create_collection( self.collection = self.chroma_client.create_collection(name=collection_name)
name=collection_name)
self.current_tasks = {} self.current_tasks = {}
@ -138,8 +137,9 @@ class Orchestrator:
result = self.worker.run(task["content"]) result = self.worker.run(task["content"])
# using the embed method to get the vector representation of the result # using the embed method to get the vector representation of the result
vector_representation = self.embed(result, self.api_key, vector_representation = self.embed(
self.model_name) result, self.api_key, self.model_name
)
self.collection.add( self.collection.add(
embeddings=[vector_representation], embeddings=[vector_representation],
@ -154,7 +154,8 @@ class Orchestrator:
except Exception as error: except Exception as error:
logging.error( logging.error(
f"Failed to process task {id(task)} by agent {id(agent)}. Error:" f"Failed to process task {id(task)} by agent {id(agent)}. Error:"
f" {error}") f" {error}"
)
finally: finally:
with self.condition: with self.condition:
self.agents.put(agent) self.agents.put(agent)
@ -162,7 +163,8 @@ class Orchestrator:
def embed(self, input, api_key, model_name): def embed(self, input, api_key, model_name):
openai = embedding_functions.OpenAIEmbeddingFunction( openai = embedding_functions.OpenAIEmbeddingFunction(
api_key=api_key, model_name=model_name) api_key=api_key, model_name=model_name
)
embedding = openai(input) embedding = openai(input)
return embedding return embedding
@ -173,13 +175,13 @@ class Orchestrator:
try: try:
# Query the vector database for documents created by the agents # Query the vector database for documents created by the agents
results = self.collection.query(query_texts=[str(agent_id)], results = self.collection.query(query_texts=[str(agent_id)], n_results=10)
n_results=10)
return results return results
except Exception as e: except Exception as e:
logging.error( logging.error(
f"Failed to retrieve results from agent {agent_id}. Error {e}") f"Failed to retrieve results from agent {agent_id}. Error {e}"
)
raise raise
# @abstractmethod # @abstractmethod
@ -210,8 +212,7 @@ class Orchestrator:
self.collection.add(documents=[result], ids=[str(id(result))]) self.collection.add(documents=[result], ids=[str(id(result))])
except Exception as e: except Exception as e:
logging.error( logging.error(f"Failed to append the agent output to database. Error: {e}")
f"Failed to append the agent output to database. Error: {e}")
raise raise
def run(self, objective: str): def run(self, objective: str):
@ -224,8 +225,8 @@ class Orchestrator:
self.task_queue.append(objective) self.task_queue.append(objective)
results = [ results = [
self.assign_task(agent_id, task) for agent_id, task in zip( self.assign_task(agent_id, task)
range(len(self.agents)), self.task_queue) for agent_id, task in zip(range(len(self.agents)), self.task_queue)
] ]
for result in results: for result in results:

@ -2,7 +2,6 @@ from queue import Queue, PriorityQueue
class SimpleSwarm: class SimpleSwarm:
def __init__( def __init__(
self, self,
llm, llm,

@ -8,7 +8,8 @@ import torch
from langchain.agents import tool from langchain.agents import tool
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.chains.qa_with_sources.loading import ( from langchain.chains.qa_with_sources.loading import (
BaseCombineDocumentsChain,) BaseCombineDocumentsChain,
)
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import BaseTool from langchain.tools import BaseTool
@ -36,10 +37,9 @@ def pushd(new_dir):
@tool @tool
def process_csv(llm, def process_csv(
csv_file_path: str, llm, csv_file_path: str, instructions: str, output_path: Optional[str] = None
instructions: str, ) -> str:
output_path: Optional[str] = None) -> str:
"""Process a CSV by with pandas in a limited REPL.\ """Process a CSV by with pandas in a limited REPL.\
Only use this after writing data to disk as a csv file.\ Only use this after writing data to disk as a csv file.\
Any figures must be saved to disk to be viewed by the human.\ Any figures must be saved to disk to be viewed by the human.\
@ -49,10 +49,7 @@ def process_csv(llm,
df = pd.read_csv(csv_file_path) df = pd.read_csv(csv_file_path)
except Exception as e: except Exception as e:
return f"Error: {e}" return f"Error: {e}"
agent = create_pandas_dataframe_agent(llm, agent = create_pandas_dataframe_agent(llm, df, max_iterations=30, verbose=False)
df,
max_iterations=30,
verbose=False)
if output_path is not None: if output_path is not None:
instructions += f" Save output to disk at {output_path}" instructions += f" Save output to disk at {output_path}"
try: try:
@ -82,8 +79,7 @@ async def async_load_playwright(url: str) -> str:
text = soup.get_text() text = soup.get_text()
lines = (line.strip() for line in text.splitlines()) lines = (line.strip() for line in text.splitlines())
chunks = ( chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
phrase.strip() for line in lines for phrase in line.split(" "))
results = "\n".join(chunk for chunk in chunks if chunk) results = "\n".join(chunk for chunk in chunks if chunk)
except Exception as e: except Exception as e:
results = f"Error: {e}" results = f"Error: {e}"
@ -117,7 +113,8 @@ class WebpageQATool(BaseTool):
"Browse a webpage and retrieve the information relevant to the question." "Browse a webpage and retrieve the information relevant to the question."
) )
text_splitter: RecursiveCharacterTextSplitter = Field( text_splitter: RecursiveCharacterTextSplitter = Field(
default_factory=_get_text_splitter) default_factory=_get_text_splitter
)
qa_chain: BaseCombineDocumentsChain qa_chain: BaseCombineDocumentsChain
def _run(self, url: str, question: str) -> str: def _run(self, url: str, question: str) -> str:
@ -128,12 +125,9 @@ class WebpageQATool(BaseTool):
results = [] results = []
# TODO: Handle this with a MapReduceChain # TODO: Handle this with a MapReduceChain
for i in range(0, len(web_docs), 4): for i in range(0, len(web_docs), 4):
input_docs = web_docs[i:i + 4] input_docs = web_docs[i : i + 4]
window_result = self.qa_chain( window_result = self.qa_chain(
{ {"input_documents": input_docs, "question": question},
"input_documents": input_docs,
"question": question
},
return_only_outputs=True, return_only_outputs=True,
) )
results.append(f"Response from window {i} - {window_result}") results.append(f"Response from window {i} - {window_result}")
@ -141,10 +135,7 @@ class WebpageQATool(BaseTool):
Document(page_content="\n".join(results), metadata={"source": url}) Document(page_content="\n".join(results), metadata={"source": url})
] ]
return self.qa_chain( return self.qa_chain(
{ {"input_documents": results_docs, "question": question},
"input_documents": results_docs,
"question": question
},
return_only_outputs=True, return_only_outputs=True,
) )
@ -180,17 +171,18 @@ def VQAinference(self, inputs):
torch_dtype = torch.float16 if "cuda" in device else torch.float32 torch_dtype = torch.float16 if "cuda" in device else torch.float32
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained( model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", torch_dtype=torch_dtype).to(device) "Salesforce/blip-vqa-base", torch_dtype=torch_dtype
).to(device)
image_path, question = inputs.split(",") image_path, question = inputs.split(",")
raw_image = Image.open(image_path).convert("RGB") raw_image = Image.open(image_path).convert("RGB")
inputs = processor(raw_image, question, inputs = processor(raw_image, question, return_tensors="pt").to(device, torch_dtype)
return_tensors="pt").to(device, torch_dtype)
out = model.generate(**inputs) out = model.generate(**inputs)
answer = processor.decode(out[0], skip_special_tokens=True) answer = processor.decode(out[0], skip_special_tokens=True)
logger.debug( logger.debug(
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
f" Question: {question}, Output Answer: {answer}") f" Question: {question}, Output Answer: {answer}"
)
return answer return answer

@ -25,14 +25,13 @@ from swarms.utils.main import BaseHandler, get_new_image_name
class MaskFormer: class MaskFormer:
def __init__(self, device): def __init__(self, device):
print("Initializing MaskFormer to %s" % device) print("Initializing MaskFormer to %s" % device)
self.device = device self.device = device
self.processor = CLIPSegProcessor.from_pretrained( self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
"CIDAS/clipseg-rd64-refined")
self.model = CLIPSegForImageSegmentation.from_pretrained( self.model = CLIPSegForImageSegmentation.from_pretrained(
"CIDAS/clipseg-rd64-refined").to(device) "CIDAS/clipseg-rd64-refined"
).to(device)
def inference(self, image_path, text): def inference(self, image_path, text):
threshold = 0.5 threshold = 0.5
@ -40,10 +39,9 @@ class MaskFormer:
padding = 20 padding = 20
original_image = Image.open(image_path) original_image = Image.open(image_path)
image = original_image.resize((512, 512)) image = original_image.resize((512, 512))
inputs = self.processor(text=text, inputs = self.processor(
images=image, text=text, images=image, padding="max_length", return_tensors="pt"
padding="max_length", ).to(self.device)
return_tensors="pt").to(self.device)
with torch.no_grad(): with torch.no_grad():
outputs = self.model(**inputs) outputs = self.model(**inputs)
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
@ -54,7 +52,8 @@ class MaskFormer:
mask_array = np.zeros_like(mask, dtype=bool) mask_array = np.zeros_like(mask, dtype=bool)
for idx in true_indices: for idx in true_indices:
padded_slice = tuple( padded_slice = tuple(
slice(max(0, i - padding), i + padding + 1) for i in idx) slice(max(0, i - padding), i + padding + 1) for i in idx
)
mask_array[padded_slice] = True mask_array[padded_slice] = True
visual_mask = (mask_array * 255).astype(np.uint8) visual_mask = (mask_array * 255).astype(np.uint8)
image_mask = Image.fromarray(visual_mask) image_mask = Image.fromarray(visual_mask)
@ -62,7 +61,6 @@ class MaskFormer:
class ImageEditing: class ImageEditing:
def __init__(self, device): def __init__(self, device):
print("Initializing ImageEditing to %s" % device) print("Initializing ImageEditing to %s" % device)
self.device = device self.device = device
@ -77,24 +75,25 @@ class ImageEditing:
@tool( @tool(
name="Remove Something From The Photo", name="Remove Something From The Photo",
description= description=(
("useful when you want to remove and object or something from the photo " "useful when you want to remove and object or something from the photo "
"from its description or location. " "from its description or location. "
"The input to this tool should be a comma separated string of two, " "The input to this tool should be a comma separated string of two, "
"representing the image_path and the object need to be removed. "), "representing the image_path and the object need to be removed. "
),
) )
def inference_remove(self, inputs): def inference_remove(self, inputs):
image_path, to_be_removed_txt = inputs.split(",") image_path, to_be_removed_txt = inputs.split(",")
return self.inference_replace( return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
f"{image_path},{to_be_removed_txt},background")
@tool( @tool(
name="Replace Something From The Photo", name="Replace Something From The Photo",
description= description=(
("useful when you want to replace an object from the object description or" "useful when you want to replace an object from the object description or"
" location with another object from its description. The input to this tool" " location with another object from its description. The input to this tool"
" should be a comma separated string of three, representing the image_path," " should be a comma separated string of three, representing the image_path,"
" the object to be replaced, the object to be replaced with "), " the object to be replaced, the object to be replaced with "
),
) )
def inference_replace(self, inputs): def inference_replace(self, inputs):
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",") image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
@ -106,21 +105,22 @@ class ImageEditing:
image=original_image.resize((512, 512)), image=original_image.resize((512, 512)),
mask_image=mask_image.resize((512, 512)), mask_image=mask_image.resize((512, 512)),
).images[0] ).images[0]
updated_image_path = get_new_image_name(image_path, updated_image_path = get_new_image_name(
func_name="replace-something") image_path, func_name="replace-something"
)
updated_image = updated_image.resize(original_size) updated_image = updated_image.resize(original_size)
updated_image.save(updated_image_path) updated_image.save(updated_image_path)
logger.debug( logger.debug(
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace" f"\nProcessed ImageEditing, Input Image: {image_path}, Replace"
f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:" f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:"
f" {updated_image_path}") f" {updated_image_path}"
)
return updated_image_path return updated_image_path
class InstructPix2Pix: class InstructPix2Pix:
def __init__(self, device): def __init__(self, device):
print("Initializing InstructPix2Pix to %s" % device) print("Initializing InstructPix2Pix to %s" % device)
self.device = device self.device = device
@ -131,56 +131,60 @@ class InstructPix2Pix:
torch_dtype=self.torch_dtype, torch_dtype=self.torch_dtype,
).to(device) ).to(device)
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.pipe.scheduler.config) self.pipe.scheduler.config
)
@tool( @tool(
name="Instruct Image Using Text", name="Instruct Image Using Text",
description= description=(
("useful when you want to the style of the image to be like the text. " "useful when you want to the style of the image to be like the text. "
"like: make it look like a painting. or make it like a robot. " "like: make it look like a painting. or make it like a robot. "
"The input to this tool should be a comma separated string of two, " "The input to this tool should be a comma separated string of two, "
"representing the image_path and the text. "), "representing the image_path and the text. "
),
) )
def inference(self, inputs): def inference(self, inputs):
"""Change style of image.""" """Change style of image."""
logger.debug("===> Starting InstructPix2Pix Inference") logger.debug("===> Starting InstructPix2Pix Inference")
image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:]) image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:])
original_image = Image.open(image_path) original_image = Image.open(image_path)
image = self.pipe(text, image = self.pipe(
image=original_image, text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2
num_inference_steps=40, ).images[0]
image_guidance_scale=1.2).images[0]
updated_image_path = get_new_image_name(image_path, func_name="pix2pix") updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
image.save(updated_image_path) image.save(updated_image_path)
logger.debug( logger.debug(
f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:" f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text:"
f" {text}, Output Image: {updated_image_path}") f" {text}, Output Image: {updated_image_path}"
)
return updated_image_path return updated_image_path
class Text2Image: class Text2Image:
def __init__(self, device): def __init__(self, device):
print("Initializing Text2Image to %s" % device) print("Initializing Text2Image to %s" % device)
self.device = device self.device = device
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.pipe = StableDiffusionPipeline.from_pretrained( self.pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype) "runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype
)
self.pipe.to(device) self.pipe.to(device)
self.a_prompt = "best quality, extremely detailed" self.a_prompt = "best quality, extremely detailed"
self.n_prompt = ( self.n_prompt = (
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, " "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, "
"fewer digits, cropped, worst quality, low quality") "fewer digits, cropped, worst quality, low quality"
)
@tool( @tool(
name="Generate Image From User Input Text", name="Generate Image From User Input Text",
description= description=(
("useful when you want to generate an image from a user input text and save" "useful when you want to generate an image from a user input text and save"
" it to a file. like: generate an image of an object or something, or" " it to a file. like: generate an image of an object or something, or"
" generate an image that includes some objects. The input to this tool" " generate an image that includes some objects. The input to this tool"
" should be a string, representing the text used to generate image. "), " should be a string, representing the text used to generate image. "
),
) )
def inference(self, text): def inference(self, text):
image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png") image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png")
@ -190,59 +194,59 @@ class Text2Image:
logger.debug( logger.debug(
f"\nProcessed Text2Image, Input Text: {text}, Output Image:" f"\nProcessed Text2Image, Input Text: {text}, Output Image:"
f" {image_filename}") f" {image_filename}"
)
return image_filename return image_filename
class VisualQuestionAnswering: class VisualQuestionAnswering:
def __init__(self, device): def __init__(self, device):
print("Initializing VisualQuestionAnswering to %s" % device) print("Initializing VisualQuestionAnswering to %s" % device)
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.device = device self.device = device
self.processor = BlipProcessor.from_pretrained( self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
"Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained( self.model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
torch_dtype=self.torch_dtype).to(self.device) ).to(self.device)
@tool( @tool(
name="Answer Question About The Image", name="Answer Question About The Image",
description= description=(
("useful when you need an answer for a question based on an image. like:" "useful when you need an answer for a question based on an image. like:"
" what is the background color of the last image, how many cats in this" " what is the background color of the last image, how many cats in this"
" figure, what is in this figure. The input to this tool should be a comma" " figure, what is in this figure. The input to this tool should be a comma"
" separated string of two, representing the image_path and the question" " separated string of two, representing the image_path and the question"
), ),
) )
def inference(self, inputs): def inference(self, inputs):
image_path, question = inputs.split(",") image_path, question = inputs.split(",")
raw_image = Image.open(image_path).convert("RGB") raw_image = Image.open(image_path).convert("RGB")
inputs = self.processor(raw_image, question, inputs = self.processor(raw_image, question, return_tensors="pt").to(
return_tensors="pt").to(self.device, self.device, self.torch_dtype
self.torch_dtype) )
out = self.model.generate(**inputs) out = self.model.generate(**inputs)
answer = self.processor.decode(out[0], skip_special_tokens=True) answer = self.processor.decode(out[0], skip_special_tokens=True)
logger.debug( logger.debug(
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input" f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
f" Question: {question}, Output Answer: {answer}") f" Question: {question}, Output Answer: {answer}"
)
return answer return answer
class ImageCaptioning(BaseHandler): class ImageCaptioning(BaseHandler):
def __init__(self, device): def __init__(self, device):
print("Initializing ImageCaptioning to %s" % device) print("Initializing ImageCaptioning to %s" % device)
self.device = device self.device = device
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32 self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
self.processor = BlipProcessor.from_pretrained( self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base") "Salesforce/blip-image-captioning-base"
)
self.model = BlipForConditionalGeneration.from_pretrained( self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base", "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype
torch_dtype=self.torch_dtype).to(self.device) ).to(self.device)
def handle(self, filename: str): def handle(self, filename: str):
img = Image.open(filename) img = Image.open(filename)
@ -254,13 +258,14 @@ class ImageCaptioning(BaseHandler):
img.save(filename, "PNG") img.save(filename, "PNG")
print(f"Resize image form {width}x{height} to {width_new}x{height_new}") print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
inputs = self.processor(Image.open(filename), inputs = self.processor(Image.open(filename), return_tensors="pt").to(
return_tensors="pt").to(self.device, self.device, self.torch_dtype
self.torch_dtype) )
out = self.model.generate(**inputs) out = self.model.generate(**inputs)
description = self.processor.decode(out[0], skip_special_tokens=True) description = self.processor.decode(out[0], skip_special_tokens=True)
print( print(
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:" f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text:"
f" {description}") f" {description}"
)
return IMAGE_PROMPT.format(filename=filename, description=description) return IMAGE_PROMPT.format(filename=filename, description=description)

@ -9,7 +9,6 @@ from pytube import YouTube
class SpeechToText: class SpeechToText:
def __init__( def __init__(
self, self,
video_url, video_url,
@ -62,15 +61,14 @@ class SpeechToText:
compute_type = "float16" compute_type = "float16"
# 1. Transcribe with original Whisper (batched) 🗣️ # 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx.load_model("large-v2", model = whisperx.load_model("large-v2", device, compute_type=compute_type)
device,
compute_type=compute_type)
audio = whisperx.load_audio(audio_file) audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size) result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model( model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device) language_code=result["language"], device=device
)
result = whisperx.align( result = whisperx.align(
result["segments"], result["segments"],
model_a, model_a,
@ -82,7 +80,8 @@ class SpeechToText:
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline( diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device) use_auth_token=self.hf_api_key, device=device
)
diarize_model(audio_file) diarize_model(audio_file)
try: try:
@ -99,7 +98,8 @@ class SpeechToText:
# 2. Align Whisper output 🔍 # 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model( model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=self.device) language_code=result["language"], device=self.device
)
result = whisperx.align( result = whisperx.align(
result["segments"], result["segments"],
@ -112,7 +112,8 @@ class SpeechToText:
# 3. Assign speaker labels 🏷️ # 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline( diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device) use_auth_token=self.hf_api_key, device=self.device
)
diarize_model(audio_file) diarize_model(audio_file)

@ -34,8 +34,9 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation.""" """Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model(name: str, model: BaseModel, def _create_subset_model(
field_names: list) -> Type[BaseModel]: name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields.""" """Create a pydantic model with only a subset of model's fields."""
fields = {} fields = {}
for field_name in field_names: for field_name in field_names:
@ -51,11 +52,7 @@ def _get_filtered_args(
"""Get the arguments from a function's signature.""" """Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"] schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters valid_keys = signature(func).parameters
return { return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
k: schema[k]
for k in valid_keys
if k not in ("run_manager", "callbacks")
}
class _SchemaConfig: class _SchemaConfig:
@ -85,8 +82,9 @@ def create_schema_from_function(
del inferred_model.__fields__["callbacks"] del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func) valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(f"{model_name}Schema", inferred_model, return _create_subset_model(
list(valid_properties)) f"{model_name}Schema", inferred_model, list(valid_properties)
)
class ToolException(Exception): class ToolException(Exception):
@ -127,7 +125,8 @@ class ChildTool(BaseTool):
"Expected annotation of 'Type[BaseModel]'" "Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n" f" but got '{args_schema_type}'.\n"
"Expected class looks like:\n" "Expected class looks like:\n"
f"{typehint_mandate}") f"{typehint_mandate}"
)
name: str name: str
"""The unique name of the tool that clearly communicates its purpose.""" """The unique name of the tool that clearly communicates its purpose."""
@ -148,8 +147,7 @@ class ChildTool(BaseTool):
callbacks: Callbacks = Field(default=None, exclude=True) callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to be called during tool execution.""" """Callbacks to be called during tool execution."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
exclude=True)
"""Deprecated. Please use callbacks instead.""" """Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None """Optional list of tags associated with the tool. Defaults to None
@ -164,8 +162,9 @@ class ChildTool(BaseTool):
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
""" """
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], handle_tool_error: Optional[
str]]] = False Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
class Config(Serializable.Config): class Config(Serializable.Config):
@ -245,9 +244,7 @@ class ChildTool(BaseTool):
else: else:
if input_args is not None: if input_args is not None:
result = input_args.parse_obj(tool_input) result = input_args.parse_obj(tool_input)
return { return {k: v for k, v in result.dict().items() if k in tool_input}
k: v for k, v in result.dict().items() if k in tool_input
}
return tool_input return tool_input
@root_validator() @root_validator()
@ -289,8 +286,7 @@ class ChildTool(BaseTool):
*args, *args,
) )
def _to_args_and_kwargs(self, def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string, # For backwards compatibility, if run_input is a string,
# pass as a positional argument. # pass as a positional argument.
if isinstance(tool_input, str): if isinstance(tool_input, str):
@ -329,10 +325,7 @@ class ChildTool(BaseTool):
# TODO: maybe also pass through run_manager is _run supports kwargs # TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager") new_arg_supported = signature(self._run).parameters.get("run_manager")
run_manager = callback_manager.on_tool_start( run_manager = callback_manager.on_tool_start(
{ {"name": self.name, "description": self.description},
"name": self.name,
"description": self.description
},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name, name=run_name,
@ -342,7 +335,9 @@ class ChildTool(BaseTool):
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = ( observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs) self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported else self._run(*tool_args, **tool_kwargs)) if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
)
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
run_manager.on_tool_error(e) run_manager.on_tool_error(e)
@ -359,20 +354,19 @@ class ChildTool(BaseTool):
else: else:
raise ValueError( raise ValueError(
"Got unexpected type of `handle_tool_error`. Expected bool, str " "Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}") f"or callable. Received: {self.handle_tool_error}"
run_manager.on_tool_end(str(observation), )
color="red", run_manager.on_tool_end(
name=self.name, str(observation), color="red", name=self.name, **kwargs
**kwargs) )
return observation return observation
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e) run_manager.on_tool_error(e)
raise e raise e
else: else:
run_manager.on_tool_end(str(observation), run_manager.on_tool_end(
color=color, str(observation), color=color, name=self.name, **kwargs
name=self.name, )
**kwargs)
return observation return observation
async def arun( async def arun(
@ -405,10 +399,7 @@ class ChildTool(BaseTool):
) )
new_arg_supported = signature(self._arun).parameters.get("run_manager") new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start( run_manager = await callback_manager.on_tool_start(
{ {"name": self.name, "description": self.description},
"name": self.name,
"description": self.description
},
tool_input if isinstance(tool_input, str) else str(tool_input), tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color, color=start_color,
name=run_name, name=run_name,
@ -417,10 +408,11 @@ class ChildTool(BaseTool):
try: try:
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (await self._arun(*tool_args, observation = (
run_manager=run_manager, await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
**tool_kwargs) if new_arg_supported if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)) else await self._arun(*tool_args, **tool_kwargs)
)
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
await run_manager.on_tool_error(e) await run_manager.on_tool_error(e)
@ -437,20 +429,19 @@ class ChildTool(BaseTool):
else: else:
raise ValueError( raise ValueError(
"Got unexpected type of `handle_tool_error`. Expected bool, str " "Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}") f"or callable. Received: {self.handle_tool_error}"
await run_manager.on_tool_end(str(observation), )
color="red", await run_manager.on_tool_end(
name=self.name, str(observation), color="red", name=self.name, **kwargs
**kwargs) )
return observation return observation
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e) await run_manager.on_tool_error(e)
raise e raise e
else: else:
await run_manager.on_tool_end(str(observation), await run_manager.on_tool_end(
color=color, str(observation), color=color, name=self.name, **kwargs
name=self.name, )
**kwargs)
return observation return observation
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
@ -477,7 +468,8 @@ class Tool(BaseTool):
if not self.coroutine: if not self.coroutine:
# If the tool does not implement async, fall back to default implementation # If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)) None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@ -492,8 +484,7 @@ class Tool(BaseTool):
# assume it takes a single string input. # assume it takes a single string input.
return {"tool_input": {"type": "string"}} return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model.""" """Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input) args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input # For backwards compatibility. The tool must be run with a single input
@ -512,13 +503,16 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
if self.func: if self.func:
new_argument_supported = signature( new_argument_supported = signature(self.func).parameters.get("callbacks")
self.func).parameters.get("callbacks") return (
return (self.func( self.func(
*args, *args,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**kwargs, **kwargs,
) if new_argument_supported else self.func(*args, **kwargs)) )
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync") raise NotImplementedError("Tool does not support sync")
async def _arun( async def _arun(
@ -529,27 +523,31 @@ class Tool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
new_argument_supported = signature( new_argument_supported = signature(self.coroutine).parameters.get(
self.coroutine).parameters.get("callbacks") "callbacks"
return (await self.coroutine( )
*args, return (
callbacks=run_manager.get_child() if run_manager else None, await self.coroutine(
**kwargs, *args,
) if new_argument_supported else await self.coroutine( callbacks=run_manager.get_child() if run_manager else None,
*args, **kwargs)) **kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
else: else:
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self._run, run_manager=run_manager, **kwargs), None, partial(self._run, run_manager=run_manager, **kwargs), *args
*args) )
# TODO: this is for backwards compatibility, remove in future # TODO: this is for backwards compatibility, remove in future
def __init__(self, name: str, func: Optional[Callable], description: str, def __init__(
**kwargs: Any) -> None: self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool.""" """Initialize tool."""
super(Tool, self).__init__(name=name, super(Tool, self).__init__(
func=func, name=name, func=func, description=description, **kwargs
description=description, )
**kwargs)
@classmethod @classmethod
def from_function( def from_function(
@ -559,8 +557,9 @@ class Tool(BaseTool):
description: str, description: str,
return_direct: bool = False, return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None, args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[Callable[..., Awaitable[ coroutine: Optional[
Any]]] = None, # This is last for compatibility, but should be after func Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
**kwargs: Any, **kwargs: Any,
) -> Tool: ) -> Tool:
"""Initialize tool from a function.""" """Initialize tool from a function."""
@ -598,7 +597,8 @@ class StructuredTool(BaseTool):
if not self.coroutine: if not self.coroutine:
# If the tool does not implement async, fall back to default implementation # If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)) None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs) return await super().ainvoke(input, config, **kwargs)
@ -617,13 +617,16 @@ class StructuredTool(BaseTool):
) -> Any: ) -> Any:
"""Use the tool.""" """Use the tool."""
if self.func: if self.func:
new_argument_supported = signature( new_argument_supported = signature(self.func).parameters.get("callbacks")
self.func).parameters.get("callbacks") return (
return (self.func( self.func(
*args, *args,
callbacks=run_manager.get_child() if run_manager else None, callbacks=run_manager.get_child() if run_manager else None,
**kwargs, **kwargs,
) if new_argument_supported else self.func(*args, **kwargs)) )
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync") raise NotImplementedError("Tool does not support sync")
async def _arun( async def _arun(
@ -634,14 +637,18 @@ class StructuredTool(BaseTool):
) -> str: ) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
if self.coroutine: if self.coroutine:
new_argument_supported = signature( new_argument_supported = signature(self.coroutine).parameters.get(
self.coroutine).parameters.get("callbacks") "callbacks"
return (await self.coroutine( )
*args, return (
callbacks=run_manager.get_child() if run_manager else None, await self.coroutine(
**kwargs, *args,
) if new_argument_supported else await self.coroutine( callbacks=run_manager.get_child() if run_manager else None,
*args, **kwargs)) **kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, None,
partial(self._run, run_manager=run_manager, **kwargs), partial(self._run, run_manager=run_manager, **kwargs),
@ -698,7 +705,8 @@ class StructuredTool(BaseTool):
description = description or source_function.__doc__ description = description or source_function.__doc__
if description is None: if description is None:
raise ValueError( raise ValueError(
"Function must have a docstring if description not provided.") "Function must have a docstring if description not provided."
)
# Description example: # Description example:
# search_api(query: str) - Searches the API for the query. # search_api(query: str) - Searches the API for the query.
@ -706,8 +714,7 @@ class StructuredTool(BaseTool):
description = f"{name}{sig} - {description.strip()}" description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema _args_schema = args_schema
if _args_schema is None and infer_schema: if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", _args_schema = create_schema_from_function(f"{name}Schema", source_function)
source_function)
return cls( return cls(
name=name, name=name,
func=func, func=func,
@ -755,7 +762,6 @@ def tool(
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool: def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable): if isinstance(dec_func, Runnable):
runnable = dec_func runnable = dec_func
@ -763,13 +769,14 @@ def tool(
if runnable.input_schema.schema().get("type") != "object": if runnable.input_schema.schema().get("type") != "object":
raise ValueError("Runnable must have an object schema.") raise ValueError("Runnable must have an object schema.")
async def ainvoke_wrapper(callbacks: Optional[Callbacks] = None, async def ainvoke_wrapper(
**kwargs: Any) -> Any: callbacks: Optional[Callbacks] = None, **kwargs: Any
return await runnable.ainvoke(kwargs, ) -> Any:
{"callbacks": callbacks}) return await runnable.ainvoke(kwargs, {"callbacks": callbacks})
def invoke_wrapper(callbacks: Optional[Callbacks] = None, def invoke_wrapper(
**kwargs: Any) -> Any: callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return runnable.invoke(kwargs, {"callbacks": callbacks}) return runnable.invoke(kwargs, {"callbacks": callbacks})
coroutine = ainvoke_wrapper coroutine = ainvoke_wrapper
@ -802,7 +809,8 @@ def tool(
if func.__doc__ is None: if func.__doc__ is None:
raise ValueError( raise ValueError(
"Function must have a docstring if " "Function must have a docstring if "
"description not provided and infer_schema is False.") "description not provided and infer_schema is False."
)
return Tool( return Tool(
name=tool_name, name=tool_name,
func=func, func=func,
@ -813,8 +821,7 @@ def tool(
return _make_tool return _make_tool
if len(args) == 2 and isinstance(args[0], str) and isinstance( if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
args[1], Runnable):
return _make_with_name(args[0])(args[1]) return _make_with_name(args[0])(args[1])
elif len(args) == 1 and isinstance(args[0], str): elif len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name # if the argument is a string, then we use the string as the tool name

@ -6,7 +6,6 @@ FuncToolBuilder = Callable[[], ToolBuilder]
class ToolsRegistry: class ToolsRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self.tools: Dict[str, FuncToolBuilder] = {} self.tools: Dict[str, FuncToolBuilder] = {}
@ -19,7 +18,8 @@ class ToolsRegistry:
if isinstance(ret, tool): if isinstance(ret, tool):
return ret return ret
raise ValueError( raise ValueError(
"Tool builder {} did not return a Tool instance".format(tool_name)) "Tool builder {} did not return a Tool instance".format(tool_name)
)
def list_tools(self) -> List[str]: def list_tools(self) -> List[str]:
return list(self.tools.keys()) return list(self.tools.keys())
@ -29,7 +29,6 @@ tools_registry = ToolsRegistry()
def register(tool_name): def register(tool_name):
def decorator(tool: FuncToolBuilder): def decorator(tool: FuncToolBuilder):
tools_registry.register(tool_name, tool) tools_registry.register(tool_name, tool)
return tool return tool

@ -118,19 +118,14 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
# Most of the time it doesn't matter, but we should figure out why it happens frequently with: # Most of the time it doesn't matter, but we should figure out why it happens frequently with:
# applescript # applescript
yield {"output": traceback.format_exc()} yield {"output": traceback.format_exc()}
yield { yield {"output": f"Retrying... ({retry_count}/{max_retries})"}
"output": f"Retrying... ({retry_count}/{max_retries})"
}
yield {"output": "Restarting process."} yield {"output": "Restarting process."}
self.start_process() self.start_process()
retry_count += 1 retry_count += 1
if retry_count > max_retries: if retry_count > max_retries:
yield { yield {"output": "Maximum retries reached. Could not execute code."}
"output":
"Maximum retries reached. Could not execute code."
}
return return
while True: while True:
@ -139,8 +134,7 @@ class SubprocessCodeInterpreter(BaseCodeInterpreter):
else: else:
time.sleep(0.1) time.sleep(0.1)
try: try:
output = self.output_queue.get( output = self.output_queue.get(timeout=0.3) # Waits for 0.3 seconds
timeout=0.3) # Waits for 0.3 seconds
yield output yield output
except queue.Empty: except queue.Empty:
if self.done.is_set(): if self.done.is_set():

@ -6,7 +6,6 @@ import warnings
def log_decorator(func): def log_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
logging.info(f"Entering {func.__name__}") logging.info(f"Entering {func.__name__}")
result = func(*args, **kwargs) result = func(*args, **kwargs)
@ -17,7 +16,6 @@ def log_decorator(func):
def error_decorator(func): def error_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -29,22 +27,18 @@ def error_decorator(func):
def timing_decorator(func): def timing_decorator(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = time.time() start_time = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
end_time = time.time() end_time = time.time()
logging.info( logging.info(f"{func.__name__} executed in {end_time - start_time} seconds")
f"{func.__name__} executed in {end_time - start_time} seconds")
return result return result
return wrapper return wrapper
def retry_decorator(max_retries=5): def retry_decorator(max_retries=5):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for _ in range(max_retries): for _ in range(max_retries):
@ -83,20 +77,16 @@ def synchronized_decorator(func):
def deprecated_decorator(func): def deprecated_decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
warnings.warn(f"{func.__name__} is deprecated", warnings.warn(f"{func.__name__} is deprecated", category=DeprecationWarning)
category=DeprecationWarning)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
def validate_inputs_decorator(validator): def validate_inputs_decorator(validator):
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not validator(*args, **kwargs): if not validator(*args, **kwargs):

@ -5,8 +5,6 @@ T = TypeVar("T")
def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]: def execute_futures_dict(fs_dict: dict[str, futures.Future[T]]) -> dict[str, T]:
futures.wait(fs_dict.values(), futures.wait(fs_dict.values(), timeout=None, return_when=futures.ALL_COMPLETED)
timeout=None,
return_when=futures.ALL_COMPLETED)
return {key: future.result() for key, future in fs_dict.items()} return {key: future.result() for key, future in fs_dict.items()}

@ -4,7 +4,8 @@ import hashlib
def dataframe_to_hash(dataframe: pd.DataFrame) -> str: def dataframe_to_hash(dataframe: pd.DataFrame) -> str:
return hashlib.sha256( return hashlib.sha256(
pd.util.hash_pandas_object(dataframe, index=True).values).hexdigest() pd.util.hash_pandas_object(dataframe, index=True).values
).hexdigest()
def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str: def str_to_hash(text: str, hash_algorithm: str = "sha256") -> str:

@ -51,16 +51,16 @@ def get_new_image_name(org_img_name, func_name="update"):
if len(name_split) == 1: if len(name_split) == 1:
most_org_file_name = name_split[0] most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name, new_file_name = "{}_{}_{}_{}.png".format(
recent_prev_file_name, this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
most_org_file_name) )
else: else:
assert len(name_split) == 4 assert len(name_split) == 4
most_org_file_name = name_split[3] most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.png".format(this_new_uuid, func_name, new_file_name = "{}_{}_{}_{}.png".format(
recent_prev_file_name, this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
most_org_file_name) )
return os.path.join(head, new_file_name) return os.path.join(head, new_file_name)
@ -73,16 +73,16 @@ def get_new_dataframe_name(org_img_name, func_name="update"):
if len(name_split) == 1: if len(name_split) == 1:
most_org_file_name = name_split[0] most_org_file_name = name_split[0]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name, new_file_name = "{}_{}_{}_{}.csv".format(
recent_prev_file_name, this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
most_org_file_name) )
else: else:
assert len(name_split) == 4 assert len(name_split) == 4
most_org_file_name = name_split[3] most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0] recent_prev_file_name = name_split[0]
new_file_name = "{}_{}_{}_{}.csv".format(this_new_uuid, func_name, new_file_name = "{}_{}_{}_{}.csv".format(
recent_prev_file_name, this_new_uuid, func_name, recent_prev_file_name, most_org_file_name
most_org_file_name) )
return os.path.join(head, new_file_name) return os.path.join(head, new_file_name)
@ -92,7 +92,6 @@ def get_new_dataframe_name(org_img_name, func_name="update"):
class Code: class Code:
def __init__(self, value: int): def __init__(self, value: int):
self.value = value self.value = value
@ -101,7 +100,6 @@ class Code:
class Color(Code): class Color(Code):
def bg(self) -> "Color": def bg(self) -> "Color":
self.value += 10 self.value += 10
return self return self
@ -148,7 +146,6 @@ class Color(Code):
class Style(Code): class Style(Code):
@staticmethod @staticmethod
def reset() -> "Style": def reset() -> "Style":
return Style(0) return Style(0)
@ -205,8 +202,7 @@ def dim_multiline(message: str) -> str:
lines = message.split("\n") lines = message.split("\n")
if len(lines) <= 1: if len(lines) <= 1:
return lines[0] return lines[0]
return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to( return lines[0] + ANSI("\n... ".join([""] + lines[1:])).to(Color.black().bright())
Color.black().bright())
# +=============================> ANSI Ending # +=============================> ANSI Ending
@ -217,7 +213,6 @@ STATIC_DIR = "static"
class AbstractUploader(ABC): class AbstractUploader(ABC):
@abstractmethod @abstractmethod
def upload(self, filepath: str) -> str: def upload(self, filepath: str) -> str:
pass pass
@ -233,9 +228,7 @@ class AbstractUploader(ABC):
class S3Uploader(AbstractUploader): class S3Uploader(AbstractUploader):
def __init__(self, accessKey: str, secretKey: str, region: str, bucket: str):
def __init__(self, accessKey: str, secretKey: str, region: str,
bucket: str):
self.accessKey = accessKey self.accessKey = accessKey
self.secretKey = secretKey self.secretKey = secretKey
self.region = region self.region = region
@ -270,7 +263,6 @@ class S3Uploader(AbstractUploader):
class StaticUploader(AbstractUploader): class StaticUploader(AbstractUploader):
def __init__(self, server: str, path: Path, endpoint: str): def __init__(self, server: str, path: Path, endpoint: str):
self.server = server self.server = server
self.path = path self.path = path
@ -338,19 +330,16 @@ class FileType(Enum):
class BaseHandler: class BaseHandler:
def handle(self, filename: str) -> str: def handle(self, filename: str) -> str:
raise NotImplementedError raise NotImplementedError
class FileHandler: class FileHandler:
def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path): def __init__(self, handlers: Dict[FileType, BaseHandler], path: Path):
self.handlers = handlers self.handlers = handlers
self.path = path self.path = path
def register(self, filetype: FileType, def register(self, filetype: FileType, handler: BaseHandler) -> "FileHandler":
handler: BaseHandler) -> "FileHandler":
self.handlers[filetype] = handler self.handlers[filetype] = handler
return self return self
@ -358,8 +347,8 @@ class FileHandler:
filetype = FileType.from_url(url) filetype = FileType.from_url(url)
data = requests.get(url).content data = requests.get(url).content
local_filename = os.path.join( local_filename = os.path.join(
"file", "file", str(uuid.uuid4())[0:8] + filetype.to_extension()
str(uuid.uuid4())[0:8] + filetype.to_extension()) )
os.makedirs(os.path.dirname(local_filename), exist_ok=True) os.makedirs(os.path.dirname(local_filename), exist_ok=True)
with open(local_filename, "wb") as f: with open(local_filename, "wb") as f:
size = f.write(data) size = f.write(data)
@ -368,15 +357,17 @@ class FileHandler:
def handle(self, url: str) -> str: def handle(self, url: str) -> str:
try: try:
if url.startswith(os.environ.get("SERVER", if url.startswith(os.environ.get("SERVER", "http://localhost:8000")):
"http://localhost:8000")):
local_filepath = url[ local_filepath = url[
len(os.environ.get("SERVER", "http://localhost:8000")) + 1:] len(os.environ.get("SERVER", "http://localhost:8000")) + 1 :
]
local_filename = Path("file") / local_filepath.split("/")[-1] local_filename = Path("file") / local_filepath.split("/")[-1]
src = self.path / local_filepath src = self.path / local_filepath
dst = (self.path / dst = (
os.environ.get("PLAYGROUND_DIR", "./playground") / self.path
local_filename) / os.environ.get("PLAYGROUND_DIR", "./playground")
/ local_filename
)
os.makedirs(os.path.dirname(dst), exist_ok=True) os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.copy(src, dst) shutil.copy(src, dst)
else: else:
@ -386,7 +377,8 @@ class FileHandler:
if FileType.from_url(url) == FileType.IMAGE: if FileType.from_url(url) == FileType.IMAGE:
raise Exception( raise Exception(
f"No handler for {FileType.from_url(url)}. " f"No handler for {FileType.from_url(url)}. "
"Please set USE_GPU to True in env/settings.py") "Please set USE_GPU to True in env/settings.py"
)
else: else:
raise Exception(f"No handler for {FileType.from_url(url)}") raise Exception(f"No handler for {FileType.from_url(url)}")
return handler.handle(local_filename) return handler.handle(local_filename)
@ -400,17 +392,17 @@ class FileHandler:
class CsvToDataframe(BaseHandler): class CsvToDataframe(BaseHandler):
def handle(self, filename: str): def handle(self, filename: str):
df = pd.read_csv(filename) df = pd.read_csv(filename)
description = ( description = (
f"Dataframe with {len(df)} rows and {len(df.columns)} columns. " f"Dataframe with {len(df)} rows and {len(df.columns)} columns. "
"Columns are: " "Columns are: "
f"{', '.join(df.columns)}") f"{', '.join(df.columns)}"
)
print( print(
f"\nProcessed CsvToDataframe, Input CSV: {filename}, Output Description:" f"\nProcessed CsvToDataframe, Input CSV: {filename}, Output Description:"
f" {description}") f" {description}"
)
return DATAFRAME_PROMPT.format(filename=filename, return DATAFRAME_PROMPT.format(filename=filename, description=description)
description=description)

@ -7,6 +7,5 @@ def extract_code_in_backticks_in_string(message: str) -> str:
""" """
pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks pattern = r"`` ``(.*?)`` " # Non-greedy match between six backticks
match = re.search(pattern, message, match = re.search(pattern, message, re.DOTALL) # re.DOTALL to match newline chars
re.DOTALL) # re.DOTALL to match newline chars
return match.group(1).strip() if match else None return match.group(1).strip() if match else None

@ -49,12 +49,16 @@ def get_input(
""" """
Multiline input function. Multiline input function.
""" """
return (session.prompt( return (
completer=completer, session.prompt(
multiline=True, completer=completer,
auto_suggest=AutoSuggestFromHistory(), multiline=True,
key_bindings=key_bindings, auto_suggest=AutoSuggestFromHistory(),
) if session else prompt(multiline=True)) key_bindings=key_bindings,
)
if session
else prompt(multiline=True)
)
async def get_input_async( async def get_input_async(
@ -64,11 +68,15 @@ async def get_input_async(
""" """
Multiline input function. Multiline input function.
""" """
return (await session.prompt_async( return (
completer=completer, await session.prompt_async(
multiline=True, completer=completer,
auto_suggest=AutoSuggestFromHistory(), multiline=True,
) if session else prompt(multiline=True)) auto_suggest=AutoSuggestFromHistory(),
)
if session
else prompt(multiline=True)
)
def get_filtered_keys_from_object(obj: object, *keys: str) -> any: def get_filtered_keys_from_object(obj: object, *keys: str) -> any:
@ -86,7 +94,9 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any:
return {key for key in class_keys if key not in keys[1:]} return {key for key in class_keys if key not in keys[1:]}
# Check if all passed keys are valid # Check if all passed keys are valid
if invalid_keys := set(keys) - class_keys: if invalid_keys := set(keys) - class_keys:
raise ValueError(f"Invalid keys: {invalid_keys}",) raise ValueError(
f"Invalid keys: {invalid_keys}",
)
# Only return specified keys that are in class_keys # Only return specified keys that are in class_keys
return {key for key in keys if key in class_keys} return {key for key in keys if key in class_keys}
@ -114,8 +124,8 @@ def random_int(min: int, max: int) -> int:
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig( logging.basicConfig(
format= format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",
"%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save