feat: Add swarms tools

Former-commit-id: a851f7057bf636d02e0ab29b545f2bbc85fc8c2d
pull/160/head
Zack 1 year ago
parent 91444c650a
commit a1edd8e6cf

1
.gitignore vendored

@ -155,6 +155,7 @@ venv/
ENV/
env.bak/
venv.bak/
secret_keys.sh
# Spyder project settings
.spyderproject

@ -0,0 +1,17 @@
import os
from swarms.models import OpenAIChat
from swarms.structs.flow import Flow, stop_when_repeats
from dotenv import load_dotenv
load_dotenv()
# Initialize the OpenAIChat model
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = OpenAIChat(openai_api_key=openai_api_key)
# Initialize the Flow
flow = Flow(llm=llm, max_loops=3, stopping_condition=stop_when_repeats)
# Run the Flow with a task
response = flow.run("")
print(response)

@ -0,0 +1,173 @@
from swarms.tools.serve import ToolServer
def run_tool_server():
# def load_weather_tool():
# WEATHER_API_KEYS = os.environ.get('WEATHER_API_KEYS', None)
# if not WEATHER_API_KEYS:
# raise RuntimeError("WEATHER_API_KEYS not provided, please register one from https://www.weatherapi.com/ and add it to environment variables.")
# server.load_tool("weather", {"subscription_key": WEATHER_API_KEYS})
# def load_database_tool():
# server.load_tool("database")
# def load_db_diag_tool():
# server.load_tool("db_diag")
def load_chemical_prop_tool():
server.load_tool("chemical-prop")
def load_douban_tool():
server.load_tool("douban-film")
def load_wikipedia_tool():
server.load_tool("wikipedia")
# def load_wikidata_tool():
# server.load_tool("wikidata")
def load_travel_tool():
server.load_tool("travel")
# def load_wolframalpha_tool():
# WOLFRAMALPH_APP_ID = os.environ.get("WOLFRAMALPH_APP_ID", None)
# if not WOLFRAMALPH_APP_ID:
# raise RuntimeError("WOLFRAMALPH_APP_ID not provided, please register one from https://products.wolframalpha.com/api/ and add it to environment variables.")
# server.load_tool("wolframalpha", {"subscription_key": WOLFRAMALPH_APP_ID})
# def load_bing_search_tool():
# BING_SUBSCRIPT_KEY = os.environ.get('BING_SUBSCRIPT_KEY', None)
# if not BING_SUBSCRIPT_KEY:
# raise RuntimeError("Bing search key not provided, please register one from https://www.microsoft.com/en-us/bing/apis/bing-web-search-api and add it to environment variables.")
# server.load_tool("bing_search", {"subscription_key": BING_SUBSCRIPT_KEY})
# def load_office_ppt_tool():
# server.load_tool("office-ppt")
# def load_alpha_vantage_tool():
# ALPHA_VANTAGE_KEY = os.environ.get('ALPHA_VANTAGE_KEY', None)
# if not ALPHA_VANTAGE_KEY:
# raise RuntimeError("Stock key not provided, please register one from https://www.alphavantage.co/support/#api-key and add it to environment variables.")
# server.load_tool("stock", {"subscription_key": ALPHA_VANTAGE_KEY})
# def load_map_tool():
# BING_MAP_KEY = os.environ.get('BING_MAP_KEY', None)
# if not BING_MAP_KEY:
# raise RuntimeError("Bing map key not provided, please register one from https://www.bingmapsportal.com/ and add it to environment variables.")
# server.load_tool("bing_map", {"subscription_key": BING_MAP_KEY})
# baidu map tool
# BAIDU_SECRET_KEY = os.environ.get('BAIDU_SECRET_KEY', None)
# BAIDU_MAP_KEY = os.environ.get('BAIDU_MAP_KEY', None)
# if not BAIDU_SECRET_KEY or not BAIDU_MAP_KEY:
# raise RuntimeError("Baidu map key not provided, please register one from https://lbsyun.baidu.com/apiconsole/key and add it to environment variables.")
# server.load_tool("baidu_map", {"subscription_key": BAIDU_MAP_KEY, "baidu_secret_key": BAIDU_SECRET_KEY})
# def load_rapidapi_tool():
# RAPIDAPI_KEY = os.environ.get('RAPIDAPI_KEY', None)
# if not RAPIDAPI_KEY:
# raise RuntimeError("RAPIDAPI_KEY not provided, please register one from https://rapidapi.com/ and add it to environment variables.")
# server.load_tool("zillow", {"subscription_key": RAPIDAPI_KEY})
# server.load_tool("airbnb", {"subscription_key": RAPIDAPI_KEY})
# server.load_tool("job_search", {"subscription_key": RAPIDAPI_KEY})
# def load_nllb_translation_tool():
# server.load_tool("nllb-translation")
# def load_baidu_translation_tool():
# server.load_tool("baidu-translation")
def load_tutorial_tool():
server.load_tool("tutorial")
def load_file_operation_tool():
server.load_tool("file_operation")
def load_meta_analysis_tool():
server.load_tool("meta_analysis")
def load_code_interpreter_tool():
server.load_tool("code_interpreter")
def load_arxiv_tool():
server.load_tool("arxiv")
# def load_google_places_tool():
# GPLACES_API_KEY = os.environ.get('GPLACES_API_KEY', '')
# if not GPLACES_API_KEY:
# raise RuntimeError("GPLACES_API_KEY not provided, please register one from https://developers.google.com/maps/documentation/elevation/get-api-key and add it to environment variables.")
# server.load_tool("google_places", {"subscription_key": GPLACES_API_KEY})
# def load_google_serper_tool():
# SERPER_API_KEY = os.environ.get('SERPER_API_KEY', None)
# if not SERPER_API_KEY:
# raise RuntimeError("SERPER_API_KEY not provided, please register one from https://serper.dev and add it to environment variables.")
# server.load_tool("google_serper", {"subscription_key": SERPER_API_KEY})
# server.load_tool("google_scholar", {"subscription_key": SERPER_API_KEY})
# server.load_tool("walmart", {"subscription_key": SERPER_API_KEY})
def load_python_tool():
server.load_tool("python")
# def load_sceneXplain_tool():
# SCENEX_API_KEY = os.environ.get('SCENEX_API_KEY', None)
# if not SCENEX_API_KEY:
# raise RuntimeError("SCENEX_API_KEY is not provided. Please sign up for a free account at https://scenex.jina.ai/, create a new API key, and add it to environment variables.")
# server.load_tool("sceneXplain", {"subscription_key": SCENEX_API_KEY})
def load_shell_tool():
server.load_tool("shell")
# def load_image_generation_tool():
# STEAMSHIP_API_KEY = os.environ.get('STEAMSHIP_API_KEY', None)
# if not STEAMSHIP_API_KEY:
# raise RuntimeError("STEAMSHIP_API_KEY is not provided. Please sign up for a free account at https://steamship.com/account/api, create a new API key, and add it to environment variables.")
# server.load_tool("image_generation")
# def load_hugging_tools():
# HUGGINGFACE_API_KEY = os.environ.get('HUGGINGFACE_API_KEY', None)
# if not HUGGINGFACE_API_KEY:
# raise RuntimeError("Huggingface api key (access tokens) not provided, please register one from https://huggingface.co/ and add it to environment variables.")
# server.load_tool("hugging_tools")
def load_gradio_tools():
server.load_tool("gradio_tools")
server = ToolServer()
print(server.list_tools())
# tool_choice = input("Enter 'ALL' to load all tools, or enter the specific tools you want to load (comma-separated): ")
# load_weather_tool()
# load_database_tool()
# load_db_diag_tool()
load_chemical_prop_tool()
load_douban_tool()
load_wikipedia_tool()
# load_wikidata_tool()
# load_wolframalpha_tool()
# load_bing_search_tool()
# load_office_ppt_tool()
# load_alpha_vantage_tool()
# load_map_tool()
# load_rapidapi_tool()
# load_nllb_translation_tool()
# load_baidu_translation_tool()
load_tutorial_tool()
load_file_operation_tool()
load_meta_analysis_tool()
load_code_interpreter_tool()
load_arxiv_tool()
# load_google_places_tool()
# load_google_serper_tool()
load_python_tool()
# load_sceneXplain_tool()
load_shell_tool()
# load_image_generation_tool()
# load_hugging_tools()
load_gradio_tools()
load_travel_tool()
server.serve()
if __name__ == "__main__":
run_tool_server()

@ -0,0 +1,173 @@
from swarms.tools.serve import ToolServer
def run_tool_server():
# def load_weather_tool():
# WEATHER_API_KEYS = os.environ.get('WEATHER_API_KEYS', None)
# if not WEATHER_API_KEYS:
# raise RuntimeError("WEATHER_API_KEYS not provided, please register one from https://www.weatherapi.com/ and add it to environment variables.")
# server.load_tool("weather", {"subscription_key": WEATHER_API_KEYS})
# def load_database_tool():
# server.load_tool("database")
# def load_db_diag_tool():
# server.load_tool("db_diag")
def load_chemical_prop_tool():
server.load_tool("chemical-prop")
def load_douban_tool():
server.load_tool("douban-film")
def load_wikipedia_tool():
server.load_tool("wikipedia")
# def load_wikidata_tool():
# server.load_tool("wikidata")
def load_travel_tool():
server.load_tool("travel")
# def load_wolframalpha_tool():
# WOLFRAMALPH_APP_ID = os.environ.get("WOLFRAMALPH_APP_ID", None)
# if not WOLFRAMALPH_APP_ID:
# raise RuntimeError("WOLFRAMALPH_APP_ID not provided, please register one from https://products.wolframalpha.com/api/ and add it to environment variables.")
# server.load_tool("wolframalpha", {"subscription_key": WOLFRAMALPH_APP_ID})
# def load_bing_search_tool():
# BING_SUBSCRIPT_KEY = os.environ.get('BING_SUBSCRIPT_KEY', None)
# if not BING_SUBSCRIPT_KEY:
# raise RuntimeError("Bing search key not provided, please register one from https://www.microsoft.com/en-us/bing/apis/bing-web-search-api and add it to environment variables.")
# server.load_tool("bing_search", {"subscription_key": BING_SUBSCRIPT_KEY})
# def load_office_ppt_tool():
# server.load_tool("office-ppt")
# def load_alpha_vantage_tool():
# ALPHA_VANTAGE_KEY = os.environ.get('ALPHA_VANTAGE_KEY', None)
# if not ALPHA_VANTAGE_KEY:
# raise RuntimeError("Stock key not provided, please register one from https://www.alphavantage.co/support/#api-key and add it to environment variables.")
# server.load_tool("stock", {"subscription_key": ALPHA_VANTAGE_KEY})
# def load_map_tool():
# BING_MAP_KEY = os.environ.get('BING_MAP_KEY', None)
# if not BING_MAP_KEY:
# raise RuntimeError("Bing map key not provided, please register one from https://www.bingmapsportal.com/ and add it to environment variables.")
# server.load_tool("bing_map", {"subscription_key": BING_MAP_KEY})
# baidu map tool
# BAIDU_SECRET_KEY = os.environ.get('BAIDU_SECRET_KEY', None)
# BAIDU_MAP_KEY = os.environ.get('BAIDU_MAP_KEY', None)
# if not BAIDU_SECRET_KEY or not BAIDU_MAP_KEY:
# raise RuntimeError("Baidu map key not provided, please register one from https://lbsyun.baidu.com/apiconsole/key and add it to environment variables.")
# server.load_tool("baidu_map", {"subscription_key": BAIDU_MAP_KEY, "baidu_secret_key": BAIDU_SECRET_KEY})
# def load_rapidapi_tool():
# RAPIDAPI_KEY = os.environ.get('RAPIDAPI_KEY', None)
# if not RAPIDAPI_KEY:
# raise RuntimeError("RAPIDAPI_KEY not provided, please register one from https://rapidapi.com/ and add it to environment variables.")
# server.load_tool("zillow", {"subscription_key": RAPIDAPI_KEY})
# server.load_tool("airbnb", {"subscription_key": RAPIDAPI_KEY})
# server.load_tool("job_search", {"subscription_key": RAPIDAPI_KEY})
# def load_nllb_translation_tool():
# server.load_tool("nllb-translation")
# def load_baidu_translation_tool():
# server.load_tool("baidu-translation")
def load_tutorial_tool():
server.load_tool("tutorial")
def load_file_operation_tool():
server.load_tool("file_operation")
def load_meta_analysis_tool():
server.load_tool("meta_analysis")
def load_code_interpreter_tool():
server.load_tool("code_interpreter")
def load_arxiv_tool():
server.load_tool("arxiv")
# def load_google_places_tool():
# GPLACES_API_KEY = os.environ.get('GPLACES_API_KEY', '')
# if not GPLACES_API_KEY:
# raise RuntimeError("GPLACES_API_KEY not provided, please register one from https://developers.google.com/maps/documentation/elevation/get-api-key and add it to environment variables.")
# server.load_tool("google_places", {"subscription_key": GPLACES_API_KEY})
# def load_google_serper_tool():
# SERPER_API_KEY = os.environ.get('SERPER_API_KEY', None)
# if not SERPER_API_KEY:
# raise RuntimeError("SERPER_API_KEY not provided, please register one from https://serper.dev and add it to environment variables.")
# server.load_tool("google_serper", {"subscription_key": SERPER_API_KEY})
# server.load_tool("google_scholar", {"subscription_key": SERPER_API_KEY})
# server.load_tool("walmart", {"subscription_key": SERPER_API_KEY})
def load_python_tool():
server.load_tool("python")
# def load_sceneXplain_tool():
# SCENEX_API_KEY = os.environ.get('SCENEX_API_KEY', None)
# if not SCENEX_API_KEY:
# raise RuntimeError("SCENEX_API_KEY is not provided. Please sign up for a free account at https://scenex.jina.ai/, create a new API key, and add it to environment variables.")
# server.load_tool("sceneXplain", {"subscription_key": SCENEX_API_KEY})
def load_shell_tool():
server.load_tool("shell")
# def load_image_generation_tool():
# STEAMSHIP_API_KEY = os.environ.get('STEAMSHIP_API_KEY', None)
# if not STEAMSHIP_API_KEY:
# raise RuntimeError("STEAMSHIP_API_KEY is not provided. Please sign up for a free account at https://steamship.com/account/api, create a new API key, and add it to environment variables.")
# server.load_tool("image_generation")
# def load_hugging_tools():
# HUGGINGFACE_API_KEY = os.environ.get('HUGGINGFACE_API_KEY', None)
# if not HUGGINGFACE_API_KEY:
# raise RuntimeError("Huggingface api key (access tokens) not provided, please register one from https://huggingface.co/ and add it to environment variables.")
# server.load_tool("hugging_tools")
def load_gradio_tools():
server.load_tool("gradio_tools")
server = ToolServer()
print(server.list_tools())
# tool_choice = input("Enter 'ALL' to load all tools, or enter the specific tools you want to load (comma-separated): ")
# load_weather_tool()
# load_database_tool()
# load_db_diag_tool()
load_chemical_prop_tool()
load_douban_tool()
load_wikipedia_tool()
# load_wikidata_tool()
# load_wolframalpha_tool()
# load_bing_search_tool()
# load_office_ppt_tool()
# load_alpha_vantage_tool()
# load_map_tool()
# load_rapidapi_tool()
# load_nllb_translation_tool()
# load_baidu_translation_tool()
load_tutorial_tool()
load_file_operation_tool()
load_meta_analysis_tool()
load_code_interpreter_tool()
load_arxiv_tool()
# load_google_places_tool()
# load_google_serper_tool()
load_python_tool()
# load_sceneXplain_tool()
load_shell_tool()
# load_image_generation_tool()
# load_hugging_tools()
load_gradio_tools()
load_travel_tool()
server.serve()
if __name__ == "__main__":
run_tool_server()

@ -0,0 +1,28 @@
from bmtools.agent.tools_controller import load_valid_tools, MTQuestionAnswerer
import jsonlines
# Choose the tools that you need
tools_mappings = {
#"klarna": "https://www.klarna.com/",
#"chemical-prop": "http://127.0.0.1:8079/tools/chemical-prop/",
"wolframalpha": "http://127.0.0.1:8079/tools/wolframalpha/",
#"meta_analysis": "http://127.0.0.1:8079/tools/meta_analysis/",
#"map": "http://127.0.0.1:8079/tools/map/",
#"douban": "http://127.0.0.1:8079/tools/douban-film/",
#"weather": "http://127.0.0.1:8079/tools/weather/",
"office-ppt": "http://127.0.0.1:8079/tools/office-ppt/",
"wikipedia": "http://127.0.0.1:8079/tools/wikipedia/",
#"nllb-translation": "http://127.0.0.1:8079/tools/nllb-translation/",
"file_operation": "http://127.0.0.1:8079/tools/file_operation/",
"bing_search": "http://127.0.0.1:8079/tools/bing_search/",
}
tools = load_valid_tools(tools_mappings)
qa = MTQuestionAnswerer(all_tools=tools)
agent = qa.build_runner()
agent(["Who's the main actress of Titanic? What did she do apart from this film? Help me make slides with this information."])
#agent(['I want to go to Berkeley for one-week vacation. Please help me recommend some tourisms, restaurants, as well as the recent weather conditions for the place.'])
#agent(["How many benzene rings are there in 9H-Carbazole-3-carboxaldehyde? and what is sin(x)*exp(x)'s plot, what is it integrated from 0 to 1? "])

@ -0,0 +1,42 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
# Langchain
tool_name, tool_url = 'weather', "http://127.0.0.1:8079/tools/weather/"
tool_name, tool_config = load_single_tools(tool_name, tool_url)
print(tool_name, tool_config)
stqa = STQuestionAnswerer()
agent = stqa.load_tools(tool_name, tool_config, prompt_type="react-with-tool-description")
agent("write a weather report for SF today")
# BabyAGI
# tool_name, tool_url = 'weather', "http://127.0.0.1:8079/tools/weather/"
# tool_name, tool_config = load_single_tools(tool_name, tool_url)
# print(tool_name, tool_config)
# stqa = STQuestionAnswerer()
# agent = stqa.load_tools(tool_name, tool_config, prompt_type="babyagi")
# agent("write a weather report for SF today")
# Auto-GPT
# tool_name, tool_url = 'weather', "http://127.0.0.1:8079/tools/weather/"
# tool_name, tool_config = load_single_tools(tool_name, tool_url)
# print(tool_name, tool_config)
# stqa = STQuestionAnswerer()
# agent = stqa.load_tools(tool_name, tool_config, prompt_type="autogpt")
# agent.run(["write a weather report for SF today"])
"""
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'wikipedia', "http://127.0.0.1:8079/tools/wikipedia/"
tool_name, tool_config = load_single_tools(tool_name, tool_url)
print(tool_name, tool_config)
stqa = STQuestionAnswerer()
agent = stqa.load_tools(tool_name, tool_config, prompt_type="babyagi")
# agent = stqa.load_tools(tool_name, tool_config, prompt_type="react-with-tool-description")# prompt_type="babyagi")
agent("Where is Yaoming Born?")
"""

@ -0,0 +1,13 @@
from bmtools.agent.tools_controller import load_valid_tools, MTQuestionAnswerer
tools_mappings = {
'weather': "http://127.0.0.1:8079/tools/weather/",
'file_operation': "http://127.0.0.1:8079/tools/file_operation/",
}
tools = load_valid_tools(tools_mappings)
qa = MTQuestionAnswerer(openai_api_key='', all_tools=tools)
agent = qa.build_runner()
agent("what is the weather in Beijing?")

@ -0,0 +1,184 @@
import gradio as gr
from bmtools.agent.tools_controller import MTQuestionAnswerer, load_valid_tools
from bmtools.agent.singletool import STQuestionAnswerer
from langchain.schema import AgentFinish
import os
import requests
available_models = ["ChatGPT", "GPT-3.5"]
DEFAULTMODEL = "ChatGPT" # "GPT-3.5"
tools_mappings = {
"klarna": "https://www.klarna.com/",
"weather": "http://127.0.0.1:8079/tools/weather/",
# "database": "http://127.0.0.1:8079/tools/database/",
# "db_diag": "http://127.0.0.1:8079/tools/db_diag/",
"chemical-prop": "http://127.0.0.1:8079/tools/chemical-prop/",
"douban-film": "http://127.0.0.1:8079/tools/douban-film/",
"wikipedia": "http://127.0.0.1:8079/tools/wikipedia/",
# "wikidata": "http://127.0.0.1:8079/tools/wikidata/",
"wolframalpha": "http://127.0.0.1:8079/tools/wolframalpha/",
"bing_search": "http://127.0.0.1:8079/tools/bing_search/",
"office-ppt": "http://127.0.0.1:8079/tools/office-ppt/",
"stock": "http://127.0.0.1:8079/tools/stock/",
"bing_map": "http://127.0.0.1:8079/tools/bing_map/",
# "baidu_map": "http://127.0.0.1:8079/tools/baidu_map/",
"zillow": "http://127.0.0.1:8079/tools/zillow/",
"airbnb": "http://127.0.0.1:8079/tools/airbnb/",
"job_search": "http://127.0.0.1:8079/tools/job_search/",
# "baidu-translation": "http://127.0.0.1:8079/tools/baidu-translation/",
# "nllb-translation": "http://127.0.0.1:8079/tools/nllb-translation/",
"tutorial": "http://127.0.0.1:8079/tools/tutorial/",
"file_operation": "http://127.0.0.1:8079/tools/file_operation/",
"meta_analysis": "http://127.0.0.1:8079/tools/meta_analysis/",
"code_interpreter": "http://127.0.0.1:8079/tools/code_interpreter/",
"arxiv": "http://127.0.0.1:8079/tools/arxiv/",
"google_places": "http://127.0.0.1:8079/tools/google_places/",
"google_serper": "http://127.0.0.1:8079/tools/google_serper/",
"google_scholar": "http://127.0.0.1:8079/tools/google_scholar/",
"python": "http://127.0.0.1:8079/tools/python/",
"sceneXplain": "http://127.0.0.1:8079/tools/sceneXplain/",
"shell": "http://127.0.0.1:8079/tools/shell/",
"image_generation": "http://127.0.0.1:8079/tools/image_generation/",
"hugging_tools": "http://127.0.0.1:8079/tools/hugging_tools/",
"gradio_tools": "http://127.0.0.1:8079/tools/gradio_tools/",
}
valid_tools_info = load_valid_tools(tools_mappings)
print(valid_tools_info)
all_tools_list = sorted(list(valid_tools_info.keys()))
gr.close_all()
MAX_TURNS = 30
MAX_BOXES = MAX_TURNS * 2
return_msg = []
chat_history = ""
def show_avatar_imgs(tools_chosen):
if len(tools_chosen) == 0:
tools_chosen = list(valid_tools_info.keys())
img_template = '<a href="{}" style="float: left"> <img style="margin:5px" src="{}.png" width="24" height="24" alt="avatar" /> {} </a>'
imgs = [valid_tools_info[tool]['avatar'] for tool in tools_chosen if valid_tools_info[tool]['avatar'] != None]
imgs = ' '.join([img_template.format(img, img, tool ) for img, tool in zip(imgs, tools_chosen) ])
return [gr.update(value='<span class="">'+imgs+'</span>', visible=True), gr.update(visible=True)]
def answer_by_tools(question, tools_chosen, model_chosen):
global return_msg
return_msg += [(question, None), (None, '...')]
yield [gr.update(visible=True, value=return_msg), gr.update(), gr.update()]
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
if len(tools_chosen) == 0: # if there is no tools chosen, we use all todo (TODO: What if the pool is too large.)
tools_chosen = list(valid_tools_info.keys())
if len(tools_chosen) == 1:
answerer = STQuestionAnswerer(OPENAI_API_KEY.strip(), stream_output=True, llm=model_chosen)
agent_executor = answerer.load_tools(tools_chosen[0], valid_tools_info[tools_chosen[0]], prompt_type="react-with-tool-description", return_intermediate_steps=True)
else:
answerer = MTQuestionAnswerer(OPENAI_API_KEY.strip(), load_valid_tools({k: tools_mappings[k] for k in tools_chosen}), stream_output=True, llm=model_chosen)
agent_executor = answerer.build_runner()
global chat_history
chat_history += "Question: " + question + "\n"
question = chat_history
for inter in agent_executor(question):
if isinstance(inter, AgentFinish): continue
result_str = []
return_msg.pop()
if isinstance(inter, dict):
result_str.append("<font color=red>Answer:</font> {}".format(inter['output']))
chat_history += "Answer:" + inter['output'] + "\n"
result_str.append("...")
else:
not_observation = inter[0].log
if not not_observation.startswith('Thought:'):
not_observation = "Thought: " + not_observation
chat_history += not_observation
not_observation = not_observation.replace('Thought:', '<font color=green>Thought: </font>')
not_observation = not_observation.replace('Action:', '<font color=purple>Action: </font>')
not_observation = not_observation.replace('Action Input:', '<font color=purple>Action Input: </font>')
result_str.append("{}".format(not_observation))
result_str.append("<font color=blue>Action output:</font>\n{}".format(inter[1]))
chat_history += "\nAction output:" + inter[1] + "\n"
result_str.append("...")
return_msg += [(None, result) for result in result_str]
yield [gr.update(visible=True, value=return_msg), gr.update(), gr.update()]
return_msg.pop()
if return_msg[-1][1].startswith("<font color=red>Answer:</font> "):
return_msg[-1] = (return_msg[-1][0], return_msg[-1][1].replace("<font color=red>Answer:</font> ", "<font color=green>Final Answer:</font> "))
yield [gr.update(visible=True, value=return_msg), gr.update(visible=True), gr.update(visible=False)]
def retrieve(tools_search):
if tools_search == "":
return gr.update(choices=all_tools_list)
else:
url = "http://127.0.0.1:8079/retrieve"
param = {
"query": tools_search
}
response = requests.post(url, json=param)
result = response.json()
retrieved_tools = result["tools"]
return gr.update(choices=retrieved_tools)
def clear_retrieve():
return [gr.update(value=""), gr.update(choices=all_tools_list)]
def clear_history():
global return_msg
global chat_history
return_msg = []
chat_history = ""
yield gr.update(visible=True, value=return_msg)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=14):
gr.Markdown("<h1 align='left'> Swarm Tools </h1>")
with gr.Column(scale=1):
gr.Image('../../images/swarmslogobanner.png', show_download_button=False, show_label=False )
# gr.Markdown('<img src="../../images/swarmslogobanner.png" alt="swarms">')
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="Question here. Use Shift+Enter to add new line.", lines=1).style(container=False)
with gr.Column(scale=0.15, min_width=0):
buttonChat = gr.Button("Chat")
chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600)
buttonClear = gr.Button("Clear History")
buttonStop = gr.Button("Stop", visible=False)
with gr.Column(scale=1):
model_chosen = gr.Dropdown(
list(available_models), value=DEFAULTMODEL, multiselect=False, label="Model provided",
info="Choose the model to solve your question, Default means ChatGPT."
)
with gr.Row():
tools_search = gr.Textbox(
lines=1,
label="Tools Search",
placeholder="Please input some text to search tools.",
)
buttonSearch = gr.Button("Reset search condition")
tools_chosen = gr.CheckboxGroup(
choices=all_tools_list,
value=["chemical-prop"],
label="Tools provided",
info="Choose the tools to solve your question.",
)
tools_search.change(retrieve, tools_search, tools_chosen)
buttonSearch.click(clear_retrieve, [], [tools_search, tools_chosen])
txt.submit(lambda : [gr.update(value=''), gr.update(visible=False), gr.update(visible=True)], [], [txt, buttonClear, buttonStop])
inference_event = txt.submit(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
buttonChat.click(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
buttonStop.click(lambda : [gr.update(visible=True), gr.update(visible=False)], [], [buttonClear, buttonStop], cancels=[inference_event])
buttonClear.click(clear_history, [], chatbot)
demo.queue().launch(share=False, inbrowser=True, server_name="127.0.0.1", server_port=7001)

@ -68,7 +68,6 @@ cohere
torchvision
rich
mkdocs
mkdocs-material
mkdocs-glightbox

@ -20,9 +20,18 @@ from swarms.models.kosmos_two import Kosmos
from swarms.models.vilt import Vilt
from swarms.models.nougat import Nougat
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA
from swarms.models.gpt4v import GPT4Vision
from swarms.models.dalle3 import Dalle3
from swarms.models.distilled_whisperx import DistilWhisperModel
# from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel
# from swarms.models.fuyu import Fuyu # Not working, wait until they update
import sys
# log_file = open("errors.txt", "w")
# sys.stderr = log_file
__all__ = [
"Anthropic",

@ -0,0 +1,51 @@
#!/usr/bin/env python
# coding=utf-8
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
import torch
from cpm_live.generation.bee import CPMBeeBeamSearch
from cpm_live.models import CPMBeeTorch, CPMBeeConfig
from cpm_live.tokenizers import CPMBeeTokenizer
class CpmBeeLLM(LLM):
model_name : str = ""
config: CPMBeeConfig = None
tokenizer: CPMBeeTokenizer = None
model: CPMBeeTorch = None
def __init__(self, config_path: str, ckpt_path: str, device: str="cuda") -> None:
super().__init__()
self.model_name = ckpt_path
self.config = CPMBeeConfig.from_json_file(config_path)
self.tokenizer = CPMBeeTokenizer()
self.model = CPMBeeTorch(config=self.config)
self.model.load_state_dict(torch.load(ckpt_path))
if device == "cuda":
self.model.cuda()
@property
def _llm_type(self) -> str:
return self.model_name
def _call(self, prompt, stop: Optional[List[str]] = None) -> str:
# use beam search
beam_search = CPMBeeBeamSearch(
model=self.model,
tokenizer=self.tokenizer,
)
inference_results = beam_search.generate([{"source":prompt, "<ans>":""}], max_length=512, repetition_penalty=1.2, beam_size=1)
output = inference_results[0]["<ans>"]
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name}
if __name__ == "__main__":
llm = CpmBeeLLM(config_path="path/to/cpm-bee/config.json", ckpt_path="path/to/cpm-bee/checkpoint/")
print(llm("You are an task creation AI that uses the result of an execution agent to create new tasks with the following objective: What's the weather in Shanghai today? Should I bring an umbrella?, The last completed task has the result: According to the weather report, it is sunny in Shanghai today and there is no precipitation, so you do not need to bring an umbrella.. This result was based on this task description: Make a todo list about this objective: What's the weather in Shanghai today? Should I bring an umbrella?. These are incomplete tasks: . Based on the result, create new tasks to be completed by the AI system that do not overlap with incomplete tasks. Do not generate repetitive tasks (e.g., tasks that have already been completed). If there is not futher task needed to complete the objective, only return NO TASK. Now return the tasks as an array."))

@ -0,0 +1,68 @@
#!/usr/bin/env python
# coding=utf-8
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
class LoraModel(LLM):
model_name: str = ""
tokenizer: AutoTokenizer = None
model: PeftModel = None
use_gpu: bool = True
def __init__(self, base_name_or_path: str, model_name_or_path: str, device: str="cuda", cpu_offloading: bool=False, load_8bit: bool=False) -> None:
super().__init__()
self.model_name = model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(base_name_or_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
base_name_or_path,
load_in_8bit=load_8bit,
device_map="auto"
)
self.model = PeftModel.from_pretrained(
model,
model_name_or_path
)
if self.tokenizer.pad_token_id == None:
self.tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"})
self.model.resize_token_embeddings(len(self.tokenizer))
self.use_gpu = (True if device == "cuda" else False)
if (device == "cuda" and not cpu_offloading) or device == "mps":
self.model.to(device)
@property
def _llm_type(self) -> str:
return self.model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
inputs = self.tokenizer(
prompt,
padding=True,
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
inputs_len = inputs["input_ids"].shape[1]
generated_outputs = self.model.generate(
input_ids=(inputs["input_ids"].cuda() if self.use_gpu else inputs["input_ids"]),
attention_mask=(inputs["attention_mask"].cuda() if self.use_gpu else inputs["attention_mask"]),
max_new_tokens=512,
eos_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
decoded_output = self.tokenizer.batch_decode(
generated_outputs[..., inputs_len:], skip_special_tokens=True)
output = decoded_output[0]
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name}
if __name__ == "__main__":
llm = LoraModel(base_name_or_path="huggyllama/llama-7b", model_name_or_path="pooruss-lsh/tool-llama7b-single-tool-lora")
print(llm("You are an task creation AI that uses the result of an execution agent to create new tasks with the following objective: What's the weather in Shanghai today? Should I bring an umbrella?, The last completed task has the result: According to the weather report, it is sunny in Shanghai today and there is no precipitation, so you do not need to bring an umbrella.. This result was based on this task description: Make a todo list about this objective: What's the weather in Shanghai today? Should I bring an umbrella?. These are incomplete tasks: . Based on the result, create new tasks to be completed by the AI system that do not overlap with incomplete tasks. Do not generate repetitive tasks (e.g., tasks that have already been completed). If there is not futher task needed to complete the objective, only return NO TASK. Now return the tasks as an array."))

@ -0,0 +1,54 @@
#!/usr/bin/env python
# coding=utf-8
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
from transformers import AutoTokenizer, OPTForCausalLM
class OPTModel(LLM):
model_name: str = ""
tokenizer: AutoTokenizer = None
model: OPTForCausalLM = None
def __init__(self, huggingface_model_name: str) -> None:
super().__init__()
self.model_name = huggingface_model_name
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = OPTForCausalLM.from_pretrained(self.model_name)
@property
def _llm_type(self) -> str:
return self.model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
inputs = self.tokenizer(
prompt,
padding=True,
max_length=512, # 512 by defaulttokenizer.model_max_length=1000000000000000019884624838656
truncation=True,
return_tensors="pt"
)
inputs_len = inputs["input_ids"].shape[1]
generated_outputs = self.model.generate(
inputs['input_ids'],
max_new_tokens=512,
)
decoded_output = self.tokenizer.batch_decode(
generated_outputs[..., inputs_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
output = decoded_output[0]
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name}
if __name__ == "__main__":
llm = OPTModel("facebook/opt-350m")
# print(llm("Hey, are you consciours? Can you talk to me?"))
print(llm("You are an task creation AI that uses the result of an execution agent to create new tasks with the following objective: What's the weather in Shanghai today? Should I bring an umbrella?, The last completed task has the result: According to the weather report, it is sunny in Shanghai today and there is no precipitation, so you do not need to bring an umbrella.. This result was based on this task description: Make a todo list about this objective: What's the weather in Shanghai today? Should I bring an umbrella?. These are incomplete tasks: . Based on the result, create new tasks to be completed by the AI system that do not overlap with incomplete tasks. Do not generate repetitive tasks (e.g., tasks that have already been completed). If there is not futher task needed to complete the objective, only return NO TASK. Now return the tasks as an array."))

@ -0,0 +1,54 @@
#!/usr/bin/env python
# coding=utf-8
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
from transformers import T5Tokenizer, T5ForConditionalGeneration
class T5Model(LLM):
model_name: str = ""
tokenizer: T5Tokenizer = None
model: T5ForConditionalGeneration = None
def __init__(self, huggingface_model_name: str) -> None:
super().__init__()
self.model_name = huggingface_model_name
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
@property
def _llm_type(self) -> str:
return self.model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
inputs = self.tokenizer(
prompt,
padding=True,
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
# inputs_len = inputs["input_ids"].shape[1]
generated_outputs = self.model.generate(
inputs["input_ids"],
max_new_tokens=512,
)
decoded_output = self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
output = decoded_output[0]
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"model_name": self.model_name}
if __name__ == "__main__":
llm = T5Model("t5-small")
# print(llm("translate English to German: The house is wonderful."))
print(llm("You are an task creation AI that uses the result of an execution agent to create new tasks with the following objective: What's the weather in Shanghai today? Should I bring an umbrella?, The last completed task has the result: According to the weather report, it is sunny in Shanghai today and there is no precipitation, so you do not need to bring an umbrella.. This result was based on this task description: Make a todo list about this objective: What's the weather in Shanghai today? Should I bring an umbrella?. These are incomplete tasks: . Based on the result, create new tasks to be completed by the AI system that do not overlap with incomplete tasks. Do not generate repetitive tasks (e.g., tasks that have already been completed). If there is not futher task needed to complete the objective, only return NO TASK. Now return the tasks as an array."))

@ -1,16 +1,16 @@
from swarms.swarms.dialogue_simulator import DialogueSimulator
from swarms.swarms.autoscaler import AutoScaler
from swarms.swarms.orchestrate import Orchestrator
from swarms.swarms.god_mode import GodMode
from swarms.swarms.simple_swarm import SimpleSwarm
from swarms.swarms.multi_agent_debate import MultiAgentDebate, select_speaker
# from swarms.swarms.dialogue_simulator import DialogueSimulator
# # from swarms.swarms.autoscaler import AutoScaler
# from swarms.swarms.orchestrate import Orchestrator
# from swarms.swarms.god_mode import GodMode
# from swarms.swarms.simple_swarm import SimpleSwarm
# from swarms.swarms.multi_agent_debate import MultiAgentDebate, select_speaker
__all__ = [
"DialogueSimulator",
"AutoScaler",
"Orchestrator",
"GodMode",
"SimpleSwarm",
"MultiAgentDebate",
"select_speaker",
]
# __all__ = [
# "DialogueSimulator",
# "AutoScaler",
# "Orchestrator",
# "GodMode",
# "SimpleSwarm",
# "MultiAgentDebate",
# "select_speaker",
# ]

@ -1,117 +1,117 @@
import queue
import threading
from time import sleep
from swarms.utils.decorators import error_decorator, log_decorator, timing_decorator
from swarms.structs.flow import Flow
class AutoScaler:
"""
The AutoScaler is like a kubernetes pod, that autoscales an agent or worker or boss!
# TODO Handle task assignment and task delegation
# TODO: User task => decomposed into very small sub tasks => sub tasks assigned to workers => workers complete and update the swarm, can ask for help from other agents.
# TODO: Missing, Task Assignment, Task delegation, Task completion, Swarm level communication with vector db
Args:
initial_agents (int, optional): Number of initial agents. Defaults to 10.
scale_up_factor (int, optional): Scale up factor. Defaults to 1.
idle_threshold (float, optional): Idle threshold. Defaults to 0.2.
busy_threshold (float, optional): Busy threshold. Defaults to 0.7.
agent ([type], optional): Agent. Defaults to None.
Methods:
add_task: Add task to queue
scale_up: Scale up
scale_down: Scale down
monitor_and_scale: Monitor and scale
start: Start scaling
del_agent: Delete an agent
Usage
```
# usage of usage
auto_scaler = AutoScaler(agent=YourCustomAgent)
auto_scaler.start()
for i in range(100):
auto_scaler.add_task9f"task {I}})
```
"""
@log_decorator
@error_decorator
@timing_decorator
def __init__(
self,
initial_agents=10,
scale_up_factor=1,
idle_threshold=0.2,
busy_threshold=0.7,
agent=None,
):
self.agent = agent or Flow
self.agents_pool = [self.agent() for _ in range(initial_agents)]
self.task_queue = queue.Queue()
self.scale_up_factor = scale_up_factor
self.idle_threshold = idle_threshold
self.lock = threading.Lock()
def add_task(self, task):
"""Add tasks to queue"""
self.tasks_queue.put(task)
@log_decorator
@error_decorator
@timing_decorator
def scale_up(self):
"""Add more agents"""
with self.lock:
new_agents_counts = len(self.agents_pool) * self.scale_up_factor
for _ in range(new_agents_counts):
self.agents_pool.append(Flow())
def scale_down(self):
"""scale down"""
with self.lock:
if len(self.agents_pool) > 10: # ensure minmum of 10 agents
del self.agents_pool[-1] # remove last agent
@log_decorator
@error_decorator
@timing_decorator
def monitor_and_scale(self):
"""Monitor and scale"""
while True:
sleep(60) # check minute
pending_tasks = self.task_queue.qsize()
active_agents = sum([1 for agent in self.agents_pool if agent.is_busy()])
if pending_tasks / len(self.agents_pool) > self.busy_threshold:
self.scale_up()
elif active_agents / len(self.agents_pool) < self.idle_threshold:
self.scale_down()
@log_decorator
@error_decorator
@timing_decorator
def start(self):
"""Start scaling"""
monitor_thread = threading.Thread(target=self.monitor_and_scale)
monitor_thread.start()
while True:
task = self.task_queue.get()
if task:
available_agent = next((agent for agent in self.agents_pool))
if available_agent:
available_agent.run(task)
# def del_agent(self):
# """Delete an agent"""
# with self.lock:
# if self.agents_pool:
# self.agents_poo.pop()
# del agent_to_remove
# import queue
# import threading
# from time import sleep
# from swarms.utils.decorators import error_decorator, log_decorator, timing_decorator
# from swarms.workers.worker import Worker
# class AutoScaler:
# """
# The AutoScaler is like a kubernetes pod, that autoscales an agent or worker or boss!
# # TODO Handle task assignment and task delegation
# # TODO: User task => decomposed into very small sub tasks => sub tasks assigned to workers => workers complete and update the swarm, can ask for help from other agents.
# # TODO: Missing, Task Assignment, Task delegation, Task completion, Swarm level communication with vector db
# Args:
# initial_agents (int, optional): Number of initial agents. Defaults to 10.
# scale_up_factor (int, optional): Scale up factor. Defaults to 1.
# idle_threshold (float, optional): Idle threshold. Defaults to 0.2.
# busy_threshold (float, optional): Busy threshold. Defaults to 0.7.
# agent ([type], optional): Agent. Defaults to None.
# Methods:
# add_task: Add task to queue
# scale_up: Scale up
# scale_down: Scale down
# monitor_and_scale: Monitor and scale
# start: Start scaling
# del_agent: Delete an agent
# Usage
# ```
# # usage of usage
# auto_scaler = AutoScaler(agent=YourCustomAgent)
# auto_scaler.start()
# for i in range(100):
# auto_scaler.add_task9f"task {I}})
# ```
# """
# @log_decorator
# @error_decorator
# @timing_decorator
# def __init__(
# self,
# initial_agents=10,
# scale_up_factor=1,
# idle_threshold=0.2,
# busy_threshold=0.7,
# agent=None,
# ):
# self.agent = agent or Worker
# self.agents_pool = [self.agent() for _ in range(initial_agents)]
# self.task_queue = queue.Queue()
# self.scale_up_factor = scale_up_factor
# self.idle_threshold = idle_threshold
# self.lock = threading.Lock()
# def add_task(self, task):
# """Add tasks to queue"""
# self.tasks_queue.put(task)
# @log_decorator
# @error_decorator
# @timing_decorator
# def scale_up(self):
# """Add more agents"""
# with self.lock:
# new_agents_counts = len(self.agents_pool) * self.scale_up_factor
# for _ in range(new_agents_counts):
# self.agents_pool.append(Worker())
# def scale_down(self):
# """scale down"""
# with self.lock:
# if len(self.agents_pool) > 10: # ensure minmum of 10 agents
# del self.agents_pool[-1] # remove last agent
# @log_decorator
# @error_decorator
# @timing_decorator
# def monitor_and_scale(self):
# """Monitor and scale"""
# while True:
# sleep(60) # check minute
# pending_tasks = self.task_queue.qsize()
# active_agents = sum([1 for agent in self.agents_pool if agent.is_busy()])
# if pending_tasks / len(self.agents_pool) > self.busy_threshold:
# self.scale_up()
# elif active_agents / len(self.agents_pool) < self.idle_threshold:
# self.scale_down()
# @log_decorator
# @error_decorator
# @timing_decorator
# def start(self):
# """Start scaling"""
# monitor_thread = threading.Thread(target=self.monitor_and_scale)
# monitor_thread.start()
# while True:
# task = self.task_queue.get()
# if task:
# available_agent = next((agent for agent in self.agents_pool))
# if available_agent:
# available_agent.run(task)
# # def del_agent(self):
# # """Delete an agent"""
# # with self.lock:
# # if self.agents_pool:
# # self.agents_poo.pop()
# # del agent_to_remove

@ -0,0 +1,22 @@
from langchain.tools import tool
from swarms.tools.base import BaseToolSet, SessionGetter, ToolScope
from swarms.utils.logger import logger
class ExitConversation(BaseToolSet):
@tool(
name="Exit Conversation",
description="A tool to exit the conversation. "
"Use this when you want to exit the conversation. "
"The input should be a message that the conversation is over.",
scope=ToolScope.SESSION,
)
def exit(self, message: str, get_session: SessionGetter) -> str:
"""Run the tool."""
_, executor = get_session()
del executor
logger.debug("\nProcessed ExitConversation.")
return message

@ -0,0 +1,36 @@
import requests
from bs4 import BeautifulSoup
from swarms.tools.base import BaseToolSet, tool
from swarms.utils.logger import logger
class RequestsGet(BaseToolSet):
@tool(
name="Requests Get",
description="A portal to the internet. "
"Use this when you need to get specific content from a website."
"Input should be a url (i.e. https://www.google.com)."
"The output will be the text response of the GET request.",
)
def get(self, url: str) -> str:
"""Run the tool."""
html = requests.get(url).text
soup = BeautifulSoup(html)
non_readable_tags = soup.find_all(
["script", "style", "header", "footer", "form"]
)
for non_readable_tag in non_readable_tags:
non_readable_tag.extract()
content = soup.get_text("\n", strip=True)
if len(content) > 300:
content = content[:300] + "..."
logger.debug(
f"\nProcessed RequestsGet, Input Url: {url} " f"Output Contents: {content}"
)
return content

@ -0,0 +1,125 @@
# speech to text tool
import os
import subprocess
import whisperx
from pydub import AudioSegment
from pytube import YouTube
class SpeechToText:
def __init__(
self,
video_url,
audio_format="mp3",
device="cuda",
batch_size=16,
compute_type="float16",
hf_api_key=None,
):
"""
# Example usage
video_url = "url"
speech_to_text = SpeechToText(video_url)
transcription = speech_to_text.transcribe_youtube_video()
print(transcription)
"""
self.video_url = video_url
self.audio_format = audio_format
self.device = device
self.batch_size = batch_size
self.compute_type = compute_type
self.hf_api_key = hf_api_key
def install(self):
subprocess.run(["pip", "install", "whisperx"])
subprocess.run(["pip", "install", "pytube"])
subprocess.run(["pip", "install", "pydub"])
def download_youtube_video(self):
audio_file = f"video.{self.audio_format}"
# Download video 📥
yt = YouTube(self.video_url)
yt_stream = yt.streams.filter(only_audio=True).first()
yt_stream.download(filename="video.mp4")
# Convert video to audio 🎧
video = AudioSegment.from_file("video.mp4", format="mp4")
video.export(audio_file, format=self.audio_format)
os.remove("video.mp4")
return audio_file
def transcribe_youtube_video(self):
audio_file = self.download_youtube_video()
device = "cuda"
batch_size = 16
compute_type = "float16"
# 1. Transcribe with original Whisper (batched) 🗣️
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
# 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
device,
return_char_alignments=False,
)
# 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=device
)
diarize_model(audio_file)
try:
segments = result["segments"]
transcription = " ".join(segment["text"] for segment in segments)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")
def transcribe(self, audio_file):
model = whisperx.load_model("large-v2", self.device, self.compute_type)
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=self.batch_size)
# 2. Align Whisper output 🔍
model_a, metadata = whisperx.load_align_model(
language_code=result["language"], device=self.device
)
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio,
self.device,
return_char_alignments=False,
)
# 3. Assign speaker labels 🏷️
diarize_model = whisperx.DiarizationPipeline(
use_auth_token=self.hf_api_key, device=self.device
)
diarize_model(audio_file)
try:
segments = result["segments"]
transcription = " ".join(segment["text"] for segment in segments)
return transcription
except KeyError:
print("The key 'segments' is not found in the result.")

@ -0,0 +1,845 @@
"""Base implementation for tools or skills."""
from __future__ import annotations
import asyncio
import inspect
import warnings
from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForToolRun,
CallbackManager,
CallbackManagerForToolRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import (
BaseModel,
Extra,
Field,
create_model,
root_validator,
validate_arguments,
)
from langchain.schema.runnable import Runnable, RunnableConfig, RunnableSerializable
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {}
for field_name in field_names:
field = model.__fields__[field_name]
fields[field_name] = (field.outer_type_, field.field_info)
return create_model(name, **fields) # type: ignore
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
class _SchemaConfig:
"""Configuration for the pydantic model."""
extra: Any = Extra.forbid
arbitrary_types_allowed: bool = True
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
func: Function to generate the schema from
Returns:
A pydantic model with the same arguments as the function
"""
# https://docs.pydantic.dev/latest/usage/validation_decorator/
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
inferred_model = validated.model # type: ignore
if "run_manager" in inferred_model.__fields__:
del inferred_model.__fields__["run_manager"]
if "callbacks" in inferred_model.__fields__:
del inferred_model.__fields__["callbacks"]
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(valid_properties)
)
class ToolException(Exception):
"""An optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working,
but will handle the exception according to the handle_tool_error
variable of the tool, and the processing result will be returned
to the agent as observation, and printed in red on the console.
"""
pass
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
"""Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Create the definition of the new tool class."""
super().__init_subclass__(**kwargs)
args_schema_type = cls.__annotations__.get("args_schema", None)
if args_schema_type is not None:
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
name = cls.__name__
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to be called during tool execution."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None
These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None
This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown."""
class Config(Serializable.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
keys = {k for k in self.args if k != "kwargs"}
return len(keys) == 1
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
schema = create_schema_from_function(self.name, self._run)
return schema.schema()["properties"]
# --- Runnable ---
@property
def input_schema(self) -> Type[BaseModel]:
"""The tool's input schema."""
if self.args_schema is not None:
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
def invoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = config or {}
return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
async def ainvoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = config or {}
return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
# --- Tool ---
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""
async def _arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously.
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""
return await asyncio.get_running_loop().run_in_executor(
None,
partial(self._run, **kwargs),
*args,
)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def run(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
verbose_,
tags,
self.tags,
metadata,
self.metadata,
)
# TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager")
run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
**kwargs,
)
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
)
except ToolException as e:
if not self.handle_tool_error:
run_manager.on_tool_error(e)
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e)
raise e
else:
run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
parsed_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
verbose_,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
)
except ToolException as e:
if not self.handle_tool_error:
await run_manager.on_tool_error(e)
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
await run_manager.on_tool_end(
str(observation), color="red", name=self.name, **kwargs
)
return observation
except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e)
raise e
else:
await run_manager.on_tool_end(
str(observation), color=color, name=self.name, **kwargs
)
return observation
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
"""Make tool callable."""
return self.run(tool_input, callbacks=callbacks)
class Tool(BaseTool):
"""Tool that takes in function or coroutine directly."""
description: str = ""
func: Optional[Callable[..., str]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
# --- Runnable ---
async def ainvoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs)
# --- Tool ---
@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ToolException(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
else:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._run, run_manager=run_manager, **kwargs), *args
)
# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
name=name, func=func, description=description, **kwargs
)
@classmethod
def from_function(
cls,
func: Optional[Callable],
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
**kwargs: Any,
) -> Tool:
"""Initialize tool from a function."""
if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided")
return cls(
name=name,
func=func,
coroutine=coroutine,
description=description,
return_direct=return_direct,
args_schema=args_schema,
**kwargs,
)
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
# --- Runnable ---
async def ainvoke(
self,
input: Union[str, Dict],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
if not self.coroutine:
# If the tool does not implement async, fall back to default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)
return await super().ainvoke(input, config, **kwargs)
# --- Tool ---
@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await asyncio.get_running_loop().run_in_executor(
None,
partial(self._run, run_manager=run_manager, **kwargs),
*args,
)
@classmethod
def from_function(
cls,
func: Optional[Callable] = None,
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
"""Create tool from a given function.
A classmethod that helps to create a tool from a function.
Args:
func: The function from which to create a tool
coroutine: The async function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
**kwargs: Additional arguments to pass to the tool
Returns:
The tool
Examples:
.. code-block:: python
def add(a: int, b: int) -> int:
\"\"\"Add two numbers\"\"\"
return a + b
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
if func is not None:
source_function = func
elif coroutine is not None:
source_function = coroutine
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description = description or source_function.__doc__
if description is None:
raise ValueError(
"Function must have a docstring if description not provided."
)
# Description example:
# search_api(query: str) - Searches the API for the query.
sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)
def tool(
*args: Union[str, Callable, Runnable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Requires:
- Function must be of type (str) -> str
- Function must have a docstring
Examples:
.. code-block:: python
@tool
def search_api(query: str) -> str:
# Searches the API for the query.
return
@tool("search", return_direct=True)
def search_api(query: str) -> str:
# Searches the API for the query.
return
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
if isinstance(dec_func, Runnable):
runnable = dec_func
if runnable.input_schema.schema().get("type") != "object":
raise ValueError("Runnable must have an object schema.")
async def ainvoke_wrapper(
callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return await runnable.ainvoke(kwargs, {"callbacks": callbacks})
def invoke_wrapper(
callbacks: Optional[Callbacks] = None, **kwargs: Any
) -> Any:
return runnable.invoke(kwargs, {"callbacks": callbacks})
coroutine = ainvoke_wrapper
func = invoke_wrapper
schema: Optional[Type[BaseModel]] = runnable.input_schema
description = repr(runnable)
elif inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
func = None
schema = args_schema
description = None
else:
coroutine = None
func = dec_func
schema = args_schema
description = None
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
coroutine,
name=tool_name,
description=description,
return_direct=return_direct,
args_schema=schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
if func.__doc__ is None:
raise ValueError(
"Function must have a docstring if "
"description not provided and infer_schema is False."
)
return Tool(
name=tool_name,
func=func,
description=f"{tool_name} tool",
return_direct=return_direct,
coroutine=coroutine,
)
return _make_tool
if len(args) == 2 and isinstance(args[0], str) and isinstance(args[1], Runnable):
return _make_with_name(args[0])(args[1])
elif len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseTool:
return _make_with_name(func.__name__)(func)
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

@ -0,0 +1,317 @@
from collections import deque
from typing import Dict, List, Optional, Any
import re
from langchain import LLMChain, OpenAI, PromptTemplate, SerpAPIWrapper
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import BaseLLM
from langchain.vectorstores.base import VectorStore
from pydantic import BaseModel, Field
from langchain.chains.base import Chain
from langchain.vectorstores import FAISS
import faiss
from langchain.docstore import InMemoryDocstore
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from .executor import Executor, AgentExecutorWithTranslation
class ContextAwareAgent(ZeroShotAgent):
def get_full_inputs(
self, intermediate_steps, **kwargs: Any
) -> Dict[str, Any]:
"""Create the full inputs for the LLMChain from intermediate steps."""
thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}
return full_inputs
def _construct_scratchpad(self, intermediate_steps):
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts = ""
# only modify the following line, [-2: ]
for action, observation in intermediate_steps[-2: ]:
thoughts += action.log
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if "is not a valid tool, try another one" in observation:
thoughts += "You should select another tool rather than the invalid one.\n"
return thoughts
class TaskCreationChain(LLMChain):
"""Chain to generates tasks."""
@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
"""Get the response parser."""
task_creation_template = (
"You are an task creation AI that uses the result of an execution agent"
" to create new tasks with the following objective: {objective},"
" The last completed task has the result: {result}."
" This result was based on this task description: {task_description}."
" These are incomplete tasks: {incomplete_tasks}."
" Based on the result, create new tasks to be completed"
" by the AI system that do not overlap with incomplete tasks."
" For a simple objective, do not generate complex todo lists."
" Do not generate repetitive tasks (e.g., tasks that have already been completed)."
" If there is not futher task needed to complete the objective, return NO TASK."
" Now return the tasks as an array."
)
prompt = PromptTemplate(
template=task_creation_template,
input_variables=["result", "task_description", "incomplete_tasks", "objective"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
class InitialTaskCreationChain(LLMChain):
"""Chain to generates tasks."""
@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
"""Get the response parser."""
task_creation_template = (
"You are a planner who is an expert at coming up with a todo list for a given objective. For a simple objective, do not generate a complex todo list. Generate the first (only one) task needed to do for this objective: {objective}"
)
prompt = PromptTemplate(
template=task_creation_template,
input_variables=["objective"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
class TaskPrioritizationChain(LLMChain):
"""Chain to prioritize tasks."""
@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
"""Get the response parser."""
task_prioritization_template = (
"You are an task prioritization AI tasked with cleaning the formatting of and reprioritizing"
" the following tasks: {task_names}."
" Consider the ultimate objective of your team: {objective}."
" Do not make up any tasks, just reorganize the existing tasks."
" Do not remove any tasks. Return the result as a numbered list, like:"
" #. First task"
" #. Second task"
" Start the task list with number {next_task_id}. (e.g., 2. ***, 3. ***, etc.)"
)
prompt = PromptTemplate(
template=task_prioritization_template,
input_variables=["task_names", "next_task_id", "objective"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
def get_next_task(task_creation_chain: LLMChain, result: Dict, task_description: str, task_list: List[str], objective: str) -> List[Dict]:
"""Get the next task."""
incomplete_tasks = ", ".join(task_list)
response = task_creation_chain.run(result=result, task_description=task_description, incomplete_tasks=incomplete_tasks, objective=objective)
# change the split method to re matching
# new_tasks = response.split('\n')
task_pattern = re.compile(r'\d+\. (.+?)\n')
new_tasks = task_pattern.findall(response)
return [{"task_name": task_name} for task_name in new_tasks if task_name.strip()]
def prioritize_tasks(task_prioritization_chain: LLMChain, this_task_id: int, task_list: List[Dict], objective: str) -> List[Dict]:
"""Prioritize tasks."""
task_names = [t["task_name"] for t in task_list]
next_task_id = int(this_task_id) + 1
response = task_prioritization_chain.run(task_names=task_names, next_task_id=next_task_id, objective=objective)
new_tasks = response.split('\n')
prioritized_task_list = []
for task_string in new_tasks:
if not task_string.strip():
continue
task_parts = task_string.strip().split(".", 1)
if len(task_parts) == 2:
task_id = task_parts[0].strip()
task_name = task_parts[1].strip()
prioritized_task_list.append({"task_id": task_id, "task_name": task_name})
return prioritized_task_list
def _get_top_tasks(vectorstore, query: str, k: int) -> List[str]:
"""Get the top k tasks based on the query."""
results = vectorstore.similarity_search_with_score(query, k=k)
if not results:
return []
sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))
return [str(item.metadata['task']) for item in sorted_results]
def execute_task(vectorstore, execution_chain: LLMChain, objective: str, task: str, k: int = 5) -> str:
"""Execute a task."""
context = _get_top_tasks(vectorstore, query=objective, k=k)
return execution_chain.run(objective=objective, context=context, task=task)
class BabyAGI(Chain, BaseModel):
"""Controller model for the BabyAGI agent."""
task_list: deque = Field(default_factory=deque)
task_creation_chain: TaskCreationChain = Field(...)
task_prioritization_chain: TaskPrioritizationChain = Field(...)
initial_task_creation_chain: InitialTaskCreationChain = Field(...)
execution_chain: AgentExecutor = Field(...)
task_id_counter: int = Field(1)
vectorstore: VectorStore = Field(init=False)
max_iterations: Optional[int] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def add_task(self, task: Dict):
self.task_list.append(task)
def print_task_list(self):
print("\033[95m\033[1m" + "\n*****TASK LIST*****\n" + "\033[0m\033[0m")
for t in self.task_list:
print(str(t["task_id"]) + ": " + t["task_name"])
def print_next_task(self, task: Dict):
print("\033[92m\033[1m" + "\n*****NEXT TASK*****\n" + "\033[0m\033[0m")
print(str(task["task_id"]) + ": " + task["task_name"])
def print_task_result(self, result: str):
print("\033[93m\033[1m" + "\n*****TASK RESULT*****\n" + "\033[0m\033[0m")
print(result)
@property
def input_keys(self) -> List[str]:
return ["objective"]
@property
def output_keys(self) -> List[str]:
return []
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Run the agent."""
# not an elegant implementation, but it works for the first task
objective = inputs['objective']
first_task = inputs.get("first_task", self.initial_task_creation_chain.run(objective=objective))# self.task_creation_chain.llm(initial_task_prompt))
self.add_task({"task_id": 1, "task_name": first_task})
num_iters = 0
while True:
if self.task_list:
self.print_task_list()
# Step 1: Pull the first task
task = self.task_list.popleft()
self.print_next_task(task)
# Step 2: Execute the task
result = execute_task(
self.vectorstore, self.execution_chain, objective, task["task_name"]
)
this_task_id = int(task["task_id"])
self.print_task_result(result)
# Step 3: Store the result in Pinecone
result_id = f"result_{task['task_id']}"
self.vectorstore.add_texts(
texts=[result],
metadatas=[{"task": task["task_name"]}],
ids=[result_id],
)
# Step 4: Create new tasks and reprioritize task list
new_tasks = get_next_task(
self.task_creation_chain, result, task["task_name"], [t["task_name"] for t in self.task_list], objective
)
for new_task in new_tasks:
self.task_id_counter += 1
new_task.update({"task_id": self.task_id_counter})
self.add_task(new_task)
if len(self.task_list) == 0:
print("\033[91m\033[1m" + "\n*****NO TASK, ABORTING*****\n" + "\033[0m\033[0m")
break
self.task_list = deque(
prioritize_tasks(
self.task_prioritization_chain, this_task_id, list(self.task_list), objective
)
)
num_iters += 1
if self.max_iterations is not None and num_iters == self.max_iterations:
print("\033[91m\033[1m" + "\n*****TASK ENDING*****\n" + "\033[0m\033[0m")
break
return {}
@classmethod
def from_llm(
cls,
llm: BaseLLM,
prompt = None,
verbose: bool = False,
tools = None,
stream_output = None,
**kwargs
) -> "BabyAGI":
embeddings_model = OpenAIEmbeddings()
embedding_size = 1536
index = faiss.IndexFlatL2(embedding_size)
vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
task_creation_chain = TaskCreationChain.from_llm(
llm, verbose=verbose
)
initial_task_creation_chain = InitialTaskCreationChain.from_llm(
llm, verbose=verbose
)
task_prioritization_chain = TaskPrioritizationChain.from_llm(
llm, verbose=verbose
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
tool_names = [tool.name for tool in tools]
agent = ContextAwareAgent(llm_chain=llm_chain, allowed_tools=tool_names)
if stream_output:
agent_executor = Executor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
else:
agent_executor = AgentExecutorWithTranslation.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
return cls(
task_creation_chain=task_creation_chain,
task_prioritization_chain=task_prioritization_chain,
initial_task_creation_chain=initial_task_creation_chain,
execution_chain=agent_executor,
vectorstore=vectorstore,
**kwargs
)
if __name__ == "__main__":
todo_prompt = PromptTemplate.from_template("You are a planner who is an expert at coming up with a todo list for a given objective. For a simple objective, do not generate a complex todo list. Come up with a todo list for this objective: {objective}")
todo_chain = LLMChain(llm=OpenAI(temperature=0), prompt=todo_prompt)
search = SerpAPIWrapper()
tools = [
Tool(
name = "Search",
func=search.run,
description="useful for when you need to answer questions about current events"
),
Tool(
name = "TODO",
func=todo_chain.run,
description="useful for when you need to come up with todo lists. Input: an objective to create a todo list for. Output: a todo list for that objective. Please be very clear what the objective is!"
)
]
prefix = """You are an AI who performs one task based on the following objective: {objective}. Take into account these previously completed tasks: {context}."""
suffix = """Question: {task}
{agent_scratchpad}"""
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["objective", "task", "context","agent_scratchpad"]
)
OBJECTIVE = "Write a weather report for SF today"
llm = OpenAI(temperature=0)
# Logging of LLMChains
verbose=False
# If None, will keep on going forever
max_iterations: Optional[int] = 10
baby_agi = BabyAGI.from_llm(
llm=llm,
verbose=verbose,
max_iterations=max_iterations
)
baby_agi({"objective": OBJECTIVE})

@ -0,0 +1,38 @@
from . import chemical
from . import film
from . import kg
from . import stock
from . import weather
from . import wikipedia
from . import wolframalpha
from . import office
from . import bing_search
from . import map
from . import translation
from . import tutorial
from . import file_operation
from . import meta_analysis
from . import database
from . import db_diag
from . import code_interpreter
from . import hugging_tools
from . import arxiv
from . import zillow
from . import google_scholar
from . import google_places
from . import google_serper
from . import python
from . import sceneXplain
from . import shell
from . import image_generation
from . import airbnb
from . import job_search
from . import gradio_tools
from . import travel
from . import walmart
from swarms.tools.tool import Tool
from swarms.tools.registry import register
from swarms.tools.serve import ToolServer

@ -0,0 +1,7 @@
from ..registry import register
@register("airbnb")
def airbnb():
from .api import build_tool
return build_tool

@ -0,0 +1,305 @@
import requests
import json
from datetime import date, datetime, timedelta
import os
from ..tool import Tool
from typing import Optional, Dict, List
def build_tool(config) -> Tool:
tool = Tool(
"Short-term rental and housing information",
"Look up rental and housing information",
name_for_model="Airbnb",
description_for_model="Plugin for look up rental and housing information",
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
BASE_URL = "https://airbnb19.p.rapidapi.com/api/v1"
KEY = config["subscription_key"]
HEADERS = {
"X-RapidAPI-Key": KEY,
"X-RapidAPI-Host": "airbnb19.p.rapidapi.com"
}
@tool.get("/ssearch_property")
def search_property(_id: str, display_name: Optional[str] = None,
total_records: Optional[str] = '10', currency: Optional[str] = 'USD',
offset: Optional[str] = None, category: Optional[str] = None,
adults: Optional[int] = 1, children: Optional[int] = None,
infants: Optional[int] = None, pets: Optional[int] = None,
checkin: Optional[str] = None, checkout: Optional[str] = None,
priceMin: Optional[int] = None, priceMax: Optional[int] = None,
minBedrooms: Optional[int] = None, minBeds: Optional[int] = None,
minBathrooms: Optional[int] = None, property_type: Optional[List[str]] = None,
host_languages: Optional[List[str]] = None, amenities: Optional[List[str]] = None,
type_of_place: Optional[List[str]] = None, top_tier_stays: Optional[List[str]] = None,
self_check_in: Optional[bool] = None, instant_book: Optional[bool] = None,
super_host: Optional[bool] = None, languageId: Optional[str] = None) -> dict:
"""
This function takes various parameters to search properties on Airbnb.
Parameters:
api_key (str): The RapidAPI Key for Airbnb API.
id (str): The ID of the destination.
display_name (Optional[str]): The name of the destination.
total_records (Optional[str]): The number of records to be retrieved per API call.
currency (Optional[str]): The currency for the transaction.
offset (Optional[str]): The offset for the search result.
category (Optional[str]): The category of the properties.
adults (Optional[int]): The number of adults.
children (Optional[int]): The number of children.
infants (Optional[int]): The number of infants.
pets (Optional[int]): The number of pets.
checkin (Optional[str]): The check-in date.
checkout (Optional[str]): The check-out date.
priceMin (Optional[int]): The minimum price.
priceMax (Optional[int]): The maximum price.
minBedrooms (Optional[int]): The minimum number of bedrooms.
minBeds (Optional[int]): The minimum number of beds.
minBathrooms (Optional[int]): The minimum number of bathrooms.
property_type (Optional[List[str]]): The type of the property.
host_languages (Optional[List[str]]): The languages that the host can speak.
amenities (Optional[List[str]]): The amenities provided by the property.
type_of_place (Optional[List[str]]): The type of the place.
top_tier_stays (Optional[List[str]]): The list of top-tier stays.
self_check_in (Optional[bool]): If the property has self check-in feature.
instant_book (Optional[bool]): If the property can be booked instantly.
super_host (Optional[bool]): If the host is a super host.
languageId (Optional[str]): The ID of the language for the response.
Returns:
dict: A dictionary that contains the search results.
"""
params = {
'id': _id,
'display_name': display_name,
'totalRecords': total_records,
'currency': currency,
'offset': offset,
'category': category,
'adults': adults,
'children': children,
'infants': infants,
'pets': pets,
'checkin': checkin,
'checkout': checkout,
'priceMin': priceMin,
'priceMax': priceMax,
'minBedrooms': minBedrooms,
'minBeds': minBeds,
'minBathrooms': minBathrooms,
'property_type': property_type,
'host_languages': host_languages,
'amenities': amenities,
'type_of_place': type_of_place,
'top_tier_stays': top_tier_stays,
'self_check_in': self_check_in,
'instant_book': instant_book,
'super_host': super_host,
'languageId': languageId
}
response = requests.get(f"{BASE_URL}/searchPropertyByPlace", headers=HEADERS, params=params)
return response.json()['data'][0]
@tool.get("/search_property_by_coordinates")
def search_property_by_coordinates(neLat: float, neLng: float, swLat: float, swLng: float,
currency: Optional[str] = 'USD', nextPageCursor: Optional[str] = None,
totalRecords: Optional[str] = None, infants: Optional[int] = None,
adults: Optional[int] = 1, children: Optional[int] = None,
pets: Optional[int] = None, checkin: Optional[str] = None,
checkout: Optional[str] = None, priceMin: Optional[int] = None,
priceMax: Optional[int] = None, minBedrooms: Optional[int] = None,
minBeds: Optional[int] = None, minBathrooms: Optional[int] = None,
property_type: Optional[List[str]] = None, host_languages: Optional[List[str]] = None,
amenities: Optional[List[str]] = None, type_of_place: Optional[List[str]] = None,
top_tier_stays: Optional[List[str]] = None, super_host: Optional[bool] = None) -> dict:
"""
This function takes GEO coordinates and various other parameters to search properties on Airbnb.
Parameters:
neLat (float): Latitude of the northeastern corner of the search area.
neLng (float): Longitude of the northeastern corner of the search area.
swLat (float): Latitude of the southwestern corner of the search area.
swLng (float): Longitude of the southwestern corner of the search area.
Other parameters are the same as search_property function.
Returns:
dict: A dictionary that contains the search results.
"""
params = {
'neLat': neLat,
'neLng': neLng,
'swLat': swLat,
'swLng': swLng,
'currency': currency,
'nextPageCursor': nextPageCursor,
'totalRecords': totalRecords,
'infants': infants,
'adults': adults,
'children': children,
'pets': pets,
'checkin': checkin,
'checkout': checkout,
'priceMin': priceMin,
'priceMax': priceMax,
'minBedrooms': minBedrooms,
'minBeds': minBeds,
'minBathrooms': minBathrooms,
'property_type': property_type,
'host_languages': host_languages,
'amenities': amenities,
'type_of_place': type_of_place,
'top_tier_stays': top_tier_stays,
'super_host': super_host
}
response = requests.get(f"https://airbnb19.p.rapidapi.com/api/v2/searchPropertyByGEO", headers=HEADERS, params=params)
return response.json()['data']['list'][0]
@tool.get("/search_destination")
def search_destination(self, query: str, country: Optional[str] = None) -> dict:
"""
This function performs a destination search given a query and optionally a country. And return positions 'ID' information.
Parameters:
query (str): The search query.
country (Optional[str]): The country for the search.
Returns:
dict: A dictionary that contains the search results. including ID information for a destination
"""
params = {
'query': query,
'country': country
}
response = requests.get(f"{BASE_URL}/searchDestination", headers=HEADERS, params=params)
return response.json()
@tool.get("/property_by_coordinates")
def property_by_coordinates(long: float, lat: float, d: Optional[float] = None, includeSold: Optional[bool] = None):
"""
Search property by coordinates.
Args:
long (float): Longitude of the property. This is a required parameter.
lat (float): Latitude of the property. This is a required parameter.
d (float, optional): Diameter in miles. The max and low values are 0.5 and 0.05 respectively. The default value is 0.1.
includeSold (bool, optional): Include sold properties in the results. True or 1 to include (default), False or 0 to exclude.
Returns:
A response object from the Zillow API with an array of zpid.
"""
params = {
"long": long,
"lat": lat,
"d": d,
"includeSold": includeSold,
}
# Remove parameters that are None
params = {k: v for k, v in params.items() if v is not None}
url = BASE_URL + '/propertyByCoordinates'
# Send GET request to Zillow API endpoint
response = requests.get(url, headers=HEADERS, params=params)
return response.json()
@tool.get("/get_property_details")
def get_property_details(propertyId: int, currency: Optional[str] = 'USD',
checkIn: Optional[str] = None, checkOut: Optional[str] = None,
adults: Optional[int] = 1, children: Optional[int] = None,
infants: Optional[int] = None, pets: Optional[int] = None,
languageId: Optional[str] = None) -> dict:
"""
This function retrieves the details of a property given its ID.
Parameters:
propertyId (int): The ID of the property.
currency (Optional[str]): The currency for the transaction.
checkIn (Optional[str]): The check-in date.
checkOut (Optional[str]): The check-out date.
adults (Optional[int]): The number of adults.
children (Optional[int]): The number of children.
infants (Optional[int]): The number of infants.
pets (Optional[int]): The number of pets.
languageId (Optional[str]): The ID of the language for the response.
Returns:
dict: A dictionary that contains the details of the property.
"""
params = {
'propertyId': propertyId,
'currency': currency,
'checkIn': checkIn,
'checkOut': checkOut,
'adults': adults,
'children': children,
'infants': infants,
'pets': pets,
'languageId': languageId
}
response = requests.get(f"https://airbnb19.p.rapidapi.com/api/v2/getPropertyDetails", headers=HEADERS, params=params)
return response.json()
@tool.get("/check_availability")
def check_availability(propertyId: int) -> dict:
"""
This function checks the availability of a property given its ID.
Parameters:
propertyId (int): The ID of the property.
Returns:
dict: A dictionary that contains the availability of the property.
"""
params = {
'propertyId': propertyId,
}
response = requests.get(f"{BASE_URL}/checkAvailability", headers=HEADERS, params=params)
return response.json()
@tool.get("/get_property_reviews")
def get_property_reviews(propertyId: int) -> dict:
"""
This function retrieves the reviews of a property given its ID.
Parameters:
propertyId (int): The ID of the property.
Returns:
dict: A dictionary that contains the reviews of the property.
"""
params = {
'propertyId': propertyId,
}
response = requests.get(f"{BASE_URL}/getPropertyReviews", headers=HEADERS, params=params)
return response.json()
@tool.get("/get_property_checkout_price")
def get_property_checkout_price(propertyId: int, checkIn: str) -> dict:
"""
This function retrieves the checkout cost of a property given its ID and check-in date.
Parameters:
propertyId (int): The ID of the property.
checkIn (str): The check-in date.
Returns:
dict: A dictionary that contains the checkout price of the property.
"""
params = {
'propertyId': propertyId,
'checkIn': checkIn
}
response = requests.get(f"{BASE_URL}/getPropertyCheckoutPrice", headers=HEADERS, params=params)
return response.json()
return tool

@ -0,0 +1,29 @@
# Airbnb Service
Contributor: [Kunlun Zhu](https://github.com/Kunlun-Zhu)
You can get your RAIPID key here: https://rapidapi.com/hub
You ought to subscribe 'Airbnb API' in your account to use this tool
# Short-term Rental and Housing Information Tool
This tool, named `Short-term Rental and Housing Information`, is designed to interact with the Airbnb API to search for properties, get property details, check availability, get property reviews, and retrieve the checkout price. The tool operates by making HTTP requests to the Airbnb API and formatting the responses into an easily usable form.
## Main Functionality
1. **Search for Properties**: This functionality allows you to search for properties based on a variety of parameters like the number of adults, children, and infants, property type, amenities, check-in and check-out dates, and many more. This is done using the `search_property` function.
2. **Search Property by Coordinates**: This function allows you to search for properties in a specific geographic area defined by the northeast and southwest coordinates of the area. This is done using the `search_property_by_coordinates` function.
3. **Search for Destination**: The `search_destination` function helps to perform a destination search given a query and optionally a country. It returns positions 'ID' information.
4. **Get Property Details**: The `get_property_details` function is used to retrieve detailed information about a specific property. This includes the number of rooms, amenities, location, and other relevant information.
5. **Check Property Availability**: This function, `check_availability`, allows you to check if a property is available for booking.
6. **Get Property Reviews**: You can use the `get_property_reviews` function to retrieve reviews of a property.
7. **Get Property Checkout Price**: The `get_property_checkout_price` function is used to get the checkout cost of a property given its ID and check-in date.
This tool provides a simple and effective way to interact with the Airbnb API, making it easier for developers to incorporate Airbnb data into their applications.

@ -0,0 +1,11 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'Airbnb', "http://127.0.0.1:8079/tools/airbnb/"
tools_name, tools_config = load_single_tools(tool_name, tool_url)
print(tools_name, tools_config)
qa = STQuestionAnswerer()
agent = qa.load_tools(tools_name, tools_config)
agent("List some houses to rent in Santa Monica, CA.")

@ -0,0 +1,152 @@
"""Interface for tools."""
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Union
from langchain.agents import Tool as LangChainTool
from langchain.tools.base import BaseTool
import requests
import json
import aiohttp
import http.client
http.client._MAXLINE = 655360
from bmtools import get_logger
logger = get_logger(__name__)
class Tool(LangChainTool):
tool_logo_md: str = ""
class RequestTool(BaseTool):
"""Tool that takes in function or coroutine directly."""
description: str = ""
func: Callable[[str], str]
afunc: Callable[[str], str]
coroutine: Optional[Callable[[str], Awaitable[str]]] = None
max_output_len = 4000
tool_logo_md: str = ""
def _run(self, tool_input: str) -> str:
"""Use the tool."""
return self.func(tool_input)
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
ret = await self.afunc(tool_input)
return ret
def convert_prompt(self,params):
lines = "Your input should be a json (args json schema): {{"
for p in params:
logger.debug(p)
optional = not p['required']
description = p.get('description', '')
if len(description) > 0:
description = "("+description+")"
lines += '"{name}" : {type}{desc}, '.format(name=p['name'],
type= p['schema']['type'],
optional=optional,
desc=description)
lines += "}}"
return lines
def __init__(self, root_url, func_url, method, request_info, **kwargs):
""" Store the function, description, and tool_name in a class to store the information
"""
url = root_url + func_url
def func(json_args):
if isinstance(json_args, str):
try:
json_args = json.loads(json_args)
except:
return "Your input can not be parsed as json, please use thought."
if "tool_input" in json_args:
json_args = json_args["tool_input"]
# if it's post put patch, then we do json
if method.lower() in ['post', 'put', 'patch']:
response = getattr(requests, method.lower())(url, json=json_args)
else:
# for other methods, we use get, and use json_args as query params
response = requests.get(url, params=json_args)
if response.status_code == 200:
message = response.text
else:
message = f"Error code {response.status_code}. You can try (1) Change your input (2) Call another function. (If the same error code is produced more than 4 times, please use Thought: I can not use these APIs, so I will stop. Final Answer: No Answer, please check the APIs.)"
message = message[:self.max_output_len] # TODO: not rigorous, to improve
return message
def convert_openapi_to_params(request_body):
if not request_body:
return []
params = []
for content_type, content in request_body['content'].items():
schema = content['schema']
properties = schema.get('properties', {})
required = schema.get('required', [])
for key, value in properties.items():
param = {
'name': key,
'schema': value,
'required': key in required,
'description': value.get('description', '')
}
if content_type == 'multipart/form-data' and value.get('format') == 'binary':
param['type'] = 'file'
elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']:
param['type'] = 'form'
else:
param['type'] = 'json'
params.append(param)
return params
async def afunc(json_args):
if isinstance(json_args, str):
try:
json_args = json.loads(json_args)
except:
return "Your input can not be parsed as json, please use thought."
if "tool_input" in json_args:
json_args = json_args["tool_input"]
async with aiohttp.ClientSession() as session:
async with session.get(url, params=json_args) as response:
if response.status == 200:
message = await response.text()
else:
message = f"Error code {response.status_code}. You can try (1) Change your input (2) Call another function. (If the same error code is produced more than 4 times, please use Thought: I can not use these APIs, so I will stop. Final Answer: No Answer, please check the APIs.)"
message = message[:self.max_output_len] # TODO: not rigorous, to improve
return message
tool_name = func_url.replace("/", ".").strip(".")
str_doc = ''
if 'parameters' in request_info[method]:
str_doc = self.convert_prompt(request_info[method]['parameters'])
if 'requestBody' in request_info[method]:
str_doc = str_doc + "\n" + self.convert_prompt(
convert_openapi_to_params(request_info[method]['requestBody']))
# description = f"- {tool_name}:\n" + \
# request_info[method].get('summary', '').replace("{", "{{").replace("}", "}}") \
description = request_info[method].get('description','').replace("{", "{{").replace("}", "}}") \
+ ". " \
+ str_doc \
+ f" The Action to trigger this API should be {tool_name} and the input parameters should be a json dict string. Pay attention to the type of parameters."
logger.info("API Name: {}".format(tool_name))
logger.info("API Description: {}".format(description))
super(RequestTool, self).__init__(
name=tool_name, func=func, afunc=afunc, description=description, **kwargs
)

@ -0,0 +1,7 @@
from ..registry import register
@register("arxiv")
def arxiv():
from .api import build_tool
return build_tool

@ -0,0 +1,53 @@
from ..tool import Tool
from typing import Any
import arxiv
def build_tool(config) -> Tool:
tool = Tool(
"Arxiv",
"Look up for information from scientific articles on arxiv.org",
name_for_model="Arxiv",
description_for_model=(
"Search information from Arxiv.org "
"Useful for when you need to answer questions about Physics, Mathematics, "
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
"Electrical Engineering, and Economics "
"from scientific articles on arxiv.org. "
"Input should be a search query."
),
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
arxiv_exceptions: Any # :meta private:
top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH = 300
doc_content_chars_max: int = 4000
@tool.get("/get_arxiv_article_information")
def get_arxiv_article_information(query : str):
'''Run Arxiv search and get the article meta information.
'''
param = {
"q": query
}
try:
results = arxiv.Search( # type: ignore
query[: ARXIV_MAX_QUERY_LENGTH], max_results = top_k_results
).results()
except arxiv_exceptions as ex:
return f"Arxiv exception: {ex}"
docs = [
f"Published: {result.updated.date()}\nTitle: {result.title}\n"
f"Authors: {', '.join(a.name for a in result.authors)}\n"
f"Summary: {result.summary}"
for result in results
]
if docs:
return "\n\n".join(docs)[: doc_content_chars_max]
else:
return "No good Arxiv Result was found"
return tool

@ -0,0 +1,38 @@
# Arxiv Queries
Contributor: [Sihan Zhao](https://github.com/Sarah816)
## Tool Description
This Python-based tool offers a streamlined way to look up scientific articles on Arxiv.org. Named "Arxiv", this tool is particularly helpful when you need to answer questions about Physics, Mathematics, Computer Science, Quantitative Biology, Quantitative Finance, Statistics, Electrical Engineering, and Economics based on scientific articles from Arxiv.org.
### Tool Specifications
- **Name**: Arxiv
- **Purpose**: Look up for information from scientific articles on arxiv.org
- **Logo**: ![Arxiv Logo](https://your-app-url.com/.well-known/logo.png)
- **Contact Email**: hello@contact.com
- **Legal Information**: [Legal Information](hello@legal.com)
### Core Functionality
1. `get_arxiv_article_information`
This method takes a search query and returns meta-information about the Arxiv articles that match this query. The method uses an API to search articles on Arxiv.org and returns details like the date of publication, title of the article, names of the authors, and the summary of the article.
The method follows these steps:
- It takes a query as a string input.
- The query is passed to the Arxiv Search API.
- The method fetches the top three results.
- For each result, it collects information about the publication date, title, authors, and summary.
- It returns this information as a string.
If the search operation encounters an error, the method returns a message describing the Arxiv exception. If no suitable articles are found on Arxiv.org that match the query, it returns a message stating that no good Arxiv result was found.
### Constants
- **ARXIV_MAX_QUERY_LENGTH**: Maximum length of a query that can be passed to the Arxiv Search API. It's set to 300.
- **doc_content_chars_max**: Maximum characters of the Arxiv results to be returned. It's set to 4000.
- **top_k_results**: The maximum number of Arxiv Search results to be returned. It's set to 3.
Please note that the parameters can be optional and have their own default values. You should consult the method's documentation to understand the default behavior and the specific role of each parameter.

@ -0,0 +1,11 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'arxiv', "http://127.0.0.1:8079/tools/arxiv/"
tools_name, tools_config = load_single_tools(tool_name, tool_url)
print(tools_name, tools_config)
qa = STQuestionAnswerer()
agent = qa.load_tools(tools_name, tools_config)
agent("List some papers written by Timo Schick")

@ -0,0 +1,6 @@
from ..registry import register
@register("bing_search")
def bing_search():
from .api import build_tool
return build_tool

@ -0,0 +1,181 @@
import requests
from bs4 import BeautifulSoup
from ..tool import Tool
from enum import Enum
from typing import Tuple
# search result list chunk size
SEARCH_RESULT_LIST_CHUNK_SIZE = 3
# result target page text chunk content length
RESULT_TARGET_PAGE_PER_TEXT_COUNT = 500
class BingAPI:
"""
A class for performing searches on the Bing search engine.
Attributes
----------
bing_api : BingAPI
The Bing API to use for performing searches.
Methods
-------
__init__(self, subscription_key: str) -> None:
Initialize the BingSearch instance with the given subscription key.
search_top3(self, key_words: str) -> List[str]:
Perform a search on the Bing search engine with the given keywords and return the top 3 search results.
load_page_index(self, idx: int) -> str:
Load the detailed page of the search result at the given index.
"""
def __init__(self, subscription_key : str) -> None:
"""
Initialize the BingSearch instance with the given subscription key.
Parameters
----------
subscription_key : str
The subscription key to use for the Bing API.
"""
self._headers = {
'Ocp-Apim-Subscription-Key': subscription_key
}
self._endpoint = "https://api.bing.microsoft.com/v7.0/search"
self._mkt = 'en-US'
def search(self, key_words : str, max_retry : int = 3):
for _ in range(max_retry):
try:
result = requests.get(self._endpoint, headers=self._headers, params={'q': key_words, 'mkt': self._mkt }, timeout=10)
except Exception:
# failed, retry
continue
if result.status_code == 200:
result = result.json()
# search result returned here
return result
else:
# failed, retry
continue
raise RuntimeError("Failed to access Bing Search API.")
def load_page(self, url : str, max_retry : int = 3) -> Tuple[bool, str]:
for _ in range(max_retry):
try:
res = requests.get(url, timeout=15)
if res.status_code == 200:
res.raise_for_status()
else:
raise RuntimeError("Failed to load page, code {}".format(res.status_code))
except Exception:
# failed, retry
res = None
continue
res.encoding = res.apparent_encoding
content = res.text
break
if res is None:
return False, "Timeout for loading this page, Please try to load another one or search again."
try:
soup = BeautifulSoup(content, 'html.parser')
paragraphs = soup.find_all('p')
page_detail = ""
for p in paragraphs:
text = p.get_text().strip()
page_detail += text
return True, page_detail
except Exception:
return False, "Timeout for loading this page, Please try to load another one or search again."
class CONTENT_TYPE(Enum):
SEARCH_RESULT = 0
RESULT_TARGET_PAGE = 1
class ContentItem:
def __init__(self, type: CONTENT_TYPE, data):
self.type = type
self.data = data
class DigestData:
title: str
desc: str
chunkIndex: int
class Digest:
datas: list
checked: bool
class SessionData:
topic = None
content = []
digests = []
curResultChunk = 0
curTargetPageResultChunk = 0
data = SessionData()
def build_tool(config) -> Tool:
tool = Tool(
"Bing_search",
"Bing_search",
name_for_model="Bing_search",
name_for_human="Bing_search",
description_for_model="""Perform Search on Bing Search engine.
Use search_top3(key: str) to get top 3 search results after input the key to search.
Use load_page_index(idx: int) to load the detailed page of the search result.""",
description_for_human="Bing search API for browsing the internet and search for results.",
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
if "debug" in config and config["debug"]:
bing_api = config["bing_api"]
else:
bing_api = BingAPI(config["subscription_key"])
@tool.get("/search_top3")
def search_top3(key_words: str) -> str:
"""Search key words, return top 3 search results.
"""
top3 = search_all(key_words)[:3]
output = ""
for idx, item in enumerate(top3):
output += "page: " + str(idx+1) + "\n"
output += "title: " + item['name'] + "\n"
output += "summary: " + item['snippet'] + "\n"
return output
def search_all(key_words: str, data: SessionData = data) -> list:
"""Search key_words, return a list of class SearchResult.
Keyword arguments:
key_words -- key words want to search
"""
result = bing_api.search(key_words)
data.content = []
data.content.append(ContentItem(CONTENT_TYPE.SEARCH_RESULT, result))
data.curResultChunk = 0
return data.content[-1].data["webPages"]["value"]
@tool.get("/load_page_index")
def load_page_index(idx: str) -> str:
"""Load page detail of the search result indexed as 'idx', and return the content of the page.
"""
idx = int(idx)
href, text = load_page(idx-1)
if len(text) > 500:
return text[:500]
else:
return text
def load_page(idx : int, data: SessionData = data):
top = data.content[-1].data["webPages"]["value"]
ok, content = bing_api.load_page(top[idx]['url'])
if ok:
return top[idx]['url'], content
else:
return " ", "Timeout for loading this page, Please try to load another one or search again."
return tool

@ -0,0 +1,3 @@
# Bing search tool
Contributor [ChengQian](https://github.com/qiancheng0)

@ -0,0 +1,67 @@
from fastapi.testclient import TestClient
from .api import build_tool, BingAPI
from typing import Tuple
BING_TEST_SEARCH = {
"webPages": {
"value": [
{
"url": "a",
"name": "test a",
"snippet": "page a"
},
{
"url": "b",
"name": "test b",
"snippet": "page b"
},
{
"url": "c",
"name": "test c",
"snippet": "page c"
}
]
}
}
class MockBingAPI(BingAPI):
def __init__(self):
pass
def search(self, key_words : str, max_retry : int = 3):
return BING_TEST_SEARCH
def load_page(self, url : str, max_retry : int = 3) -> Tuple[bool, str]:
if url == "a":
return True, "This is page a"
elif url == "b":
return True, "This is page b"
elif url == "c":
return True, "This is page c"
else:
return False, "Timeout for loading this page, Please try to load another one or search again."
app = build_tool({"debug": True, "bing_api": MockBingAPI()})
client = TestClient(app)
def test_bing():
# test search top 3
response = client.get("/search_top3", params={"key_words": "test"})
output = ""
for idx, item in enumerate(BING_TEST_SEARCH["webPages"]["value"]):
output += "page: " + str(idx+1) + "\n"
output += "title: " + item['name'] + "\n"
output += "summary: " + item['snippet'] + "\n"
assert response.status_code == 200
assert response.json() == output
# test load page
response = client.get("/load_page_index", params={"idx": "1"})
assert response.status_code == 200
assert response.json() == "This is page a"
response = client.get("/load_page_index", params={"idx": "2"})
assert response.status_code == 200
assert response.json() == "This is page b"

@ -0,0 +1,6 @@
from ..registry import register
@register("chemical-prop")
def chemical_prop():
from .prop import build_tool
return build_tool

@ -0,0 +1 @@
from .api import build_tool

@ -0,0 +1,150 @@
import requests
from pydantic import BaseModel
from bs4 import BeautifulSoup
import json, random
from ...tool import Tool
from typing import List, Optional, Union
class ChemicalPropAPI:
def __init__(self) -> None:
self._endpoint = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/"
def get_name_by_cid(self, cid : str, top_k : Optional[int] = None) -> List[str]:
html_doc = requests.get(f"{self._endpoint}cid/{cid}/synonyms/XML").text
soup = BeautifulSoup(html_doc, "html.parser", from_encoding="utf-8")
syns = soup.find_all('synonym')
ans = []
if top_k is None:
top_k = len(syns)
for syn in syns[:top_k]:
ans.append(syn.text)
return ans
def get_cid_by_struct(self, smiles : str) -> List[str]:
html_doc = requests.get(f"{self._endpoint}smiles/{smiles}/cids/XML").text
soup = BeautifulSoup(html_doc,"html.parser",from_encoding="utf-8")
cids = soup.find_all('cid')
if cids is None:
return []
ans = []
for cid in cids:
ans.append(cid.text)
return ans
def get_cid_by_name(self, name : str, name_type : Optional[str] = None) -> List[str]:
url = f"{self._endpoint}name/{name}/cids/XML"
if name_type is not None:
url += f"?name_type={name_type}"
html_doc = requests.get(url).text
soup = BeautifulSoup(html_doc,"html.parser",from_encoding="utf-8")
cids = soup.find_all('cid')
if cids is None:
return []
ans = []
for cid in cids:
ans.append(cid.text)
return ans
def get_prop_by_cid(self, cid : str) -> str:
html_doc = requests.get(f"{self._endpoint}cid/{cid}/property/MolecularFormula,MolecularWeight,CanonicalSMILES,IsomericSMILES,IUPACName,XLogP,ExactMass,MonoisotopicMass,TPSA,Complexity,Charge,HBondDonorCount,HBondAcceptorCount,RotatableBondCount,HeavyAtomCount,CovalentUnitCount/json").text
return json.loads(html_doc)['PropertyTable']['Properties'][0]
class GetNameResponse(BaseModel):
"""name list"""
names: List[str]
class GetStructureResponse(BaseModel):
"""structure list"""
state : int
content : Optional[str] = None
class GetIDResponse(BaseModel):
state : int
content : Union[str, List[str]]
def build_tool(config) -> Tool:
tool = Tool(
"Chemical Property Plugin",
description="looking up a chemical's property",
name_for_model="Chemical Property",
description_for_model="Plugin for looking up a chemical's property using a chemical knowledge base. All input should be a json like {'input': 'some input'}. Please use the provided questions and search step by step.",
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com",
)
if "debug" in config and config["debug"]:
chemical_prop_api = config["chemical_prop_api"]
else:
chemical_prop_api = ChemicalPropAPI()
@tool.get("/get_name")
def get_name( cid: str ):
"""prints the possible 3 synonyms of the queried compound ID"""
ans = chemical_prop_api.get_name_by_cid(cid, top_k=3)
return {
"names": ans
}
@tool.get("/get_allname")
def get_allname( cid: str ):
"""prints all the possible synonyms (might be too many, use this function carefully).
"""
ans = chemical_prop_api.get_name_by_cid(cid)
return {
"names": ans
}
@tool.get("/get_id_by_struct")
def get_id_by_struct(smiles : str):
"""prints the ID of the queried compound SMILES. This should only be used if smiles is provided or retrieved in the previous step. The input should not be a string, but a SMILES formula.
"""
cids = chemical_prop_api.get_cid_by_struct(smiles)
if len(cids) == 0:
return {
"state": "no result"
}
else:
return {
"state": "matched",
"content": cids[0]
}
@tool.get("/get_id")
def get_id(name : str):
"""prints the ID of the queried compound name, and prints the possible 5 names if the queried name can not been precisely matched,
"""
cids = chemical_prop_api.get_cid_by_name(name)
if len(cids) > 0:
return {
"state": "precise",
"content": cids[0]
}
cids = chemical_prop_api.get_cid_by_name(name, name_type="word")
if len(cids) > 0:
if name in get_name(cids[0]):
return {
"state": "precise",
"content": cids[0]
}
ans = []
random.shuffle(cids)
for cid in cids[:5]:
nms = get_name(cid)
ans.append(nms)
return {
"state": "not precise",
"content": ans
}
@tool.get("/get_prop")
def get_prop(cid : str):
"""prints the properties of the queried compound ID
"""
return chemical_prop_api.get_prop_by_cid(cid)
return tool

@ -0,0 +1,38 @@
# Chemical Properties
Contributor: [Zheni Zeng](https://github.com/Ellenzzn)
## Tool Description
The tool, "Chemical Property Plugin," provides the ability to lookup a chemical's properties by querying a chemical knowledge base. The tool accepts the input in a JSON format, like {'input': 'some input'} and guides you to ask questions and search step by step.
### Tool Specifications
- **Name**: Chemical Property Plugin
- **Purpose**: Plugin for looking up a chemical's property using a chemical knowledge base
- **Logo**: ![Chemical Property Plugin Logo](https://your-app-url.com/.well-known/logo.png)
- **Contact Email**: hello@contact.com
- **Legal Information**: [Legal Information](hello@legal.com)
### Core Functionality
1. `get_name`
This method accepts a Compound ID (CID) and returns the top 3 synonyms for the queried compound.
2. `get_allname`
This method accepts a Compound ID (CID) and returns all the possible synonyms for the queried compound. Be aware that the number of returned names can be large, so use this function with caution.
3. `get_id_by_struct`
This method accepts a SMILES formula and returns the ID of the queried compound. This method should be used only if the SMILES formula is provided or retrieved in the previous step.
4. `get_id`
This method accepts a compound name and returns the ID of the queried compound. If the name cannot be precisely matched, it will return the possible names.
5. `get_prop`
This method accepts a Compound ID (CID) and returns the properties of the queried compound.
The tool is made possible through the use of the ChemicalPropAPI, which interacts with a chemical knowledge base.

@ -0,0 +1,62 @@
from fastapi.testclient import TestClient
from .api import build_tool, ChemicalPropAPI
from typing import Tuple, Optional, List
class ChemicalPropMock(ChemicalPropAPI):
def __init__(self) -> None:
self._endpoint = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/"
def get_name_by_cid(self, cid : str, top_k : Optional[int] = None) -> List[str]:
ans = ["A", "B", "C", "D", "E"]
if top_k is None:
top_k = len(ans)
return ans[:top_k]
def get_cid_by_struct(self, smiles : str) -> List[str]:
return ["123"]
def get_cid_by_name(self, name : str, name_type : Optional[str] = None) -> List[str]:
return ["123"]
def get_prop_by_cid(self, cid : str) -> str:
return {
"works": "well"
}
app = build_tool({"debug": True, "chemical_prop_api": ChemicalPropMock()})
client = TestClient(app)
def test_get_name():
response = client.get("/get_name", params={"cid": 123})
assert response.status_code == 200
assert response.json() == {"names": ["A", "B", "C"]}
def test_get_all_names():
response = client.get("/get_allname", params={"cid": 123})
assert response.status_code == 200
assert response.json() == {"names": ["A", "B", "C", "D", "E"]}
def test_get_id_by_struct():
response = client.get("/get_id_by_struct", params={"smiles": "C1=CC=CC=C1"})
assert response.status_code == 200
assert response.json() == {
"state": "matched",
"content": "123"
}
def test_get_id():
response = client.get("/get_id", params={"name": "benzene"})
assert response.status_code == 200
assert response.json() == {
"state": "precise",
"content": "123",
}
def test_get_prop():
response = client.get("/get_prop", params={"cid": "123"})
assert response.status_code == 200
assert response.json() == {
"works": "well"
}

@ -0,0 +1,7 @@
from ..registry import register
@register("code_interpreter")
def code_interpreter():
from .api import build_tool
return build_tool

@ -0,0 +1,47 @@
from ..tool import Tool
class CodeInterpreter:
def __init__(self, timeout=300):
self.globals = {}
self.locals = {}
self.timeout = timeout
def execute_code(self, code):
try:
# Wrap the code in an eval() call to return the result
wrapped_code = f"__result__ = eval({repr(code)}, globals(), locals())"
exec(wrapped_code, self.globals, self.locals)
return self.locals.get('__result__', None)
except Exception as e:
try:
# If eval fails, attempt to exec the code without returning a result
exec(code, self.globals, self.locals)
return "Code executed successfully."
except Exception as e:
return f"Error: {str(e)}"
def reset_session(self):
self.globals = {}
self.locals = {}
def build_tool(config) -> Tool:
tool = Tool(
"Python Code Interpreter Tool",
"Execute Python Codes",
name_for_model="code_interpreter",
description_for_model="Plugin for executing python codes",
logo_url=None,
contact_email=None,
legal_info_url=None
)
# Usage example
interpreter = CodeInterpreter()
@tool.get("/execute_code")
def execute_python_code(code: str):
'''execute Python expressions with Python Interpreter, can be used as a simple calculator e.g., "(123 + 234) / 23 * 19"
'''
return interpreter.execute_code(code)
return tool

@ -0,0 +1,17 @@
## Tool Description
The "Python Code Interpreter Tool" is a handy tool designed to execute Python code. It can be used as a simple Python interpreter or a calculator to compute Python expressions.
### Tool Specifications
- **Name**: Python Code Interpreter Tool
- **Purpose**: Execute Python codes
- **Name for Model**: code_interpreter
- **Model Description**: Plugin for executing python codes
### Core Functionality
1. `execute_code`
This method accepts a Python code string as input and executes it using a Python interpreter. It can also be used as a simple calculator. For example, sending a request with code `(123 + 234) / 23 * 19` will return the result of this expression.
The tool makes use of the `CodeInterpreter` class which is capable of running Python code.

@ -0,0 +1,6 @@
from ..registry import register
@register("database")
def register_database_tool():
from .api import build_database_tool
return build_database_tool

@ -0,0 +1,154 @@
#!/usr/bin/env python
# coding=utf-8
import json
import os
from ..tool import Tool
from bmtools.tools.database.utils.db_parser import get_conf
from bmtools.tools.database.utils.database import DBArgs, Database
import openai
from typing import Optional, List, Mapping, Any
import requests, json
def build_database_tool(config) -> Tool:
tool = Tool(
"Data in a database",
"Look up user data",
name_for_model="Database",
description_for_model="Plugin for querying the data in a database",
logo_url="https://commons.wikimedia.org/wiki/File:Postgresql_elephant.svg",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
URL_REWRITE= "http://8.131.229.55:5114/rewrite"
# load db settings
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
config = get_conf(script_dir + '/my_config.ini', 'postgresql')
dbargs = DBArgs("postgresql", config=config) # todo assign database name
# send request to database
db = Database(dbargs, timeout=-1)
schema = ""
query = ""
@tool.get("/get_database_schema")
#def get_database_schema(query : str='select * from customer limit 2;', db_name : str='tpch10x'):
def get_database_schema(db_name : str='tpch10x'):
global schema
#todo simplify the schema based on the query
print("=========== database name:", db_name)
schema = db.compute_table_schema()
print("========schema:", schema)
text_output = f"The database schema is:\n" + "".join(str(schema))
return text_output
@tool.get("/translate_nlp_to_sql")
def translate_nlp_to_sql(description : str):
global schema, query
"""translate_nlp_to_sql(description: str) translates the input nlp string into sql query based on the database schema, and the sql query is the input of rewrite_sql and select_database_data API.
description is a string that represents the description of the result data.
schema is a string that represents the database schema.
Final answer should be complete.
This is an example:
Thoughts: Now that I have the database schema, I will use the \\\'translate_nlp_to_sql\\\' command to generate the SQL query based on the given description and schema, and take the SQL query as the input of the \\\'rewrite_sql\\\' and \\\'select_database_data\\\' commands.
Reasoning: I need to generate the SQL query accurately based on the given description. I will use the \\\'translate_nlp_to_sql\\\' command to obtain the SQL query based on the given description and schema, and take the SQL query as the input of the \\\'select_database_data\\\' command.
Plan: - Use the \\\'translate_nlp_to_sql\\\' command to generate the SQL query. \\\\n- Use the \\\'finish\\\' command to signal that I have completed all my objectives.
Command: {"name": "translate_nlp_to_sql", "args": {"description": "Retrieve the comments of suppliers . The results should be sorted in descending order based on the comments of the suppliers."}}
Result: Command translate_nlp_to_sql returned: "SELECT s_comment FROM supplier BY s_comment DESC"
"""
openai.api_key = os.environ["OPENAI_API_KEY"]
# schema = db.compute_table_schema()
prompt = """Translate the natural language description into an semantic equivalent SQL query.
The table and column names used in the sql must exactly appear in the schema. Any other table and column names are unacceptable.
The schema is:\n
{}
The description is:\n
{}
The SQL query is:
""".format(schema, description)
# Set up the OpenAI GPT-3 model
model_engine = "gpt-3.5-turbo"
prompt_response = openai.ChatCompletion.create(
engine=model_engine,
messages=[
{"role": "assistant", "content": "The table schema is as follows: " + schema},
{"role": "user", "content": prompt}
]
)
output_text = prompt_response['choices'][0]['message']['content']
query = output_text
return output_text
@tool.get("/select_database_data")
def select_database_data(query : str):
"""select_database_data(query : str) Read the data stored in database based on the SQL query from the translate_nlp_to_sql API.
query : str is a string that represents the SQL query outputted by the translate_nlp_to_sql API.
Final answer should be complete.
This is an example:
Thoughts: Now that I have the database schema and SQL query, I will use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query
Reasoning: I will use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query
Plan: - Use the \\\'select_database_data\\\' command to retrieve the data from the database based on the SQL query.\\\\n- Use the \\\'finish\\\' command to signal that I have completed all my objectives.
Command: {"name": "select_database_data", "args": {query: "SELECT s_comment FROM supplier BY s_comment DESC"}}
Result: Command select_database_data returned: "The number of result rows is: 394"
"""
if query == "":
raise RuntimeError("SQL query is empty")
print("=========== database query:", query)
res_completion = db.pgsql_results(query) # list format
if res_completion == "<fail>":
raise RuntimeError("Database query failed")
#data = json.loads(str(res_completion).strip())
if isinstance(res_completion, list):
text_output = f"The number of result rows is: "+"".join(str(len(res_completion)))
else:
text_output = f"The number of result rows is: "+"".join(str(res_completion))
return text_output
@tool.get("/rewrite_sql")
def rewrite_sql(sql: str="select distinct l_orderkey, sum(l_extendedprice + 3 + (1 - l_discount)) as revenue, o_orderkey, o_shippriority from customer, orders, lineitem where c_mktsegment = 'BUILDING' and c_custkey = o_custkey and l_orderkey = o_orderkey and o_orderdate < date '1995-03-15' and l_shipdate > date '1995-03-15' group by l_orderkey, o_orderkey, o_shippriority order by revenue desc, o_orderkey;"):
'''Rewrite the input sql query
'''
param = {
"sql": sql
}
print("Rewriter param:", param)
headers = {'Content-Type': 'application/json'}
res_completion = requests.post(URL_REWRITE, data=json.dumps(param), headers=headers)
#print("============ res_completion", res_completion.text)
data = json.loads(res_completion.text.strip())
data = data.get('data')
text_output = f"Rewritten sql is:\n"+data.get('rewritten_sql')
return text_output
return tool

@ -0,0 +1,12 @@
[postgresql]
host =
port =
user =
password =
dbname =
[mysql]
host =
port =
user =
password =

@ -0,0 +1,96 @@
import json
import json
from difflib import ndiff
import psycopg2
import time
'''
prepare the test samples
'''
def execute_sql(sql):
conn = psycopg2.connect(database='tpch10x',
user='xxx',
password='xxx',
host='xxx',
port=xxx)
cur = conn.cursor()
cur.execute(sql)
# res = cur.fetchall()[0][0][0]
res = cur.fetchall()
conn.commit()
cur.close()
conn.close()
return len(res)
# Load the JSON file as a dictionary
data = {}
with open('text2res_single_table.json', 'r') as f:
data = json.load(f)
# Select only the diverse SQL statements
# Find SQL statements with an edit distance of less than 10
selected_sql = []
for sql1 in data:
if 'sql' in sql1:
sql1 = sql1['sql']
print("==========sql", sql1)
start_time = time.time()
res_cnt = execute_sql(sql1)
end_time = time.time()
elapsed_time = end_time - start_time
print(res_cnt, elapsed_time)
selected_sql.append({f"sql": sql1, 'res_cnt': res_cnt, 'execution_time': elapsed_time})
# Write the dictionary to a JSON file
with open("text2res_single_table2.json", "w") as f:
json.dump(selected_sql, f)
'''
add text descriptions for queries
'''
if __name__ == "__main__":
llm = LLM() # add the def of your llm
with open('./tpch10x/text2res_single_table2.json', 'r') as json_file:
json_data = json.load(json_file)
new_json_data = []
for i,item in enumerate(json_data):
sql = item['sql']
print("========= ", i, sql)
prompt = "Please convert the following sql query into one natural language sentence: \n" + sql + "\n Note. 1) Do not mention any other information other than the natural language sentence; 2) Must use the origin table and column names in the sql query."
text = llm(prompt)
item['text'] = text
new_json_data.append(item)
#print(llm("Describe Shanghai in 200 words."))
with open("text2res_single_table3.json", "w") as f:
json.dump(new_json_data, f)
'''
calculate total execution time
'''
with open('text2res_origin.json', 'r') as json_file:
json_data = json.load(json_file)
total_time = 0
for i,item in enumerate(json_data):
print(item['execution_time'])
total_time = total_time + float(item['execution_time'])
print(total_time)

@ -0,0 +1,13 @@
[{"sql": "SELECT l_tax, o_totalprice FROM lineitem JOIN orders ON o_orderkey = l_orderkey WHERE l_linenumber >= 3 AND o_orderkey <> 784709 AND l_orderkey <= 189383 AND o_clerk < 'Clerk#000000181'", "res_cnt": "1707", "execution_time": "4.83", "text": "Retrieve the tax rate and total price from the lineitem and orders tables where the line number is greater than or equal to 3, the order key is not equal to 784709, the order key is less than or equal to 189383, and the clerk's ID is less than 'Clerk#000000181'."},
{"sql": "SELECT p_type, ps_availqty, SUM(ps_suppkey) FROM part JOIN partsupp ON ps_partkey = p_partkey WHERE ps_suppkey <> 3804 AND ps_partkey <= 57823 AND ps_availqty < 4781 GROUP BY p_type, ps_availqty HAVING SUM(ps_suppkey) > 1089 ORDER BY SUM(ps_suppkey) ASC", "res_cnt": "100967", "execution_time": "2.25", "text": "Retrieve the part type, available quantity of parts, and the sum of supplier keys from the part and partsupp tables where the supplier key is not equal to 3804, the part key is less than or equal to 57823, and the available quantity of parts is less than 4781. Group the results by part type and available quantity of parts, and only include groups where the sum of supplier keys is greater than 1089. Sort the results in ascending order based on the sum of supplier keys."},
{"sql": "SELECT c_phone, o_totalprice, n_comment, r_comment FROM orders JOIN customer ON c_custkey = o_custkey JOIN nation ON n_nationkey = c_nationkey JOIN region ON r_regionkey = n_regionkey WHERE n_nationkey < 8 AND o_orderstatus >= 'O' AND o_comment < 'ly around the pending theodo' ORDER BY c_phone ASC, n_comment ASC, o_totalprice ASC, r_comment ASC", "res_cnt": "1249285", "execution_time": "32.55", "text": "Retrieve the phone number of the customer, the total price of the order, the comment of the nation, and the comment of the region from the orders table, customer table, nation table, and region table, respectively, where the nation key is less than 8, the order status is greater than or equal to 'O', and the order comment is less than 'ly around the pending theodo'. Sort the result by customer phone number in ascending order, nation comment in ascending order, order total price in ascending order, and region comment in ascending order."},
{"sql": "SELECT s_acctbal, ps_supplycost, r_regionkey, n_name FROM region JOIN nation ON n_regionkey = r_regionkey JOIN supplier ON s_nationkey = n_nationkey JOIN partsupp ON ps_suppkey = s_suppkey WHERE r_name <= 'AFRICA' AND n_comment >= 'l platelets. regular accounts x-ray: unusual, regular acco' AND s_nationkey >= 0", "res_cnt": "1272240", "execution_time": "22.09", "text": "Retrieve the account balance, supply cost, region key, and nation name from the region, nation, supplier, and partsupp tables where the region name is less than or equal to 'AFRICA', the nation comment is greater than or equal to 'l platelets. regular accounts x-ray: unusual, regular acco', and the supplier nation key is greater than or equal to 0."},
{"sql": "SELECT o_orderkey, c_address FROM customer JOIN orders ON o_custkey = c_custkey WHERE c_phone <> '29-716-678-7355' AND o_custkey <= 16201 AND o_totalprice > 29849.7 AND o_clerk <> 'Clerk#000000361' AND o_shippriority >= 0 ORDER BY c_address DESC", "res_cnt": "150302", "execution_time": "2.55", "text": "Retrieve the order key and customer address from the customer and orders tables where the customer phone number is not '29-716-678-7355', the customer key is less than or equal to 16201, the order total price is greater than 29849.7, the order clerk is not 'Clerk#000000361', and the order ship priority is greater than or equal to 0. Sort the results by customer address in descending order."},
{"sql": "SELECT s_comment, p_size, ps_supplycost FROM part JOIN partsupp ON ps_partkey = p_partkey JOIN supplier ON s_suppkey = ps_suppkey WHERE ps_availqty = 6331 AND p_type > 'LARGE POLISHED NICKEL' AND p_retailprice < 1758.76 ORDER BY s_comment DESC", "res_cnt": "394", "execution_time": "7.79", "text": "Retrieve the comments of suppliers, size of parts, and supply cost of part-supplier combinations where the available quantity of the part is 6331, the type of the part is greater than 'LARGE POLISHED NICKEL', and the retail price of the part is less than 1758.76. The results should be sorted in descending order based on the comments of the suppliers."},
{"sql": "SELECT l_shipdate FROM lineitem WHERE l_extendedprice >= 50883.12 AND l_linenumber > 1 AND l_shipdate <> '1992-08-30' AND l_returnflag = 'A' ORDER BY l_shipdate DESC", "res_cnt": "3376197", "execution_time": "93.75", "text": "Retrieve the shipment dates from the lineitem table where the extended price is greater than or equal to 50883.12, the linenumber is greater than 1, the shipdate is not equal to '1992-08-30', and the return flag is 'A', and sort the results in descending order based on the shipment date."},
{"sql": "SELECT c_acctbal, o_orderpriority, n_name FROM nation JOIN customer ON c_nationkey = n_nationkey JOIN orders ON o_custkey = c_custkey WHERE c_comment = 'ar deposits believe special, express foxes. packages cajole slyly e' AND n_name <> 'JAPAN' AND c_mktsegment = 'HOUSEHOLD' AND o_totalprice < 110238.65 AND c_name <= 'Customer#000013191'", "res_cnt": "7", "execution_time": "1.88", "text": "Retrieve the account balance, order priority, and nation name for customers who have a comment of 'ar deposits believe special, express foxes. packages cajole slyly e', are not from Japan, have a market segment of 'HOUSEHOLD', have a total order price less than 110238.65, and have a name less than or equal to 'Customer#000013191'."},
{"sql": "SELECT p_type, ps_comment FROM part JOIN partsupp ON ps_partkey = p_partkey WHERE ps_availqty <> 1078 AND p_type < 'PROMO BURNISHED NICKEL' AND p_size > 8 AND p_container < 'LG CAN'", "res_cnt": "974546", "execution_time": "13.71", "text": "Retrieve the part type and part supplier comment from the Part and Partsupp tables where the available quantity of the part supplier is not equal to 1078, the part type is less than 'PROMO BURNISHED NICKEL', the part size is greater than 8, and the part container is less than 'LG CAN'."},
{"sql": "SELECT ps_comment FROM partsupp WHERE ps_availqty >= 9324 AND ps_suppkey <> 1716 AND ps_partkey >= 65143 AND ps_supplycost < 164.19 AND ps_comment <> 's use slyly pending instructions. furiously final ideas shall have to are c'", "res_cnt": "85446", "execution_time": "9.15", "text": "Retrieve the comments from the \"partsupp\" table where the available quantity is greater than or equal to 9324, the supplier key is not equal to 1716, the part key is greater than or equal to 65143, the supply cost is less than 164.19, and the comment is not equal to 's use slyly pending instructions. furiously final ideas shall have to are c'."},
{"sql": "SELECT r_name, s_suppkey, n_regionkey FROM supplier, nation, region WHERE r_regionkey >= 1 AND s_suppkey <= 9696 AND r_comment <> 'uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl' AND s_name < 'Supplier#000008309' AND s_phone <> '19-247-536-8083' ORDER BY s_suppkey ASC, n_regionkey DESC, r_name ASC", "res_cnt": "623025", "execution_time": "3.5", "text": "Retrieve the supplier name, supplier key, and region key from the supplier, nation, and region tables where the region key is greater than or equal to 1, the supplier key is less than or equal to 9696, the region comment is not equal to 'uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl', the supplier name is less than 'Supplier#000008309', and the supplier phone is not equal to '19-247-536-8083', and sort the results by supplier key in ascending order, region key in descending order, and region name in ascending order."},
{"sql": "SELECT o_orderpriority FROM orders WHERE o_orderpriority > '3-MEDIUM' AND o_totalprice > 130861.55 AND o_comment < 'inally pending packages sleep along the furiously special' AND o_custkey <= 16480 AND o_shippriority <= 0 AND o_orderdate <> '1997-02-20' ORDER BY o_orderpriority ASC", "res_cnt": "14448", "execution_time": "1.70", "text": "Retrieve the order priority from the \"orders\" table where the order priority is greater than '3-MEDIUM', the total price is greater than 130861.55, the comment is less than 'finally pending packages sleep along the furiously special', the customer key is less than or equal to 16480, the ship priority is less than or equal to 0, and the order date is not equal to '1997-02-20', and sort the results in ascending order based on the order priority."}
]

File diff suppressed because one or more lines are too long

@ -0,0 +1,12 @@
[{"sql": "SELECT o_clerk FROM orders WHERE o_clerk < 'Clerk#000000377' AND o_shippriority = 0 AND o_orderstatus <> 'P' ORDER BY o_clerk DESC", "res_cnt": 548984, "execution_time": 7.67071008682251, "text": "Retrieve the clerk names from the orders table where the clerk name is less than 'Clerk#000000377', the ship priority is 0, and the order status is not 'P', and sort the results in descending order based on the clerk name."},
{"sql": "SELECT MIN(l_orderkey) FROM lineitem GROUP BY l_comment ORDER BY l_comment ASC", "res_cnt": 34378943, "execution_time": 606.0404827594757, "text": "Find the minimum value of the \"l_orderkey\" column from the \"lineitem\" table, grouped by the values in the \"l_comment\" column, and sort the result in ascending order based on the values in the \"l_comment\" column."},
{"sql": "SELECT s_comment FROM supplier WHERE s_nationkey >= 6 ORDER BY s_comment ASC", "res_cnt": 76084, "execution_time": 2.5723021030426025, "text": "Retrieve the comments of suppliers whose nation key is greater than or equal to 6, and sort the results in ascending order based on the comments."},
{"sql": "SELECT ps_supplycost FROM partsupp WHERE ps_availqty <> 5628", "res_cnt": 7999236, "execution_time": 40.69241118431091, "text": "Retrieve the supply cost from the \"partsupp\" table where the available quantity is not equal to 5628."},
{"sql": "SELECT MIN(l_shipdate) FROM lineitem WHERE l_shipinstruct >= 'TAKE BACK RETURN' GROUP BY l_commitdate ORDER BY l_commitdate ASC", "res_cnt": 2466, "execution_time": 83.76757621765137, "text": "Find the earliest shipment date (minimum l_shipdate) for each unique l_commitdate where the shipping instruction is 'TAKE BACK RETURN', and sort the results by l_commitdate in ascending order."},
{"sql": "SELECT MIN(p_type) FROM part WHERE p_container <> 'SM DRUM' GROUP BY p_partkey ORDER BY MIN(p_type) ASC", "res_cnt": 1949902, "execution_time": 26.379438877105713, "text": "Find the minimum value of the column \"p_type\" from the table \"part\" for each unique value in the column \"p_partkey\", but only for rows where the value in the column \"p_container\" is not equal to 'SM DRUM', and sort the results in ascending order based on the minimum value of \"p_type\"."},
{"sql": "SELECT l_suppkey FROM lineitem WHERE l_linenumber <> 1", "res_cnt": 44986052, "execution_time": 262.93306398391724, "text": "Retrieve the supplier keys from the \"lineitem\" table where the line number is not equal to 1."},
{"sql": "SELECT l_shipinstruct FROM lineitem WHERE l_linestatus > 'F' AND l_receiptdate >= '1994-07-02' AND l_tax < 0.04 AND l_returnflag > 'A' ORDER BY l_shipinstruct ASC", "res_cnt": 13322500, "execution_time": 286.6735632419586, "text": "Retrieve the shipping instructions from the lineitem table where the linestatus is greater than 'F', the receipt date is on or after '1994-07-02', the tax is less than 0.04, and the return flag is greater than 'A', and sort the results in ascending order based on the shipping instructions."},
{"sql": "SELECT l_shipdate FROM lineitem WHERE l_shipdate <= '1996-11-15' AND l_shipmode = 'RAIL' AND l_receiptdate < '1994-07-31'", "res_cnt": 3083223, "execution_time": 83.14104127883911, "text": "Retrieve the shipment date from the lineitem table where the shipment date is on or before November 15th, 1996, the shipment mode is 'RAIL', and the receipt date is before July 31st, 1994."},
{"sql": "SELECT ps_suppkey, SUM(ps_partkey), AVG(ps_availqty), SUM(ps_supplycost) FROM partsupp WHERE ps_partkey < 73880 AND ps_availqty <> 9160 AND ps_supplycost < 892.65 GROUP BY ps_suppkey ORDER BY AVG(ps_availqty) ASC", "res_cnt": 99828, "execution_time": 4.685328960418701, "text": "Retrieve the supplier key, the sum of part keys, the average available quantity, and the sum of supply costs from the PartsSupp table where the part key is less than 73880, the available quantity is not equal to 9160, and the supply cost is less than 892.65. Group the results by supplier key and sort them in ascending order based on the average available quantity."},
{"sql": "SELECT c_phone FROM customer WHERE c_custkey < 73501 AND c_nationkey <= 23", "res_cnt": 70506, "execution_time": 5.223631858825684, "text": "Retrieve the phone numbers of customers whose customer key is less than 73501 and whose nation key is less than or equal to 23 from the \"customer\" table."},
{"sql": "SELECT MIN(s_name) FROM supplier WHERE s_address > 'pIXH,lXMVPMknhTIXb4owWLtOvOmsdb' GROUP BY s_name HAVING MIN(s_name) > 'Supplier#000003443'", "res_cnt": 32766, "execution_time": 1.8696420192718506, "text": "Retrieve the minimum value of the column \"s_name\" from the table \"supplier\" for those records where the value in the column \"s_address\" is greater than 'pIXH,lXMVPMknhTIXb4owWLtOvOmsdb', group the results by the column \"s_name\", and only include those groups where the minimum value of the column \"s_name\" is greater than 'Supplier#000003443'."}]

@ -0,0 +1,52 @@
# Database Tool
Contributor: [Xuanhe Zhou](https://github.com/zhouxh19)
### API Functions
- *get_database_schema*: obtain the information of target tables
- *select_database_data*: fetch the query results from a database instance
- *rewrite_sql*: transform a sql query into an semantic-equivalent but execution-efficient sql
### Dataset
- Text2SQL Dataset
- *./data/tpch10x/text2res_multi_table.json*: relativley complex database queries (2-6 tables)
- *./data/tpch10x/text2res_single_table.json*: basic database queries
- SQL Optimization Dataset
- Samples for *[sql rewrite](https://github.com/TsinghuaDatabaseGroup/lmdb/tree/main/query_rewrite/data)*
- Samples for *[index tuning](https://github.com/TsinghuaDatabaseGroup/lmdb/tree/main/index_tuning/data)*
### Setup
1. Follow the steps in [main readme](https://github.com/OpenBMB/BMTools/blob/main/README.md)
2. Rename config.ini.template into my_config.ini
3. Configure the adopted LLM model in the 84th line of ../../agent/singletool.py, e.g.,
```bash
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key)
```
4. Modify database settings in my_config.ini, e.g.,
```bash
[{db_system}]
host = 127.0.0.1
port = 5432
user = postgres
password = postgres
dbname = postgres
```
And rename *config.ini* into *my_config.ini*.
Note. {db_system} must match with that in ./api.py
4. Modify and run the test.py script to test the tool

@ -0,0 +1,148 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'database', "http://127.0.0.1:8079/tools/database/"
tool_name, tool_config = load_single_tools(tool_name, tool_url)
print(tool_name, tool_config)
stqa = STQuestionAnswerer()
agent = stqa.load_tools(tool_name, tool_config, prompt_type="autogpt") # langchain: react-with-tool-description
# 394
text = "Retrieve the comments of suppliers, size of parts, and supply cost of part-supplier combinations where the available quantity of the part is 6331, the type of the part is greater than 'LARGE POLISHED NICKEL', and the retail price of the part is less than 1758.76. The results should be sorted in descending order based on the comments of the suppliers."
# 1707
# text = "Retrieve the tax rate and total price from the lineitem and orders tables where the line number is greater than or equal to 3, the order key is not equal to 784709, the order key is less than or equal to 189383, and the clerk's ID is less than 'Clerk#000000181'."
# # 100967
# text = "Retrieve the part type, available quantity of parts, and the sum of supplier keys from the part and partsupp tables where the supplier key is not equal to 3804, the part key is less than or equal to 57823, and the available quantity of parts is less than 4781. Group the results by part type and available quantity of parts, and only include groups where the sum of supplier keys is greater than 1089. Sort the results in ascending order based on the sum of supplier keys."
# # 1249285 (fail to generate the sql)
# text = "Retrieve the phone number of the customer, the total price of the order, the comment of the nation, and the comment of the region from the orders table, customer table, nation table, and region table, respectively, where the nation key is less than 8, the order status is greater than or equal to 'O', and the order comment is less than 'ly around the pending theodo'. Sort the result by customer phone number in ascending order, nation comment in ascending order, order total price in ascending order, and region comment in ascending order."
# # 1272240
# text = "Retrieve the account balance, supply cost, region key, and nation name from the region, nation, supplier, and partsupp tables where the region name is less than or equal to 'AFRICA', the nation comment is greater than or equal to 'l platelets. regular accounts x-ray: unusual, regular acco', and the supplier nation key is greater than or equal to 0."
# # 150302
# text = "Retrieve the order key and customer address from the customer and orders tables where the customer phone number is not '29-716-678-7355', the customer key is less than or equal to 16201, the order total price is greater than 29849.7, the order clerk is not 'Clerk#000000361', and the order ship priority is greater than or equal to 0. Sort the results by customer address in descending order."
# 3376197
# text = "Retrieve the shipment dates from the lineitem table where the extended price is greater than or equal to 50883.12, the linenumber is greater than 1, the shipdate is not equal to '1992-08-30', and the return flag is 'A', and sort the results in descending order based on the shipment date." # 7
# 7
# text = "Retrieve the account balance, order priority, and nation name for customers who have a comment of 'ar deposits believe special, express foxes. packages cajole slyly e', are not from Japan, have a market segment of 'HOUSEHOLD', have a total order price less than 110238.65, and have a name less than or equal to 'Customer#000013191'." # 8
# 974546
# text = "Retrieve the part type and part supplier comment from the Part and Partsupp tables where the available quantity of the part supplier is not equal to 1078, the part type is less than 'PROMO BURNISHED NICKEL', the part size is greater than 8, and the part container is less than 'LG CAN'." # 9
# 85446
# text = "Retrieve the comments from the \"partsupp\" table where the available quantity is greater than or equal to 9324, the supplier key is not equal to 1716, the part key is greater than or equal to 65143, the supply cost is less than 164.19, and the comment is not equal to 's use slyly pending instructions. furiously final ideas shall have to are c'." # 10
# 623025 (wrong results ~ directly call the database tool)
# text = "Retrieve the supplier name, supplier key, and region key from the supplier, nation, and region tables where the region key is greater than or equal to 1, the supplier key is less than or equal to 9696, the region comment is not equal to 'uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl', the supplier name is less than 'Supplier#000008309', and the supplier phone is not equal to '19-247-536-8083', and sort the results by supplier key in ascending order, region key in descending order, and region name in ascending order." # 11
# 14448 (wrong results)
# text = "Retrieve the order priority from the \"orders\" table where the order priority is greater than '3-MEDIUM', the total price is greater than 130861.55, the comment is less than 'finally pending packages sleep along the furiously special', the customer key is less than or equal to 16480, the ship priority is less than or equal to 0, and the order date is not equal to '1997-02-20', and sort the results in ascending order based on the order priority." # 12
# rewrite
#text = "SELECT s_comment FROM part As p,partsupp As ps,supplier As s WHERE p.p_partkey = ps.ps_partkey AND s.s_suppkey = ps.ps_suppkey AND ps.ps_availqty = 6331 AND p.p_type > 'LARGE POLISHED NICKEL' AND p.p_retailprice < 1758.76 ORDER BY s_comment DESC;"
# text = "Retrieve the comments of suppliers. The results should be sorted in descending order based on the comments of the suppliers"
text = "Retrieve the comments in the supplier table where the p\_partkey column in the part table matches the ps\_partkey column in the partsupp table, the ps\_availqty column in the partsupp table equals 6331, the p_type column in the part table is greater than 'LARGE POLISHED NICKEL', and the p\_retailprice column in the part table is less than 1758.76."
agent.run([""" First get the database schema via get_database_schema. Next generate the sql query exactly based on the schema and the following description:
\"{}\"
Next rewrite the SQL query and output the total number of rows in the database results of the rewritten SQL query.
Note. 1) Only obtain the database schema once;
2) If an API is successfully called, do not call the same API again;
3) Do not use any image in the output;
4) The db_name is tpch10x;
5) Count the rows of query results by your own and do not output the whole query results.
""".format(text)])
# # unit test: get_database_schema
# agent.run(["""
# Fetch the database schema from a postgresql database named tpch10x.\"
# """])
# # unit test: rewrite_sql
# agent("Rewrite the input query: select * from customer limit 2")
# # unit test: select_database_data
# agent("""
# Output the total number of rows in the query results from a postgresql database based on the following description:
# \"Retrieve all the data from the 'customer' table and limit the output to only the first 2 rows.\"
# """)
''' output (autogpt)
> Entering new LLMChain chain...
Prompt after formatting:
System: You are Tom, Assistant
Your decisions must always be made independently
without seeking user assistance. Play to your strengths
as an LLM and pursue simple strategies with no legal complications.
If you have completed all your tasks,
make sure to use the "finish" command.
GOALS:
{input prompt}
Constraints:
1. ~4000 word limit for short term memory. Your short term memory is short, so immediately save important information to files.
2. If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.
3. No user assistance
4. Exclusively use the commands listed in double quotes e.g. "command name"
Commands:
1. get_database_schema: . Your input should be a json (args json schema): {{"db_name" : string, }} The Action to trigger this API should be get_database_schema and the input parameters should be a json dict string. Pay attention to the type of parameters.
2. select_database_data: Read the data stored in database. Your input should be a json (args json schema): {{"query" : string, }} The Action to trigger this API should be select_database_data and the input parameters should be a json dict string. Pay attention to the type of parameters.
3. rewrite_sql: Get rewritten sql from rewriter. Your input should be a json (args json schema): {{"sql" : string, }} The Action to trigger this API should be rewrite_sql and the input parameters should be a json dict string. Pay attention to the type of parameters.
4. finish: use this to signal that you have finished all your objectives, args: "response": "final response to let people know you have finished your objectives"
Resources:
1. Internet access for searches and information gathering.
2. Long Term memory management.
3. GPT-3.5 powered Agents for delegation of simple tasks.
4. File output.
Performance Evaluation:
1. Continuously review and analyze your actions to ensure you are performing to the best of your abilities.
2. Constructively self-criticize your big-picture behavior constantly.
3. Reflect on past decisions and strategies to refine your approach.
4. Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.
You should only respond in JSON format as described below
Response Format:
{
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user"
},
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
}
}
Ensure the response can be parsed by Python json.loads
System: The current time and date is Wed Apr 26 11:02:25 2023
System: This reminds you of these events from your past:
[]
Human: Determine which next command to use, and respond using the format specified above:
> Finished chain.
{
"thoughts": {
"text": "Since the 'get_database_schema' command did not work, I will try to retrieve the schema manually. I will use the following SQL query to retrieve the schema: 'SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = 'public' ORDER BY table_name, ordinal_position;'",
"reasoning": "If the 'get_database_schema' command is not working, I can manually retrieve the schema using an SQL query. This query will retrieve the names of all the tables and columns in the 'public' schema of the database.",
"plan": "- Use the SQL query 'SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = 'public' ORDER BY table_name, ordinal_position;' to retrieve the schema\n- Generate SQL query based on given description\n- Use 'select_database_data' command to retrieve query results\n- Count number of rows in query results using Python code",
"criticism": "I need to make sure that I am using the correct table and column names from the schema in the SQL query. I also need to make sure that I am using the correct syntax for the SQL query.",
"speak": "I will try to retrieve the schema manually using an SQL query."
},
"command": {
"name": "select_database_data",
"args": {
"query": "SELECT table_name, column_name FROM information_schema.columns WHERE table_schema = 'public' ORDER BY table_name, ordinal_position;"
}
}
}
'''

@ -0,0 +1,27 @@
import random
def subsample_data(data, subsample_size):
"""
Subsample data. Data is in the form of a tuple of lists.
"""
inputs, outputs = data
assert len(inputs) == len(outputs)
indices = random.sample(range(len(inputs)), subsample_size)
inputs = [inputs[i] for i in indices]
outputs = [outputs[i] for i in indices]
return inputs, outputs
def create_split(data, split_size):
"""
Split data into two parts. Data is in the form of a tuple of lists.
"""
inputs, outputs = data
assert len(inputs) == len(outputs)
indices = random.sample(range(len(inputs)), split_size)
inputs1 = [inputs[i] for i in indices]
outputs1 = [outputs[i] for i in indices]
inputs2 = [inputs[i] for i in range(len(inputs)) if i not in indices]
outputs2 = [outputs[i] for i in range(len(inputs)) if i not in indices]
return (inputs1, outputs1), (inputs2, outputs2)

@ -0,0 +1,268 @@
import psycopg2
import pymysql
import json
import logging
import os
from enum import IntEnum
class DataType(IntEnum):
VALUE = 0
TIME = 1
CHAR = 2
AGGREGATE_CONSTRAINTS = {
DataType.VALUE.value: ['count', 'max', 'min', 'avg', 'sum'],
DataType.VALUE.CHAR: ['count', 'max', 'min'],
DataType.VALUE.TIME: ['count', 'max', 'min']
}
def transfer_field_type(database_type, server):
data_type = list()
if server == 'mysql':
data_type = [['int', 'tinyint', 'smallint', 'mediumint', 'bigint', 'float', 'double', 'decimal'],
['date', 'time', 'year', 'datetime', 'timestamp']]
database_type = database_type.lower().split('(')[0]
elif server == 'postgresql':
data_type = [['integer', 'numeric'],
['date']]
if database_type in data_type[0]:
return DataType.VALUE.value
elif database_type in data_type[1]:
return DataType.TIME.value
else:
return DataType.CHAR.value
class DBArgs(object):
def __init__(self, dbtype, config, dbname=None):
self.dbtype = dbtype
if self.dbtype == 'mysql':
self.host = config['host']
self.port = config['port']
self.user = config['user']
self.password = config['password']
self.dbname = dbname if dbname else config['dbname']
self.driver = 'com.mysql.jdbc.Driver'
self.jdbc = 'jdbc:mysql://'
else:
self.host = config['host']
self.port = config['port']
self.user = config['user']
self.password = config['password']
self.dbname = dbname if dbname else config['dbname']
self.driver = 'org.postgresql.Driver'
self.jdbc = 'jdbc:postgresql://'
class Database():
def __init__(self, args, timeout=-1):
self.args = args
self.conn = self.resetConn(timeout)
# self.schema = self.compute_table_schema()
def resetConn(self, timeout=-1):
if self.args.dbtype == 'mysql':
conn = pymysql.connect(
host=self.args.host,
user=self.args.user,
passwd=self.args.password,
database=self.args.dbname,
port=int(self.args.port),
charset='utf8',
connect_timeout=timeout,
read_timeout=timeout,
write_timeout=timeout)
else:
if timeout > 0:
conn = psycopg2.connect(database=self.args.dbname,
user=self.args.user,
password=self.args.password,
host=self.args.host,
port=self.args.port,
options='-c statement_timeout={}s'.format(timeout))
else:
conn = psycopg2.connect(database=self.args.dbname,
user=self.args.user,
password=self.args.password,
host=self.args.host,
port=self.args.port)
return conn
'''
def exec_fetch(self, statement, one=True):
cur = self.conn.cursor()
cur.execute(statement)
if one:
return cur.fetchone()
return cur.fetchall()
'''
def execute_sql(self, sql):
fail = 1
self.conn = self.resetConn()
cur = self.conn.cursor()
i = 0
cnt = 3 # retry times
while fail == 1 and i < cnt:
try:
fail = 0
cur.execute(sql)
except BaseException:
fail = 1
res = []
if fail == 0:
res = cur.fetchall()
i = i + 1
logging.debug('database {}, return flag {}, execute sql {}\n'.format(self.args.dbname, 1 - fail, sql))
if fail == 1:
# raise RuntimeError("Database query failed")
print("SQL Execution Fatal!!")
return 0, ''
elif fail == 0:
# print("SQL Execution Succeed!!")
return 1, res
def pgsql_results(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql(sql)
#print("pgsql_results", success, res)
if success == 1:
return res
else:
return "<fail>"
except Exception as error:
logging.error('pgsql_results Exception', error)
return "<fail>"
def pgsql_cost_estimation(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql('explain (FORMAT JSON) ' + sql)
if success == 1:
cost = res[0][0][0]['Plan']['Total Cost']
return cost
else:
logging.error('pgsql_cost_estimation Fails!')
return 0
except Exception as error:
logging.error('pgsql_cost_estimation Exception', error)
return 0
def pgsql_actual_time(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
if success == 1:
cost = res[0][0][0]['Plan']['Actual Total Time']
return cost
else:
return -1
except Exception as error:
logging.error('pgsql_actual_time Exception', error)
return -1
def mysql_cost_estimation(self, sql):
try:
success, res = self.execute_sql('explain format=json ' + sql)
if success == 1:
total_cost = self.get_mysql_total_cost(0, json.loads(res[0][0]))
return float(total_cost)
else:
return -1
except Exception as error:
logging.error('mysql_cost_estimation Exception', error)
return -1
def get_mysql_total_cost(self, total_cost, res):
if isinstance(res, dict):
if 'query_cost' in res.keys():
total_cost += float(res['query_cost'])
else:
for key in res:
total_cost += self.get_mysql_total_cost(0, res[key])
elif isinstance(res, list):
for i in res:
total_cost += self.get_mysql_total_cost(0, i)
return total_cost
def get_tables(self):
if self.args.dbtype == 'mysql':
return self.mysql_get_tables()
else:
return self.pgsql_get_tables()
# query cost estimated by the optimizer
def cost_estimation(self, sql):
if self.args.dbtype == 'mysql':
return self.mysql_cost_estimation(sql)
else:
return self.pgsql_cost_estimation(sql)
def compute_table_schema(self):
"""
schema: {table_name: [field_name]}
:param cursor:
:return:
"""
if self.args.dbtype == 'postgresql':
# cur_path = os.path.abspath('.')
# tpath = cur_path + '/sampled_data/'+dbname+'/schema'
sql = 'SELECT table_name FROM information_schema.tables WHERE table_schema = \'public\';'
success, res = self.execute_sql(sql)
#print("======== tables", res)
if success == 1:
tables = res
schema = {}
for table_info in tables:
table_name = table_info[0]
sql = 'SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \'' + table_name + '\';'
success, res = self.execute_sql(sql)
#print("======== table columns", res)
columns = res
schema[table_name] = []
for col in columns:
''' compute the distinct value ratio of the column
if transfer_field_type(col[1], self.args.dbtype) == DataType.VALUE.value:
sql = 'SELECT count({}) FROM {};'.format(col[0], table_name)
success, res = self.execute_sql(sql)
print("======== column rows", res)
num = res
if num[0][0] != 0:
schema[table_name].append(col[0])
'''
#schema[table_name].append("column {} is of {} type".format(col[0], col[1]))
schema[table_name].append("{}".format(col[0]))
'''
with open(tpath, 'w') as f:
f.write(str(schema))
'''
#print(schema)
return schema
else:
logging.error('pgsql_cost_estimation Fails!')
return 0
def simulate_index(self, index):
#table_name = index.table()
statement = (
"SELECT * FROM hypopg_create_index(E'{}');".format(index)
)
result = self.execute_sql(statement)
return result
def drop_simulated_index(self, oid):
statement = f"select * from hypopg_drop_index({oid})"
result = self.execute_sql(statement)
assert result[0] is True, f"Could not drop simulated index with oid = {oid}."

@ -0,0 +1,71 @@
import argparse
import configparser
import logging
def get_conf(conf_file, server_name):
conf = configparser.ConfigParser()
conf.read(conf_file)
sql_server = conf[server_name]
return sql_server
def get_parser():
parser = argparse.ArgumentParser(
description="Instruction Induction.")
parser.add_argument("--db_conf", type=str,
default = '../database/configs/config.ini')
"""
parser.add_argument("--train_data", type=str,
default="./data/raw/train/rules.json")
parser.add_argument("--eval_data", type=str,
default="./data/raw/execute/zhenzhi.json")
parser.add_argument("--data_save", type=str,
default="./result/{}/data/")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--runlog", type=str,
default="./result/{}/exp_runtime.log")
parser.add_argument("--logdir", type=str,
default="./result/{}/logdir/")
parser.add_argument("--model_save", type=str,
default="./result/{}/model/")
parser.add_argument("--gen_sample", type=int, default=20)
parser.add_argument("--gen_demo", type=int, default=16)
parser.add_argument("--gen_prompt_per_sample", type=int, default=5)
parser.add_argument("--gen_model", type=str, default="text-davinci-003")
parser.add_argument("--gen_max_tokens", type=int, default=200)
parser.add_argument("--eval_sample", type=int, default=20)
parser.add_argument("--eval_model", type=str, default="text-davinci-003")
parser.add_argument("--eval_max_tokens", type=int, default=1000)
parser.add_argument("--storage_budget", type=int, default=500) # limit storage space of built indexes
"""
return parser
def set_logger(log_file):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s: - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
# log to file
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
# log to console
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.addHandler(fh)

@ -0,0 +1,36 @@
import re
import sqlparse
def remove_create_table(sql):
return re.sub(r'(create|CREATE)\s+(table|TABLE).+?\(.+?\)\s*;','',sql, flags=re.DOTALL)
def remove_create_index(sql):
return re.sub(r'(create|CREATE)\s+(index|INDEX).+?\(.+?\)\s*;','',sql, flags=re.DOTALL)
def remove_table(sql):
return re.sub(r'(table|TABLE).+?\(.+?\)\s*;','',sql, flags=re.DOTALL)
def clean_sql(sql):
tmp = []
for token in sql.flatten():
if not token.is_whitespace and not token.ttype is sqlparse.tokens.Comment.Single:
tmp.append(token)
return strip_par(' '.join(str(t) for t in tmp))
def strip_par(s):
for op in ['(',')',',','>','=','<','>=','<=','!=','<>','.',';']:
s = s.replace(' {}'.format(op), op).replace('{} '.format(op), op)
return s
def preprocess_execute_sql(sql):
sql = remove_create_table(sql)
sql = remove_create_index(sql)
parsed = sqlparse.parse(sql)
if len(parsed) == 0:
return [0, '']
sql = clean_sql(parsed[0])
if not sql:
return [0, '']
if sql[-1] != ';':
sql += ';'
return [1, sql]

@ -0,0 +1,355 @@
"""Contains classes for querying large language models."""
from math import ceil
import os
import time
from tqdm import tqdm
from abc import ABC, abstractmethod
import openai
gpt_costs_per_thousand = {
'davinci': 0.0200,
'curie': 0.0020,
'babbage': 0.0005,
'ada': 0.0004
}
def model_from_config(config, disable_tqdm=True):
"""Returns a model based on the config."""
model_type = config["name"]
if model_type == "GPT_forward":
return GPT_Forward(config, disable_tqdm=disable_tqdm)
elif model_type == "GPT_insert":
return GPT_Insert(config, disable_tqdm=disable_tqdm)
raise ValueError(f"Unknown model type: {model_type}")
class LLM(ABC):
"""Abstract base class for large language models."""
@abstractmethod
def generate_text(self, prompt):
"""Generates text from the model.
Parameters:
prompt: The prompt to use. This can be a string or a list of strings.
Returns:
A list of strings.
"""
pass
@abstractmethod
def log_probs(self, text, log_prob_range):
"""Returns the log probs of the text.
Parameters:
text: The text to get the log probs of. This can be a string or a list of strings.
log_prob_range: The range of characters within each string to get the log_probs of.
This is a list of tuples of the form (start, end).
Returns:
A list of log probs.
"""
pass
class GPT_Forward(LLM):
"""Wrapper for GPT-3."""
def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
"""Initializes the model."""
self.config = config
self.needs_confirmation = needs_confirmation
self.disable_tqdm = disable_tqdm
def confirm_cost(self, texts, n, max_tokens):
total_estimated_cost = 0
for text in texts:
total_estimated_cost += gpt_get_estimated_cost(
self.config, text, max_tokens) * n
print(f"Estimated cost: ${total_estimated_cost:.2f}")
# Ask the user to confirm in the command line
if os.getenv("LLM_SKIP_CONFIRM") is None:
confirm = input("Continue? (y/n) ")
if confirm != 'y':
raise Exception("Aborted.")
def auto_reduce_n(self, fn, prompt, n):
"""Reduces n by half until the function succeeds."""
try:
return fn(prompt, n)
except BatchSizeException as e:
if n == 1:
raise e
return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
def generate_text(self, prompt, n):
if not isinstance(prompt, list):
prompt = [prompt]
if self.needs_confirmation:
self.confirm_cost(
prompt, n, self.config['gpt_config']['max_tokens'])
batch_size = self.config['batch_size']
prompt_batches = [prompt[i:i + batch_size]
for i in range(0, len(prompt), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
f"split into {len(prompt_batches)} batches of size {batch_size * n}")
text = []
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
return text
def complete(self, prompt, n):
"""Generates text from the model and returns the log prob data."""
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = self.config['batch_size']
prompt_batches = [prompt[i:i + batch_size]
for i in range(0, len(prompt), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
f"split into {len(prompt_batches)} batches of size {batch_size * n}")
res = []
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
res += self.__complete(prompt_batch, n)
return res
def log_probs(self, text, log_prob_range=None):
"""Returns the log probs of the text."""
if not isinstance(text, list):
text = [text]
if self.needs_confirmation:
self.confirm_cost(text, 1, 0)
batch_size = self.config['batch_size']
text_batches = [text[i:i + batch_size]
for i in range(0, len(text), batch_size)]
if log_prob_range is None:
log_prob_range_batches = [None] * len(text)
else:
assert len(log_prob_range) == len(text)
log_prob_range_batches = [log_prob_range[i:i + batch_size]
for i in range(0, len(log_prob_range), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
log_probs = []
tokens = []
for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
disable=self.disable_tqdm):
log_probs_batch, tokens_batch = self.__log_probs(
text_batch, log_prob_range)
log_probs += log_probs_batch
tokens += tokens_batch
return log_probs, tokens
def __generate_text(self, prompt, n):
"""Generates text from the model."""
if not isinstance(prompt, list):
text = [prompt]
config = self.config['gpt_config'].copy()
config['n'] = n
# If there are any [APE] tokens in the prompts, remove them
for i in range(len(prompt)):
prompt[i] = prompt[i].replace('[APE]', '').strip()
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=prompt)
except Exception as e:
if 'is greater than the maximum' in str(e):
raise BatchSizeException()
print(e)
print('Retrying...')
time.sleep(5)
return [response['choices'][i]['text'] for i in range(len(response['choices']))]
def __complete(self, prompt, n):
"""Generates text from the model and returns the log prob data."""
if not isinstance(prompt, list):
text = [prompt]
config = self.config['gpt_config'].copy()
config['n'] = n
# If there are any [APE] tokens in the prompts, remove them
for i in range(len(prompt)):
prompt[i] = prompt[i].replace('[APE]', '').strip()
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=prompt)
except Exception as e:
print(e)
print('Retrying...')
time.sleep(5)
return response['choices']
def __log_probs(self, text, log_prob_range=None):
"""Returns the log probs of the text."""
if not isinstance(text, list):
text = [text]
if log_prob_range is not None:
for i in range(len(text)):
lower_index, upper_index = log_prob_range[i]
assert lower_index < upper_index
assert lower_index >= 0
assert upper_index - 1 < len(text[i])
config = self.config['gpt_config'].copy()
config['logprobs'] = 1
config['echo'] = True
config['max_tokens'] = 0
if isinstance(text, list):
text = [f'\n{text[i]}' for i in range(len(text))]
else:
text = f'\n{text}'
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=text)
except Exception as e:
print(e)
print('Retrying...')
time.sleep(5)
log_probs = [response['choices'][i]['logprobs']['token_logprobs'][1:]
for i in range(len(response['choices']))]
tokens = [response['choices'][i]['logprobs']['tokens'][1:]
for i in range(len(response['choices']))]
offsets = [response['choices'][i]['logprobs']['text_offset'][1:]
for i in range(len(response['choices']))]
# Subtract 1 from the offsets to account for the newline
for i in range(len(offsets)):
offsets[i] = [offset - 1 for offset in offsets[i]]
if log_prob_range is not None:
# First, we need to find the indices of the tokens in the log probs
# that correspond to the tokens in the log_prob_range
for i in range(len(log_probs)):
lower_index, upper_index = self.get_token_indices(
offsets[i], log_prob_range[i])
log_probs[i] = log_probs[i][lower_index:upper_index]
tokens[i] = tokens[i][lower_index:upper_index]
return log_probs, tokens
def get_token_indices(self, offsets, log_prob_range):
"""Returns the indices of the tokens in the log probs that correspond to the tokens in the log_prob_range."""
# For the lower index, find the highest index that is less than or equal to the lower index
lower_index = 0
for i in range(len(offsets)):
if offsets[i] <= log_prob_range[0]:
lower_index = i
else:
break
upper_index = len(offsets)
for i in range(len(offsets)):
if offsets[i] >= log_prob_range[1]:
upper_index = i
break
return lower_index, upper_index
class GPT_Insert(LLM):
def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
"""Initializes the model."""
self.config = config
self.needs_confirmation = needs_confirmation
self.disable_tqdm = disable_tqdm
def confirm_cost(self, texts, n, max_tokens):
total_estimated_cost = 0
for text in texts:
total_estimated_cost += gpt_get_estimated_cost(
self.config, text, max_tokens) * n
print(f"Estimated cost: ${total_estimated_cost:.2f}")
# Ask the user to confirm in the command line
if os.getenv("LLM_SKIP_CONFIRM") is None:
confirm = input("Continue? (y/n) ")
if confirm != 'y':
raise Exception("Aborted.")
def auto_reduce_n(self, fn, prompt, n):
"""Reduces n by half until the function succeeds."""
try:
return fn(prompt, n)
except BatchSizeException as e:
if n == 1:
raise e
return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
def generate_text(self, prompt, n):
if not isinstance(prompt, list):
prompt = [prompt]
if self.needs_confirmation:
self.confirm_cost(
prompt, n, self.config['gpt_config']['max_tokens'])
batch_size = self.config['batch_size']
assert batch_size == 1
prompt_batches = [prompt[i:i + batch_size]
for i in range(0, len(prompt), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Generating {len(prompt) * n} completions, split into {len(prompt_batches)} batches of (maximum) size {batch_size * n}")
text = []
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
return text
def log_probs(self, text, log_prob_range=None):
raise NotImplementedError
def __generate_text(self, prompt, n):
"""Generates text from the model."""
config = self.config['gpt_config'].copy()
config['n'] = n
# Split prompts into prefixes and suffixes with the [APE] token (do not include the [APE] token in the suffix)
prefix = prompt[0].split('[APE]')[0]
suffix = prompt[0].split('[APE]')[1]
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=prefix, suffix=suffix)
except Exception as e:
print(e)
print('Retrying...')
time.sleep(5)
# Remove suffix from the generated text
texts = [response['choices'][i]['text'].replace(suffix, '') for i in range(len(response['choices']))]
return texts
def gpt_get_estimated_cost(config, prompt, max_tokens):
"""Uses the current API costs/1000 tokens to estimate the cost of generating text from the model."""
# Get rid of [APE] token
prompt = prompt.replace('[APE]', '')
# Get the number of tokens in the prompt
n_prompt_tokens = len(prompt) // 4
# Get the number of tokens in the generated text
total_tokens = n_prompt_tokens + max_tokens
engine = config['gpt_config']['model'].split('-')[1]
costs_per_thousand = gpt_costs_per_thousand
if engine not in costs_per_thousand:
# Try as if it is a fine-tuned model
engine = config['gpt_config']['model'].split(':')[0]
costs_per_thousand = {
'davinci': 0.1200,
'curie': 0.0120,
'babbage': 0.0024,
'ada': 0.0016
}
price = costs_per_thousand[engine] * total_tokens / 1000
return price
class BatchSizeException(Exception):
pass

@ -0,0 +1,6 @@
from ..registry import register
@register("db_diag")
def register_db_diag_tool():
from .api import build_db_diag_tool
return build_db_diag_tool

@ -0,0 +1,61 @@
import numpy as np
import requests
import json
def prometheus(url, params):
res = requests.get(url='http://8.131.229.55:9090/' + url, params=params)
# print(json.dumps(res.json()))
#return json.dumps(res.json())
return res.json()
def detect_anomalies(data, significance_level=0.05):
# assume the workload is steadily running
"""
Detects anomalies in the given data using the KS test algorithm.
Args:
data (numpy.ndarray): 1-D array of data values.
significance_level (float): Level of significance for the KS test (default: 0.05).
Returns:
numpy.ndarray: Boolean array indicating anomalies (True) and non-anomalies (False).
"""
"""
sorted_data = np.sort(data)
n = len(sorted_data)
# Calculate the expected CDF assuming a normal distribution
expected_cdf = np.arange(1, n + 1) / n
# Calculate the empirical CDF
empirical_cdf = np.searchsorted(sorted_data, sorted_data, side='right') / n
# Calculate the maximum absolute difference between the expected and empirical CDFs
ks_statistic = np.max(np.abs(empirical_cdf - expected_cdf))
# Calculate the critical value based on the significance level and sample size
critical_value = np.sqrt(-0.5 * np.log(significance_level / 2) / n)
# Compare the KS statistic with the critical value
anomalies = np.where(ks_statistic > critical_value, True, False)
"""
# Calculate the mean and standard deviation of the data
anomalies = False
mean = np.mean(data)
max_value = np.max(data)
print("mean: ", mean)
print("max_value: ", max_value)
if max_value > mean:
anomalies = True
return anomalies

@ -0,0 +1,315 @@
import json
import os
import requests
import numpy as np
import openai
import paramiko
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet, stopwords
from nltk.tokenize import word_tokenize
import nltk
from ..tool import Tool
from bmtools.tools.database.utils.db_parser import get_conf
from bmtools.tools.database.utils.database import DBArgs, Database
from bmtools.models.customllm import CustomLLM
from bmtools.knowledge.knowledge_extraction import KnowledgeExtraction
from bmtools.tools.db_diag.anomaly_detection import detect_anomalies
from bmtools.tools.db_diag.anomaly_detection import prometheus
from bmtools.tools.db_diag.example_generate import bm25
import warnings
def obtain_values_of_metrics(start_time, end_time, metrics):
if end_time - start_time > 11000*3: # maximum resolution of 11,000 points per timeseries
#raise Exception("The time range is too large, please reduce the time range")
warnings.warn("The time range ({}, {}) is too large, please reduce the time range".format(start_time, end_time))
required_values = {}
print(" ====> metrics: ", metrics)
for metric in metrics:
metric_values = prometheus('api/v1/query_range', {'query': metric, 'start': start_time, 'end': end_time, 'step': '3'})
if metric_values["data"]["result"] != []:
metric_values = metric_values["data"]["result"][0]["values"]
else:
raise Exception("No metric values found for the given time range")
# compute the average value of the metric
max_value = np.max(np.array([float(value) for _, value in metric_values]))
required_values[metric] = max_value
return required_values
def find_abnormal_metrics(start_time, end_time, monitoring_metrics, resource):
resource_keys = ["memory", "cpu", "disk", "network"]
abnormal_metrics = []
for metric_name in monitoring_metrics:
interval_time = 5
metric_values = prometheus('api/v1/query_range', {'query': metric_name, 'start': start_time-interval_time*60, 'end': end_time+interval_time*60, 'step': '3'})
if metric_values["data"]["result"] != []:
metric_values = metric_values["data"]["result"][0]["values"]
else:
continue
if detect_anomalies(np.array([float(value) for _, value in metric_values])):
success = True
for key in resource_keys:
if key in metric_name and key != resource:
success = False
break
if success:
abnormal_metrics.append(metric_name)
return abnormal_metrics
def build_db_diag_tool(config) -> Tool:
tool = Tool(
"Database Diagnosis",
"Diagnose the bottlenecks of a database based on relevant metrics",
name_for_model="db_diag",
description_for_model="Plugin for diagnosing the bottlenecks of a database based on relevant metrics",
logo_url="https://commons.wikimedia.org/wiki/File:Postgresql_elephant.svg",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
#URL_CURRENT_WEATHER= "http://api.weatherapi.com/v1/current.json"
#URL_FORECAST_WEATHER = "http://api.weatherapi.com/v1/forecast.json"
URL_PROMETHEUS = 'http://8.131.229.55:9090/'
prometheus_metrics = {"cpu_usage": "avg(rate(process_cpu_seconds_total{instance=\"172.27.58.65:9187\"}[5m]) * 1000)",
"cpu_metrics": ["node_scrape_collector_duration_seconds{instance=\"172.27.58.65:9100\"}", "node_procs_running{instance=\"172.27.58.65:9100\"}", "node_procs_blocked{instance=\"172.27.58.65:9100\"}", "node_entropy_available_bits{instance=\"172.27.58.65:9100\"}", "node_load1{instance=\"172.27.58.65:9100\"}", "node_load5{instance=\"172.27.58.65:9100\"}", "node_load15{instance=\"172.27.58.65:9100\"}"],
"memory_usage": "node_memory_MemTotal_bytes{instance=~\"172.27.58.65:9100\"} - (node_memory_Cached_bytes{instance=~\"172.27.58.65:9100\"} + node_memory_Buffers_bytes{instance=~\"172.27.58.65:9100\"} + node_memory_MemFree_bytes{instance=~\"172.27.58.65:9100\"})",
"memory_metrics": ["node_memory_Inactive_anon_bytes{instance=\"172.27.58.65:9100\"}", "node_memory_MemFree_bytes{instance=\"172.27.58.65:9100\"}", "node_memory_Dirty_bytes{instance=\"172.27.58.65:9100\"}", "pg_stat_activity_count{datname=~\"(imdbload|postgres|sysbench|template0|template1|tpcc|tpch)\", instance=~\"172.27.58.65:9187\", state=\"active\"} !=0"],
"network_metrics": ["node_sockstat_TCP_tw{instance=\"172.27.58.65:9100\"}", "node_sockstat_TCP_orphan{instance=\"172.27.58.65:9100\"}"]}
# "node_sockstat_TCP_tw{instance=\"172.27.58.65:9100\"}",
# load knowlege extractor
knowledge_matcher = KnowledgeExtraction("/bmtools/tools/db_diag/root_causes_dbmind.jsonl")
# load db settings
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
config = get_conf(script_dir + '/my_config.ini', 'postgresql')
dbargs = DBArgs("postgresql", config=config) # todo assign database name
# send request to database
db = Database(dbargs, timeout=-1)
server_config = get_conf(script_dir + '/my_config.ini', 'benchserver')
monitoring_metrics = []
with open(str(os.getcwd()) + "/bmtools/tools/db_diag/database_monitoring_metrics", 'r') as f:
monitoring_metrics = f.read()
monitoring_metrics = eval(monitoring_metrics)
@tool.get("/obtain_start_and_end_time_of_anomaly")
def obtain_start_and_end_time_of_anomaly():
# Create SSH client
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
start_time = 0
end_time = 0
try:
# Connect to the remote server
ssh.connect(server_config["server_address"], username=server_config["username"], password=server_config["password"])
# Create an SFTP client
sftp = ssh.open_sftp()
# Change to the remote directory
sftp.chdir(server_config["remote_directory"])
# Get a list of files in the directory
files = sftp.listdir()
required_file_name = ""
required_tp = -1
# Read the contents of each file
for filename in files:
remote_filepath = server_config["remote_directory"] + '/' + filename
if "trigger_time_log" not in filename:
continue
tp = filename.split("_")[0]
if tp.isdigit():
tp = int(tp)
if required_tp < tp:
required_tp = tp
required_file_name = filename
file_content = sftp.open(server_config["remote_directory"] + '/' + required_file_name).read()
file_content = file_content.decode()
tps = file_content.split("\n")[0]
start_time = tps.split(";")[0]
end_time = tps.split(";")[1]
finally:
# Close the SFTP session and SSH connection
sftp.close()
ssh.close()
return {"start_time": start_time, "end_time": end_time}
@tool.get("/whether_is_abnormal_metric")
def whether_is_abnormal_metric(start_time:int, end_time:int, metric_name : str="cpu_usage"):
interval_time = 5
metric_values = prometheus('api/v1/query_range', {'query': prometheus_metrics[metric_name], 'start': start_time-interval_time*60, 'end': end_time+interval_time*60, 'step': '3'})
# prometheus('api/v1/query_range', {'query': '100 - (avg(irate(node_cpu_seconds_total{instance=~"172.27.58.65:9100",mode="idle"}[1m])) * 100)', 'start': '1684412385', 'end': '1684413285', 'step': '3'})
# print(" === metric_values", metric_values)
if metric_values["data"]["result"] != []:
metric_values = metric_values["data"]["result"][0]["values"]
else:
raise Exception("No metric values found for the given time range")
#is_abnormal = detect_anomalies(np.array([float(value) for _, value in metric_values]))
is_abnormal = True
if is_abnormal:
return "The metric is abnormal"
else:
return "The metric is normal"
@tool.get("/cpu_diagnosis_agent")
def cpu_diagnosis_agent(start_time : int, end_time : int):
# live_tuples\n- dead_tuples\n- table_size
cpu_metrics = prometheus_metrics["cpu_metrics"]
cpu_metrics = cpu_metrics # + find_abnormal_metrics(start_time, end_time, monitoring_metrics, 'cpu')
print("==== cpu_metrics", cpu_metrics)
detailed_cpu_metrics = obtain_values_of_metrics(start_time, end_time, cpu_metrics)
docs_str = knowledge_matcher.match(detailed_cpu_metrics)
prompt = """The CPU metric is abnormal. Then obtain the CPU relevant metric values from Prometheus: {}.
Next output the analysis of potential causes of the high CPU usage based on the CPU relevant metric values,
{}""".format(detailed_cpu_metrics, docs_str)
print(prompt)
# response = openai.Completion.create(
# model="text-davinci-003",
# prompt=prompt,
# temperature=0,
# max_tokens=1000,
# top_p=1.0,
# frequency_penalty=0.0,
# presence_penalty=0.0,
# stop=["#", ";"]
# )
# output_text = response.choices[0].text.strip()
# Set up the OpenAI GPT-3 model
# model_engine = "gpt-3.5-turbo"
# prompt_response = openai.ChatCompletion.create(
# engine="gpt-3.5-turbo",
# messages=[
# {"role": "assistant", "content": "The table schema is as follows: " + str(schema)},
# {"role": "user", "content": str(prompt)}
# ]
# )
# output_text = prompt_response['choices'][0]['message']['content']
llm = CustomLLM()
output_analysis = llm(prompt)
return {"diagnose": output_analysis, "knowledge": docs_str}
@tool.get("/memory_diagnosis_agent")
def memory_diagnosis_agent(start_time : int, end_time : int):
memory_metrics = prometheus_metrics["memory_metrics"]
memory_metrics = prometheus_metrics["memory_metrics"]
memory_metrics = memory_metrics # + find_abnormal_metrics(start_time, end_time, monitoring_metrics, 'memory')
detailed_memory_metrics = obtain_values_of_metrics(start_time, end_time, memory_metrics)
openai.api_key = os.environ["OPENAI_API_KEY"]
db = Database(dbargs, timeout=-1)
slow_queries = db.obtain_historical_slow_queries()
slow_query_state = ""
for i,query in enumerate(slow_queries):
slow_query_state += str(i+1) + '. ' + str(query) + "\n"
print(slow_query_state)
# TODO: need a similarity match function to match the top-K examples
# 1. get the categories of incoming metrics. Such as "The abnormal metrics include A, B, C, D"
# 2. embedding the metrics
# note: 这个metrics的embedding有可能预计算吗如果metrics的种类组合数有限的话
# 3. match the top-K examples(embedding)
# note: 不用embedding如何有效的筛选出来与当前metrics最相关的example呢可以枚举吗比如如果我知道某一个example涉及到哪些metrics
# 该如何判断某一个metrics跟一段文本是相关的呢能否用一个模型来判断一段文本涉及到哪些metrics呢重新训练的话感觉需要很多样本才行
# 能不能用关键词数量?
docs_str = knowledge_matcher.match(detailed_memory_metrics)
prompt = """The memory metric is abnormal. Then obtain the memory metric values from Prometheus: {}. The slow queries are:
{}
Output the analysis of potential causes of the high memory usage based on the memory metric values and slow queries, e.g.,
{}
Note: include the important slow queries in the output.
""".format(detailed_memory_metrics, slow_query_state, docs_str)
# print(prompt)
# response = openai.Completion.create(
# model="text-davinci-003",
# prompt=prompt,
# temperature=0,
# max_tokens=1000,
# top_p=1.0,
# frequency_penalty=0.0,
# presence_penalty=0.0,
# stop=["#", ";"]
# )
# output_text = response.choices[0].text.strip()
# Set up the OpenAI GPT-3 model
# model_engine = "gpt-3.5-turbo"
# prompt_response = openai.ChatCompletion.create(
# engine="gpt-3.5-turbo",
# messages=[
# {"role": "assistant", "content": "The table schema is as follows: " + str(schema)},
# {"role": "user", "content": str(prompt)}
# ]
# )
# output_text = prompt_response['choices'][0]['message']['content']
llm = CustomLLM()
output_analysis = llm(prompt)
return {"diagnose": output_analysis, "knowledge": docs_str}
return tool

@ -0,0 +1,18 @@
[postgresql]
host =
port =
user =
password =
dbname =
[mysql]
host =
port =
user =
password =
[benchserver]
server_address =
username =
password =
remote_directory =

@ -0,0 +1 @@
['go_gc_duration_seconds', 'go_gc_duration_seconds_count', 'go_gc_duration_seconds_sum', 'go_goroutines', 'go_memstats_alloc_bytes', 'go_memstats_alloc_bytes_total', 'go_memstats_frees_total', 'go_memstats_gc_cpu_fraction', 'go_memstats_heap_alloc_bytes', 'go_memstats_heap_idle_bytes', 'go_memstats_heap_inuse_bytes', 'go_memstats_heap_objects', 'go_memstats_heap_released_bytes', 'go_memstats_heap_sys_bytes', 'go_memstats_last_gc_time_seconds', 'go_memstats_mallocs_total', 'go_memstats_mspan_inuse_bytes', 'go_memstats_stack_inuse_bytes', 'go_memstats_stack_sys_bytes', 'node_context_switches_total', 'node_cpu_seconds_total', 'node_disk_io_now', 'node_disk_io_time_seconds_total', 'node_disk_io_time_weighted_seconds_total', 'node_disk_read_bytes_total', 'node_disk_read_time_seconds_total', 'node_disk_reads_completed_total', 'node_disk_reads_merged_total', 'node_disk_write_time_seconds_total', 'node_disk_writes_completed_total', 'node_disk_writes_merged_total', 'node_disk_written_bytes_total', 'node_entropy_available_bits', 'node_filefd_allocated', 'node_filesystem_avail_bytes', 'node_filesystem_files_free', 'node_filesystem_free_bytes', 'node_forks_total', 'node_intr_total', 'node_load1', 'node_load15', 'node_load5', 'node_memory_Active_anon_bytes', 'node_memory_Active_bytes', 'node_memory_Active_file_bytes', 'node_memory_AnonHugePages_bytes', 'node_memory_AnonPages_bytes', 'node_memory_Cached_bytes', 'node_memory_Committed_AS_bytes', 'node_memory_Dirty_bytes', 'node_memory_Inactive_anon_bytes', 'node_memory_Inactive_bytes', 'node_memory_Inactive_file_bytes', 'node_memory_KReclaimable_bytes', 'node_memory_KernelStack_bytes', 'node_memory_Mapped_bytes', 'node_memory_MemAvailable_bytes', 'node_memory_MemFree_bytes', 'node_memory_PageTables_bytes', 'node_memory_SReclaimable_bytes', 'node_memory_SUnreclaim_bytes', 'node_memory_Shmem_bytes', 'node_memory_Slab_bytes', 'node_netstat_Icmp_InMsgs', 'node_netstat_Icmp_OutMsgs', 'node_netstat_Ip6_InOctets', 'node_netstat_Ip6_OutOctets', 'node_netstat_IpExt_InOctets', 'node_netstat_IpExt_OutOctets', 'node_netstat_TcpExt_TCPSynRetrans', 'node_netstat_Tcp_ActiveOpens', 'node_netstat_Tcp_CurrEstab', 'node_netstat_Tcp_InSegs', 'node_netstat_Tcp_OutRsts', 'node_netstat_Tcp_OutSegs', 'node_netstat_Tcp_PassiveOpens', 'node_netstat_Tcp_RetransSegs', 'node_netstat_Udp6_InDatagrams', 'node_netstat_Udp6_InErrors', 'node_netstat_Udp6_OutDatagrams', 'node_netstat_Udp6_RcvbufErrors', 'node_netstat_Udp_InDatagrams', 'node_netstat_Udp_OutDatagrams', 'node_network_receive_bytes_total', 'node_network_receive_packets_total', 'node_network_transmit_bytes_total', 'node_network_transmit_packets_total', 'node_procs_blocked', 'node_procs_running', 'node_schedstat_running_seconds_total', 'node_schedstat_timeslices_total', 'node_schedstat_waiting_seconds_total', 'node_scrape_collector_duration_seconds', 'node_sockstat_TCP6_inuse', 'node_sockstat_TCP_alloc', 'node_sockstat_TCP_inuse', 'node_sockstat_TCP_mem', 'node_sockstat_TCP_mem_bytes', 'node_sockstat_TCP_orphan', 'node_sockstat_TCP_tw', 'node_sockstat_UDP_inuse', 'node_sockstat_UDP_mem', 'node_sockstat_UDP_mem_bytes', 'node_sockstat_sockets_used', 'node_softnet_processed_total', 'node_time_seconds', 'node_timex_estimated_error_seconds', 'node_timex_frequency_adjustment_ratio', 'node_timex_maxerror_seconds', 'node_timex_tick_seconds', 'node_vmstat_pgfault', 'node_vmstat_pgmajfault', 'node_vmstat_pgpgin', 'node_vmstat_pgpgout', 'node_xfs_block_mapping_extent_list_deletions_total', 'node_xfs_block_mapping_extent_list_insertions_total', 'node_xfs_block_mapping_extent_list_lookups_total', 'node_xfs_block_mapping_reads_total', 'node_xfs_block_mapping_unmaps_total', 'node_xfs_block_mapping_writes_total', 'node_xfs_directory_operation_create_total', 'node_xfs_directory_operation_getdents_total', 'node_xfs_directory_operation_lookup_total', 'node_xfs_directory_operation_remove_total', 'node_xfs_extent_allocation_blocks_allocated_total', 'node_xfs_extent_allocation_blocks_freed_total', 'node_xfs_extent_allocation_extents_allocated_total', 'node_xfs_extent_allocation_extents_freed_total', 'node_xfs_inode_operation_attempts_total', 'node_xfs_inode_operation_attribute_changes_total', 'node_xfs_inode_operation_found_total', 'node_xfs_inode_operation_missed_total', 'node_xfs_inode_operation_reclaims_total', 'node_xfs_read_calls_total', 'node_xfs_vnode_active_total', 'node_xfs_vnode_reclaim_total', 'node_xfs_vnode_release_total', 'node_xfs_vnode_remove_total', 'node_xfs_write_calls_total', 'process_cpu_seconds_total', 'process_resident_memory_bytes', 'process_start_time_seconds', 'promhttp_metric_handler_requests_total', 'scrape_duration_seconds', 'scrape_samples_post_metric_relabeling', 'scrape_samples_scraped', 'up']

File diff suppressed because one or more lines are too long

@ -0,0 +1,28 @@
import requests
import json
import datetime
import numpy as np
def prometheus(url, params):
output = requests.get(url='http://8.131.229.55:9090/' + url, params=params)
output = output.json()
print(output)
#output = json.dumps(res.json())
output = output["data"]["result"][0]["values"]
output = np.array([float(value) for _, value in output])
print(output)
print(type(output))
if __name__ == '__main__':
#prometheus('api/v1/query_range', {'query': '100 - (avg(irate(node_cpu_seconds_total{instance=~"123.56.63.105:9100",mode="idle"}[1m])) * 100)', 'start': '1684412385', 'end': '1684412485', 'step': '3'})
start_timestamp_str = "2023-05-19 22:21:30"
dt = datetime.datetime.strptime(start_timestamp_str, "%Y-%m-%d %H:%M:%S")
timestamp = dt.timestamp()
start_time = timestamp
end_timestamp_str = "2023-05-19 22:23:30"
dt = datetime.datetime.strptime(end_timestamp_str, "%Y-%m-%d %H:%M:%S")
timestamp = dt.timestamp()
end_time = timestamp
prometheus('api/v1/query_range', {'query': "node_memory_MemTotal_bytes{instance=~\"123.56.63.105:9100\"} - (node_memory_Cached_bytes{instance=~\"123.56.63.105:9100\"} + node_memory_Buffers_bytes{instance=~\"123.56.63.105:9100\"} + node_memory_MemFree_bytes{instance=~\"123.56.63.105:9100\"})", 'start': start_time, 'end': end_time, 'step': '3'})

@ -0,0 +1,24 @@
# db_diag Tool
Contributor: [Xuanhe Zhou](https://github.com/zhouxh19)
### API Functions
- *obtain_start_and_end_time_of_anomaly*: fetch the time period of an anomaly
- *whether_is_abnormal_metric*: examine whether the values of the input metric appear to be abnormal. //todo: add classic anomaly detection algorithms
- *xxx_diagnosis_agent*: diagnose the root causes of the abnormal metrics in specific region (e.g., memory/cpu problems)
### Setup
1. Follow the steps in [main readme](https://github.com/OpenBMB/BMTools/blob/main/README.md)
2. Configure the adopted LLM model in the 84th line of ../../agent/singletool.py, e.g.,
```bash
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key)
```
3. Modify the settings in *config.ini*, and rename *config.ini* into *my_config.ini*
4. Modify and run the test.py script to test the tool

@ -0,0 +1,147 @@
[
{
"cause_name": "large_table",
"desc": "This function checks whether the query related table is a root cause of performance issues. It considers two aspects: the size of the table and the number of tuples in the table. If the number of live and dead tuples in a table exceeds the tuple_number_threshold or the table size exceeds the table_total_size_threshold, then the table is considered large and added to the detail dictionary. If there are any large tables, then they are considered a root cause of performance issues. If there are no large tables, then they are not a root cause of performance issues.",
"metrics": "- live_tuples\n- dead_tuples\n- table_size"
},
{
"cause_name": "many_dead_tuples",
"desc": "This function checks whether the query related table has too many dead tuples, which can cause bloat-table and affect query performance. If the table structure is not available or the insert type is not supported, it is not a root cause. The function then collects information about the dead rate, live tuples, dead tuples, and table size of each table and checks if they exceed certain thresholds. If the dead rate of a table exceeds the dead rate threshold, it is considered a root cause. The function also provides suggestions to clean up dead tuples in time to avoid affecting query performance.",
"metrics": "- dead_rate\n- live_tuples\n- dead_tuples\n- table_size"
},
{
"cause_name": "heavy_scan_operator",
"desc": "This function diagnoses whether there is a heavy scan operator in the query related table. If the table has too many fetched tuples and the hit rate is low, it is considered a root cause. Additionally, if there are expensive sequential scans, index scans, or heap scans, it is also considered a root cause. The function provides details on the heavy scan operator, including the number of fetched tuples, returned rows, and hit rate. It also suggests adjustments to avoid large scans. If there are expensive scans, the function suggests confirming whether the inner table has an index, avoiding count operations, and considering the index filter ability. If there is a heavy scan operator, the function provides details on the operator and suggests adjustments according to business needs. If there are no suggestions, it suggests avoiding heavy scans.",
"metrics": "- hit_rate\n- n_tuples_fetched\n- n_tuples_returned\n- n_returned_rows\n- total_cost\n- table\n- name\n- parent\n- cost rate"
},
{
"cause_name": "abnormal_plan_time",
"desc": "This function checks for abnormal execution plan generation in slow SQL instances. It calculates the ratio of plan time to execution time and compares it to the plan time rate threshold. If the ratio is greater than or equal to the threshold and the number of hard parses is greater than the number of soft parses, it indicates abnormal plan time. This could be due to the lack of support for PBE in the business. The function suggests modifying the business to support PBE. If the condition is met, it is a root cause of the issue. If not, it is not a root cause.",
"metrics": "- n_soft_parse\n- n_hard_parse\n- plan_time\n- exc_time"
},
{
"cause_name": "unused_and_redundant_index",
"desc": "This function checks for the presence of unused or redundant indexes in a table that is related to a query. Unused indexes are those that have not been used for a long time, while redundant indexes are those that are not needed for the query. If the table is not large or there are no unused or redundant indexes, or if the query involves a select operation, then the function is not a root cause. Otherwise, the function identifies the unused and redundant indexes and provides suggestions to clean them up. The threshold for identifying unused indexes is not specified in the code.",
"metrics": "- Large table\n- Unused index info\n- Redundant index info\n- Select type"
},
{
"cause_name": "update_large_data",
"desc": "This function checks whether a table has a large number of tuples updated. If the number of updated tuples is greater than or equal to the specified threshold, it is considered a root cause. The function then provides details on the number of updated tuples and suggests making adjustments to the business. If the number of updated tuples is not above the threshold or if there is no plan parse information available, it is not a root cause.",
"metrics": "- n_tuples_updated\n- updated_tuples_threshold\n- live_tuples\n- rows"
},
{
"cause_name": "insert_large_data",
"desc": "This function checks whether a query related table has a large number of inserted tuples. If the number of inserted tuples is greater than or equal to the specified threshold (stored in the variable \"inserted_tuples_threshold\"), it is considered a root cause. The function then calculates the ratio of inserted tuples to live tuples and provides a suggestion to make adjustments to the business. If the number of inserted tuples is less than the threshold, the function checks for insert operators and if any of them have a large number of rows inserted (greater than the threshold), it is considered a root cause and a suggestion is provided. If neither of these conditions are met, it is not a root cause.",
"metrics": "- n_tuples_inserted\n- inserted_tuples_threshold\n- live_tuples\n- table_structure\n- plan_parse_info\n- rows"
},
{
"cause_name": "delete_large_data",
"desc": "This function checks whether a table has too many tuples to be deleted. If the number of deleted tuples is greater than or equal to the specified threshold, it is considered a root cause and will be deleted in the future. If the number of deleted tuples is less than the threshold, the function checks whether there is a delete operation on a table with a large number of tuples. If so, it is also considered a root cause and the suggestion is to make adjustments to the business. If neither condition is met, it is not a root cause.",
"metrics": "- n_tuples_deleted\n- deleted_tuples_threshold\n- live_tuples\n- deleted_tuples_rate\n- rows"
},
{
"cause_name": "too_many_index",
"desc": "This function checks for the presence of too many indexes in a table, which can negatively impact the performance of insert and update operations. If the table structure is not available or the select type is not appropriate, this is not a root cause. If there are a large number of indexes in the table, the function identifies the related tables and provides details on the number of indexes. In this case, the function is a root cause and suggests that too many indexes can slow down insert, delete, and update statements. The threshold for the number of indexes is determined by the variable \"index_number_threshold\".",
"metrics": "- index_number_threshold\n- len(table.index)"
},
{
"cause_name": "disk_spill",
"desc": "This is a function that checks whether there is a possibility of disk spill during the execution of SQL. If the plan parse information is not available, it checks whether the sort spill count or hash spill count exceeds the sort rate threshold. If the plan parse information is available, it calculates the total cost of the plan and checks whether the cost rate of the sort or hash operators exceeds the cost rate threshold. If abnormal operator details are found and the sort or hash spill count is greater than 0, it indicates that the SORT/HASH operation may spill to disk. The suggestion is to analyze whether the business needs to adjust the size of the work_mem parameter. If disk spill is detected, it is a root cause, otherwise it is not a root cause.",
"metrics": "1. sort_spill_count\n2. hash_spill_count\n3. sort_rate_threshold\n4. cost_rate_threshold\n5. plan_total_cost\n6. rows\n7. _get_operator_cost"
},
{
"cause_name": "vacuum_event",
"desc": "This function checks whether the query related table has undergone a vacuum operation, which could potentially be a root cause of slow SQL queries. It first retrieves the probable time interval for an analyze operation from the monitoring module. Then, it creates a dictionary of vacuum information for each table in the table structure, including the schema name, table name, and the time of the last vacuum operation. The function then checks whether the vacuum time falls within the time range of the slow SQL query execution or whether the slow SQL query execution starts within a certain time interval after the vacuum operation. If any table meets these conditions, it is considered a root cause and added to the detail dictionary. If no table meets these conditions, it is not a root cause.",
"metrics": "- table_structure\n- slow_sql_param\n- vacuum_delay\n- start_at\n- duration_time"
},
{
"cause_name": "analyze_event",
"desc": "This function checks whether the query related table has undergone an analyzing operation. If the table structure is not available, it is not a root cause. Otherwise, it calculates the probable time interval for the analyzing operation and creates a dictionary of table names and their corresponding analyze times. It then checks whether the analyze time falls within the slow SQL instance's start and end time or within the probable time interval before or after the slow SQL instance. If any table satisfies these conditions, it is considered a root cause and added to the 'analyze' key in the detail dictionary. Finally, the function returns True if there is at least one table in the 'analyze' key, otherwise it is not a root cause.",
"metrics": "- table_structure\n- slow_sql_param\n- analyze_delay\n- start_at\n- duration_time"
},
{
"cause_name": "workload_contention",
"desc": "This code is designed to diagnose workload contention issues in a database system. The function checks for several potential causes of contention, including abnormal CPU and memory resource usage, insufficient space in the database data directory, and excessive connections or thread pool usage. If any of these issues are detected, the function provides a detailed report of the problem and suggests potential solutions. If no issues are found, the function returns \"not a root cause\".",
"metrics": "- process_used_memory\n- max_process_memory\n- dynamic_used_memory\n- max_dynamic_memory\n- other_used_memory\n- tps\n- max_connections\n- db_cpu_usage\n- db_mem_usage\n- disk_usage\n- connection\n- thread_pool_rate"
},
{
"cause_name": "cpu_resource_contention",
"desc": "This function checks whether there is contention for CPU resources by other processes outside the database. If the maximum CPU usage of these processes exceeds the threshold specified in the variable \"cpu_usage_threshold\", the function sets the \"system_cpu_contention\" key in the \"detail\" dictionary to indicate the current user CPU usage. If this key is set, the function suggests handling exception processes in the system as a solution. If the \"system_cpu_contention\" key is not set, this issue is not a root cause.",
"metrics": "- user_cpu_usage\n- system_cpu_contention"
},
{
"cause_name": "io_resource_contention",
"desc": "This piece of code checks for IO resource contention in the system. It does so by iterating through the IO utils of each device and checking if the maximum IO utils exceed the disk_ioutils_threshold. If there is contention, the function provides details on the device and the IO utils that exceed the threshold. It also suggests two possible causes of contention: competing processes outside the database and long transactions within the database. If there is contention, it is considered a root cause of the issue. If there is no contention, it is not a root cause.",
"metrics": "- IO utilization (IO-Utils)"
},
{
"cause_name": "memory_resource_contention",
"desc": "This function checks whether there is contention for memory resources by other processes outside the database. If the maximum system memory usage exceeds the detection threshold specified in the variable \"mem_usage_threshold\", the function sets the \"system_mem_contention\" key in the \"detail\" dictionary to indicate that the current system memory usage is significant. If the \"system_mem_contention\" key exists in the \"detail\" dictionary, the function suggests checking for external processes that may be consuming resources. If the function returns True, it indicates that memory resource contention is a root cause of the problem. If the function returns False, it means that memory resource contention is not a root cause.",
"metrics": "- system_mem_usage\n- system_mem_contention"
},
{
"cause_name": "abnormal_network_status",
"desc": "This piece of code checks for abnormal network status by analyzing packet loss rate and bandwidth usage. It first checks the receive and transmit drop rates for each device and appends any abnormal rates to a list of details. It then checks the bandwidth usage for each device and appends any abnormal rates to the same list of details. If any abnormal rates are found, the function sets the detail and suggestion attributes accordingly and returns True, indicating that abnormal network status is a root cause. If no abnormal rates are found, the function returns False, indicating that abnormal network status is not a root cause. The thresholds for abnormal packet loss rate and network bandwidth usage are obtained from the monitoring module.",
"metrics": "- package_drop_rate_threshold\n- network_bandwidth_usage_threshold"
},
{
"cause_name": "os_resource_contention",
"desc": "This function checks for a potential issue where other processes outside the database may be occupying too many handle resources. If the system file descriptor (fds) occupation rate exceeds the detection threshold, it is considered a root cause and the function returns a boolean value of True. The system fds occupation rate is recorded in the diagnosis report along with a suggestion to determine whether the handle resource is occupied by the database or other processes. If the system fds occupation rate is below the tuple_number_threshold, it is not a root cause and the function returns a boolean value of False.",
"metrics": "- process_fds_rate\n- handler_occupation_threshold"
},
{
"cause_name": "database_wait_event",
"desc": "This function checks if there is a wait event in the database. If there is a wait event, it retrieves the wait status and wait event information and stores it in the detail dictionary. If the detail dictionary already has wait event information, it suggests that there is no root cause for the issue. Otherwise, it suggests that the wait event may be a root cause for the issue. If there is no wait event information, it suggests that there is no root cause for the issue. Therefore, the presence of wait event information is a root cause for the issue, while the absence of wait event information is not a root cause.",
"metrics": "- wait_event_info\n- wait_status\n- wait_event\n- detail\n- suggestion"
},
{
"cause_name": "lack_of_statistics",
"desc": "This piece of code checks for the presence of updated statistics in the business table. If the statistics have not been updated for a long time, it may lead to a serious decline in the execution plan. The code identifies abnormal tables by comparing the difference in tuples with the tuple_number_threshold. If any abnormal tables are found, the code suggests updating the statistics in a timely manner to help the planner choose the most suitable plan. If no abnormal tables are found, lack of statistics is not a root cause.",
"metrics": "- data_changed_delay\n- tuples_diff\n- schema_name\n- table_name"
},
{
"cause_name": "missing_index",
"desc": "This function checks for the presence of a required index using a workload-index-recommend interface. If the recommended index information is available, it indicates that a required index is missing and provides a suggestion for the recommended index. If the information is not available, it is not a root cause for the issue.",
"metrics": ""
},
{
"cause_name": "poor_join_performance",
"desc": "This code diagnoses poor performance in join operations. There are four main situations that can cause poor join performance: 1) when the GUC parameter 'enable_hashjoin' is set to 'off', which can result in the optimizer choosing NestLoop or other join operators even when HashJoin would be more suitable; 2) when the optimizer incorrectly chooses the NestLoop operator, even when 'set_hashjoin' is on; 3) when the join operation involves a large amount of data, which can lead to high execution costs; and 4) when the cost of the join operator is expensive. \n\nIn general, NestLoop is suitable when the inner table has a suitable index or when the tuple of the outer table is small (less than 10000), while HashJoin is suitable for tables with large amounts of data (more than 10000), although index will reduce HashJoin performance to a certain extent. Note that HashJoin requires high memory consumption.\n\nThe code checks for abnormal NestLoop, HashJoin, and MergeJoin operators, and identifies inappropriate join nodes based on the number of rows and cost rate. It also provides suggestions for optimization, such as setting 'enable_hashjoin' to 'on', optimizing SQL structure to reduce JOIN cost, and using temporary tables to filter data. \n\nIf the code finds any poor join performance, it is considered a root cause of the problem. Otherwise, it is not a root cause.",
"metrics": "- total_cost\n- cost_rate_threshold\n- nestloop_rows_threshold\n- large_join_threshold"
},
{
"cause_name": "complex_boolean_expression",
"desc": "This function checks for a specific issue in SQL queries that can lead to poor performance. The issue occurs when the \"in\" clause in a query is too long, which can cause the query to execute slowly. The function looks for instances of this issue in the SQL query and if it finds one where the length of the \"in\" clause exceeds a certain threshold, it returns a message indicating the issue and provides a suggestion for how to fix it. If the function does not find any instances of this issue, it is not a root cause of the performance problem.",
"metrics": "- slow_sql_instance.query\n- expression_number\n- len(expression)\n- monitoring.get_slow_sql_param('large_in_list_threshold')"
},
{
"cause_name": "string_matching",
"desc": "This function checks for certain conditions that may cause index columns to fail. These conditions include selecting columns using certain functions or regular expressions, and using the \"order by random()\" operation. If any of these conditions are detected, the function provides suggestions for how to rewrite the query to avoid index failure. If abnormal functions or regulations are detected, the function suggests avoiding using functions or expression operations on indexed columns or creating expression index for it. If the \"order by random()\" operation is detected, the function suggests confirming whether the scene requires this operation. If any of these conditions are detected, the function is a root cause of the index failure. Otherwise, it is not a root cause.",
"metrics": "- existing_functions\n- matching_results\n- seq_scan_properties\n- sort_operators"
},
{
"cause_name": "complex_execution_plan",
"desc": "This is a function that checks for complex execution plans in SQL statements. The function identifies two cases that may cause complex execution plans: (1) a large number of join or group operations, and (2) a very complex execution plan based on its height. If the function identifies either of these cases, it sets the corresponding details and suggestions for the user. If the number of join operators exceeds the \"complex_operator_threshold\" or the plan height exceeds the \"plan_height_threshold\", the function considers it a root cause of the problem. Otherwise, it is not a root cause.",
"metrics": "- complex_boolean_expression\n- plan_parse_info\n- plan_parse_info.height\n- join_operator\n- len(join_operator)"
},
{
"cause_name": "correlated_subquery",
"desc": "This piece of code checks for the presence of sub-queries in SQL execution that cannot be promoted. If the execution plan contains the keyword 'SubPlan' and the SQL structure does not support Sublink-Release, the user needs to rewrite the SQL. The function checks for the existence of such sub-queries and provides suggestions for rewriting the statement to support sublink-release. If there are subqueries that cannot be promoted, it is a root cause of the issue. Otherwise, it is not a root cause.",
"metrics": "- SubPlan\n- exists_subquery"
},
{
"cause_name": "poor_aggregation_performance",
"desc": "This code diagnoses poor aggregation performance in SQL queries. It identifies four potential root causes: (1) when the GUC parameter 'enable_hashagg' is set to 'off', resulting in a higher tendency to use the GroupAgg operator; (2) when the query includes scenarios like 'count(distinct col)', which makes HashAgg unavailable; (3) when the cost of the GroupAgg operator is expensive; and (4) when the cost of the HashAgg operator is expensive. The code checks for these conditions and provides detailed information and suggestions for each potential root cause. If none of these conditions are met, poor aggregation performance is not a root cause.",
"metrics": "- total_cost\n- cost_rate_threshold\n- enable_hashagg\n- GroupAggregate\n- HashAggregate\n- Group By Key\n- NDV"
},
{
"cause_name": "abnormal_sql_structure",
"desc": "This function checks for a specific issue in the SQL structure that can lead to poor performance. If the rewritten SQL information is present, it indicates that the SQL structure is abnormal and can be a root cause of performance issues. The function provides a detailed description of the issue and suggests a solution to address it. If the rewritten SQL information is not present, it is not a root cause of the performance issue.",
"metrics": "\n- rewritten_sql_info"
},
{
"cause_name": "timed_task_conflict",
"desc": "This is a function that analyzes various features related to SQL execution and returns a feature vector, system cause details, and suggestions. The features include lock contention, heavy scan operator, abnormal plan time, unused and redundant index, and many others. The function checks if each feature can be obtained and appends the feature value to the feature vector. If a feature cannot be obtained, it logs an error message and appends 0 to the feature vector. The function also sets the system cause and plan details to empty dictionaries. The \"timed_task_conflict\" feature is not a root cause of the issue being diagnosed.",
"metrics": "\n- lock_contention\n- many_dead_tuples\n- heavy_scan_operator\n- abnormal_plan_time\n- unused_and_redundant_index\n- update_large_data\n- insert_large_data\n- delete_large_data\n- too_many_index\n- disk_spill\n- vacuum_event\n- analyze_event\n- workload_contention\n- cpu_resource_contention\n- io_resource_contention\n- memory_resource_contention\n- abnormal_network_status\n- os_resource_contention\n- database_wait_event\n- lack_of_statistics\n- missing_index\n- poor_join_performance\n- complex_boolean_expression\n- string_matching\n- complex_execution_plan\n- correlated_subquery\n- poor_aggregation_performance\n- abnormal_sql_structure\n- timed_task_conflict"
}
]

@ -0,0 +1,63 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
import datetime
tool_name, tool_url = 'db_diag', "http://127.0.0.1:8079/tools/db_diag/"
tool_name, tool_config = load_single_tools(tool_name, tool_url)
print(tool_name, tool_config)
stqa = STQuestionAnswerer()
# langchain
agent = stqa.load_tools(tool_name, tool_config, prompt_type="react-with-tool-description") # langchain: react-with-tool-description autogpt: autogpt
# database on 123.56.63.105
'''
start_timestamp_str = "2023-05-19 22:21:30"
dt = datetime.datetime.strptime(start_timestamp_str, "%Y-%m-%d %H:%M:%S")
timestamp = dt.timestamp()
start_time = timestamp
end_timestamp_str = "2023-05-19 22:23:30"
dt = datetime.datetime.strptime(end_timestamp_str, "%Y-%m-%d %H:%M:%S")
timestamp = dt.timestamp()
end_time = timestamp
print(" ===== time period: ", start_time, end_time)
'''
#text = "The database performance is bad during {} to {}.".format(start_timestamp_str, end_timestamp_str) # trigger database diagnosis
text = "Here is a database performance problem. Please help me to diagnose the causes and give some optimization suggestions."
agent(""" {}
First, obtain_start_and_end_time_of_anomaly and memorize the start and end time of the anomaly.
Second, you need to diagnose the causes of the anomaly from the following two aspects:
- call the whether_is_abnormal_metric API and examine whether CPU usage is high (or abnormal). Next, if the CPU usage is high (or abnormal), cpu_diagnosis_agent and obtain the diagnosis results.
- call the whether_is_abnormal_metric API and examine whether memory usage is high (or abnormal). Next, if the memory usage is high (or abnormal), memory_diagnosis_agent and obtain the diagnosis results.
Third, you need to match each cause with potential solutions cached in the vector database.
Finally, list the above diagnosed causes and their matched solutions in easy-to-understand format using bullet points.
================================
A Demonstration example:
Thought: I need to check whether the CPU usage is high or abnormal during the given time period.
Action: whether_is_abnormal_metric
Action Input: {{"start_time": xxxx, "end_time": xxxx, "metric_name": "cpu_usage"}}
Note. 1) The first action must be obtain_start_and_end_time_of_anomaly;
2) Do not use any image in the output;
3) Give some optimization suggestions based on the diagnosis results.
""".format(text))
'''
1) Action can only be one of the following API names: obtain_start_and_end_time_of_anomaly, whether_is_abnormal_metric, obtain_values_of_metrics, cpu_diagnosis_agent, memory_diagnosis_agent. Any other content in Action is unacceptable;
'''

@ -0,0 +1,27 @@
import random
def subsample_data(data, subsample_size):
"""
Subsample data. Data is in the form of a tuple of lists.
"""
inputs, outputs = data
assert len(inputs) == len(outputs)
indices = random.sample(range(len(inputs)), subsample_size)
inputs = [inputs[i] for i in indices]
outputs = [outputs[i] for i in indices]
return inputs, outputs
def create_split(data, split_size):
"""
Split data into two parts. Data is in the form of a tuple of lists.
"""
inputs, outputs = data
assert len(inputs) == len(outputs)
indices = random.sample(range(len(inputs)), split_size)
inputs1 = [inputs[i] for i in indices]
outputs1 = [outputs[i] for i in indices]
inputs2 = [inputs[i] for i in range(len(inputs)) if i not in indices]
outputs2 = [outputs[i] for i in range(len(inputs)) if i not in indices]
return (inputs1, outputs1), (inputs2, outputs2)

@ -0,0 +1,291 @@
import psycopg2
import pymysql
import json
import logging
import os
from enum import IntEnum
import time
class DataType(IntEnum):
VALUE = 0
TIME = 1
CHAR = 2
AGGREGATE_CONSTRAINTS = {
DataType.VALUE.value: ['count', 'max', 'min', 'avg', 'sum'],
DataType.VALUE.CHAR: ['count', 'max', 'min'],
DataType.VALUE.TIME: ['count', 'max', 'min']
}
def transfer_field_type(database_type, server):
data_type = list()
if server == 'mysql':
data_type = [['int', 'tinyint', 'smallint', 'mediumint', 'bigint', 'float', 'double', 'decimal'],
['date', 'time', 'year', 'datetime', 'timestamp']]
database_type = database_type.lower().split('(')[0]
elif server == 'postgresql':
data_type = [['integer', 'numeric'],
['date']]
if database_type in data_type[0]:
return DataType.VALUE.value
elif database_type in data_type[1]:
return DataType.TIME.value
else:
return DataType.CHAR.value
class DBArgs(object):
def __init__(self, dbtype, config, dbname=None):
self.dbtype = dbtype
if self.dbtype == 'mysql':
self.host = config['host']
self.port = config['port']
self.user = config['user']
self.password = config['password']
self.dbname = dbname if dbname else config['dbname']
self.driver = 'com.mysql.jdbc.Driver'
self.jdbc = 'jdbc:mysql://'
else:
self.host = config['host']
self.port = config['port']
self.user = config['user']
self.password = config['password']
self.dbname = dbname if dbname else config['dbname']
self.driver = 'org.postgresql.Driver'
self.jdbc = 'jdbc:postgresql://'
class Database():
def __init__(self, args, timeout=-1):
self.args = args
self.conn = self.resetConn(timeout)
# self.schema = self.compute_table_schema()
def resetConn(self, timeout=-1):
if self.args.dbtype == 'mysql':
conn = pymysql.connect(
host=self.args.host,
user=self.args.user,
passwd=self.args.password,
database=self.args.dbname,
port=int(self.args.port),
charset='utf8',
connect_timeout=timeout,
read_timeout=timeout,
write_timeout=timeout)
else:
if timeout > 0:
conn = psycopg2.connect(database=self.args.dbname,
user=self.args.user,
password=self.args.password,
host=self.args.host,
port=self.args.port,
options='-c statement_timeout={}s'.format(timeout))
else:
conn = psycopg2.connect(database=self.args.dbname,
user=self.args.user,
password=self.args.password,
host=self.args.host,
port=self.args.port)
return conn
'''
def exec_fetch(self, statement, one=True):
cur = self.conn.cursor()
cur.execute(statement)
if one:
return cur.fetchone()
return cur.fetchall()
'''
def execute_sql(self, sql):
fail = 1
self.conn = self.resetConn(timeout=2)
cur = self.conn.cursor()
i = 0
cnt = 5 # retry times
while fail == 1 and i < cnt:
try:
fail = 0
print("========== start execution time:", time.time())
cur.execute(sql)
except BaseException:
fail = 1
time.sleep(1)
res = []
if fail == 0:
res = cur.fetchall()
i = i + 1
logging.debug('database {}, return flag {}, execute sql {}\n'.format(self.args.dbname, 1 - fail, sql))
cur.close()
self.conn.close()
print("========== finish time:", time.time())
if fail == 1:
# raise RuntimeError("Database query failed")
print("SQL Execution Fatal!!")
return 0, ''
elif fail == 0:
# print("SQL Execution Succeed!!")
return 1, res
def pgsql_results(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql(sql)
#print("pgsql_results", success, res)
if success == 1:
return res
else:
return "<fail>"
except Exception as error:
logging.error('pgsql_results Exception', error)
return "<fail>"
def pgsql_query_plan(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql('explain (FORMAT JSON) ' + sql)
if success == 1:
plan = res[0][0][0]['Plan']
return plan
else:
logging.error('pgsql_query_plan Fails!')
return 0
except Exception as error:
logging.error('pgsql_query_plan Exception', error)
return 0
def pgsql_cost_estimation(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql('explain (FORMAT JSON) ' + sql)
if success == 1:
cost = res[0][0][0]['Plan']['Total Cost']
return cost
else:
logging.error('pgsql_cost_estimation Fails!')
return 0
except Exception as error:
logging.error('pgsql_cost_estimation Exception', error)
return 0
def pgsql_actual_time(self, sql):
try:
#success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
if success == 1:
cost = res[0][0][0]['Plan']['Actual Total Time']
return cost
else:
return -1
except Exception as error:
logging.error('pgsql_actual_time Exception', error)
return -1
def mysql_cost_estimation(self, sql):
try:
success, res = self.execute_sql('explain format=json ' + sql)
if success == 1:
total_cost = self.get_mysql_total_cost(0, json.loads(res[0][0]))
return float(total_cost)
else:
return -1
except Exception as error:
logging.error('mysql_cost_estimation Exception', error)
return -1
def get_mysql_total_cost(self, total_cost, res):
if isinstance(res, dict):
if 'query_cost' in res.keys():
total_cost += float(res['query_cost'])
else:
for key in res:
total_cost += self.get_mysql_total_cost(0, res[key])
elif isinstance(res, list):
for i in res:
total_cost += self.get_mysql_total_cost(0, i)
return total_cost
def get_tables(self):
if self.args.dbtype == 'mysql':
return self.mysql_get_tables()
else:
return self.pgsql_get_tables()
# query cost estimated by the optimizer
def cost_estimation(self, sql):
if self.args.dbtype == 'mysql':
return self.mysql_cost_estimation(sql)
else:
return self.pgsql_cost_estimation(sql)
def compute_table_schema(self):
"""
schema: {table_name: [field_name]}
:param cursor:
:return:
"""
if self.args.dbtype == 'postgresql':
# cur_path = os.path.abspath('.')
# tpath = cur_path + '/sampled_data/'+dbname+'/schema'
sql = 'SELECT table_name FROM information_schema.tables WHERE table_schema = \'public\';'
success, res = self.execute_sql(sql)
#print("======== tables", res)
if success == 1:
tables = res
schema = {}
for table_info in tables:
table_name = table_info[0]
sql = 'SELECT column_name, data_type FROM information_schema.columns WHERE table_name = \'' + table_name + '\';'
success, res = self.execute_sql(sql)
#print("======== table columns", res)
columns = res
schema[table_name] = []
for col in columns:
''' compute the distinct value ratio of the column
if transfer_field_type(col[1], self.args.dbtype) == DataType.VALUE.value:
sql = 'SELECT count({}) FROM {};'.format(col[0], table_name)
success, res = self.execute_sql(sql)
print("======== column rows", res)
num = res
if num[0][0] != 0:
schema[table_name].append(col[0])
'''
#schema[table_name].append("column {} is of {} type".format(col[0], col[1]))
schema[table_name].append("{}".format(col[0]))
'''
with open(tpath, 'w') as f:
f.write(str(schema))
'''
#print(schema)
return schema
else:
logging.error('pgsql_cost_estimation Fails!')
return 0
def simulate_index(self, index):
#table_name = index.table()
statement = (
"SELECT * FROM hypopg_create_index(E'{}');".format(index)
)
result = self.execute_sql(statement)
return result
def drop_simulated_index(self, oid):
statement = f"select * from hypopg_drop_index({oid})"
result = self.execute_sql(statement)
assert result[0] is True, f"Could not drop simulated index with oid = {oid}."

@ -0,0 +1,71 @@
import argparse
import configparser
import logging
def get_conf(conf_file, server_name):
conf = configparser.ConfigParser()
conf.read(conf_file)
sql_server = conf[server_name]
return sql_server
def get_parser():
parser = argparse.ArgumentParser(
description="Instruction Induction.")
parser.add_argument("--db_conf", type=str,
default = '../database/configs/config.ini')
"""
parser.add_argument("--train_data", type=str,
default="./data/raw/train/rules.json")
parser.add_argument("--eval_data", type=str,
default="./data/raw/execute/zhenzhi.json")
parser.add_argument("--data_save", type=str,
default="./result/{}/data/")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--runlog", type=str,
default="./result/{}/exp_runtime.log")
parser.add_argument("--logdir", type=str,
default="./result/{}/logdir/")
parser.add_argument("--model_save", type=str,
default="./result/{}/model/")
parser.add_argument("--gen_sample", type=int, default=20)
parser.add_argument("--gen_demo", type=int, default=16)
parser.add_argument("--gen_prompt_per_sample", type=int, default=5)
parser.add_argument("--gen_model", type=str, default="text-davinci-003")
parser.add_argument("--gen_max_tokens", type=int, default=200)
parser.add_argument("--eval_sample", type=int, default=20)
parser.add_argument("--eval_model", type=str, default="text-davinci-003")
parser.add_argument("--eval_max_tokens", type=int, default=1000)
parser.add_argument("--storage_budget", type=int, default=500) # limit storage space of built indexes
"""
return parser
def set_logger(log_file):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s: - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
# log to file
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
# log to console
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.addHandler(fh)

@ -0,0 +1,36 @@
import re
import sqlparse
def remove_create_table(sql):
return re.sub(r'(create|CREATE)\s+(table|TABLE).+?\(.+?\)\s*;','',sql, flags=re.DOTALL)
def remove_create_index(sql):
return re.sub(r'(create|CREATE)\s+(index|INDEX).+?\(.+?\)\s*;','',sql, flags=re.DOTALL)
def remove_table(sql):
return re.sub(r'(table|TABLE).+?\(.+?\)\s*;','',sql, flags=re.DOTALL)
def clean_sql(sql):
tmp = []
for token in sql.flatten():
if not token.is_whitespace and not token.ttype is sqlparse.tokens.Comment.Single:
tmp.append(token)
return strip_par(' '.join(str(t) for t in tmp))
def strip_par(s):
for op in ['(',')',',','>','=','<','>=','<=','!=','<>','.',';']:
s = s.replace(' {}'.format(op), op).replace('{} '.format(op), op)
return s
def preprocess_execute_sql(sql):
sql = remove_create_table(sql)
sql = remove_create_index(sql)
parsed = sqlparse.parse(sql)
if len(parsed) == 0:
return [0, '']
sql = clean_sql(parsed[0])
if not sql:
return [0, '']
if sql[-1] != ';':
sql += ';'
return [1, sql]

@ -0,0 +1,355 @@
"""Contains classes for querying large language models."""
from math import ceil
import os
import time
from tqdm import tqdm
from abc import ABC, abstractmethod
import openai
gpt_costs_per_thousand = {
'davinci': 0.0200,
'curie': 0.0020,
'babbage': 0.0005,
'ada': 0.0004
}
def model_from_config(config, disable_tqdm=True):
"""Returns a model based on the config."""
model_type = config["name"]
if model_type == "GPT_forward":
return GPT_Forward(config, disable_tqdm=disable_tqdm)
elif model_type == "GPT_insert":
return GPT_Insert(config, disable_tqdm=disable_tqdm)
raise ValueError(f"Unknown model type: {model_type}")
class LLM(ABC):
"""Abstract base class for large language models."""
@abstractmethod
def generate_text(self, prompt):
"""Generates text from the model.
Parameters:
prompt: The prompt to use. This can be a string or a list of strings.
Returns:
A list of strings.
"""
pass
@abstractmethod
def log_probs(self, text, log_prob_range):
"""Returns the log probs of the text.
Parameters:
text: The text to get the log probs of. This can be a string or a list of strings.
log_prob_range: The range of characters within each string to get the log_probs of.
This is a list of tuples of the form (start, end).
Returns:
A list of log probs.
"""
pass
class GPT_Forward(LLM):
"""Wrapper for GPT-3."""
def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
"""Initializes the model."""
self.config = config
self.needs_confirmation = needs_confirmation
self.disable_tqdm = disable_tqdm
def confirm_cost(self, texts, n, max_tokens):
total_estimated_cost = 0
for text in texts:
total_estimated_cost += gpt_get_estimated_cost(
self.config, text, max_tokens) * n
print(f"Estimated cost: ${total_estimated_cost:.2f}")
# Ask the user to confirm in the command line
if os.getenv("LLM_SKIP_CONFIRM") is None:
confirm = input("Continue? (y/n) ")
if confirm != 'y':
raise Exception("Aborted.")
def auto_reduce_n(self, fn, prompt, n):
"""Reduces n by half until the function succeeds."""
try:
return fn(prompt, n)
except BatchSizeException as e:
if n == 1:
raise e
return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
def generate_text(self, prompt, n):
if not isinstance(prompt, list):
prompt = [prompt]
if self.needs_confirmation:
self.confirm_cost(
prompt, n, self.config['gpt_config']['max_tokens'])
batch_size = self.config['batch_size']
prompt_batches = [prompt[i:i + batch_size]
for i in range(0, len(prompt), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
f"split into {len(prompt_batches)} batches of size {batch_size * n}")
text = []
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
return text
def complete(self, prompt, n):
"""Generates text from the model and returns the log prob data."""
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = self.config['batch_size']
prompt_batches = [prompt[i:i + batch_size]
for i in range(0, len(prompt), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
f"split into {len(prompt_batches)} batches of size {batch_size * n}")
res = []
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
res += self.__complete(prompt_batch, n)
return res
def log_probs(self, text, log_prob_range=None):
"""Returns the log probs of the text."""
if not isinstance(text, list):
text = [text]
if self.needs_confirmation:
self.confirm_cost(text, 1, 0)
batch_size = self.config['batch_size']
text_batches = [text[i:i + batch_size]
for i in range(0, len(text), batch_size)]
if log_prob_range is None:
log_prob_range_batches = [None] * len(text)
else:
assert len(log_prob_range) == len(text)
log_prob_range_batches = [log_prob_range[i:i + batch_size]
for i in range(0, len(log_prob_range), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
log_probs = []
tokens = []
for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
disable=self.disable_tqdm):
log_probs_batch, tokens_batch = self.__log_probs(
text_batch, log_prob_range)
log_probs += log_probs_batch
tokens += tokens_batch
return log_probs, tokens
def __generate_text(self, prompt, n):
"""Generates text from the model."""
if not isinstance(prompt, list):
text = [prompt]
config = self.config['gpt_config'].copy()
config['n'] = n
# If there are any [APE] tokens in the prompts, remove them
for i in range(len(prompt)):
prompt[i] = prompt[i].replace('[APE]', '').strip()
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=prompt)
except Exception as e:
if 'is greater than the maximum' in str(e):
raise BatchSizeException()
print(e)
print('Retrying...')
time.sleep(5)
return [response['choices'][i]['text'] for i in range(len(response['choices']))]
def __complete(self, prompt, n):
"""Generates text from the model and returns the log prob data."""
if not isinstance(prompt, list):
text = [prompt]
config = self.config['gpt_config'].copy()
config['n'] = n
# If there are any [APE] tokens in the prompts, remove them
for i in range(len(prompt)):
prompt[i] = prompt[i].replace('[APE]', '').strip()
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=prompt)
except Exception as e:
print(e)
print('Retrying...')
time.sleep(5)
return response['choices']
def __log_probs(self, text, log_prob_range=None):
"""Returns the log probs of the text."""
if not isinstance(text, list):
text = [text]
if log_prob_range is not None:
for i in range(len(text)):
lower_index, upper_index = log_prob_range[i]
assert lower_index < upper_index
assert lower_index >= 0
assert upper_index - 1 < len(text[i])
config = self.config['gpt_config'].copy()
config['logprobs'] = 1
config['echo'] = True
config['max_tokens'] = 0
if isinstance(text, list):
text = [f'\n{text[i]}' for i in range(len(text))]
else:
text = f'\n{text}'
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=text)
except Exception as e:
print(e)
print('Retrying...')
time.sleep(5)
log_probs = [response['choices'][i]['logprobs']['token_logprobs'][1:]
for i in range(len(response['choices']))]
tokens = [response['choices'][i]['logprobs']['tokens'][1:]
for i in range(len(response['choices']))]
offsets = [response['choices'][i]['logprobs']['text_offset'][1:]
for i in range(len(response['choices']))]
# Subtract 1 from the offsets to account for the newline
for i in range(len(offsets)):
offsets[i] = [offset - 1 for offset in offsets[i]]
if log_prob_range is not None:
# First, we need to find the indices of the tokens in the log probs
# that correspond to the tokens in the log_prob_range
for i in range(len(log_probs)):
lower_index, upper_index = self.get_token_indices(
offsets[i], log_prob_range[i])
log_probs[i] = log_probs[i][lower_index:upper_index]
tokens[i] = tokens[i][lower_index:upper_index]
return log_probs, tokens
def get_token_indices(self, offsets, log_prob_range):
"""Returns the indices of the tokens in the log probs that correspond to the tokens in the log_prob_range."""
# For the lower index, find the highest index that is less than or equal to the lower index
lower_index = 0
for i in range(len(offsets)):
if offsets[i] <= log_prob_range[0]:
lower_index = i
else:
break
upper_index = len(offsets)
for i in range(len(offsets)):
if offsets[i] >= log_prob_range[1]:
upper_index = i
break
return lower_index, upper_index
class GPT_Insert(LLM):
def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
"""Initializes the model."""
self.config = config
self.needs_confirmation = needs_confirmation
self.disable_tqdm = disable_tqdm
def confirm_cost(self, texts, n, max_tokens):
total_estimated_cost = 0
for text in texts:
total_estimated_cost += gpt_get_estimated_cost(
self.config, text, max_tokens) * n
print(f"Estimated cost: ${total_estimated_cost:.2f}")
# Ask the user to confirm in the command line
if os.getenv("LLM_SKIP_CONFIRM") is None:
confirm = input("Continue? (y/n) ")
if confirm != 'y':
raise Exception("Aborted.")
def auto_reduce_n(self, fn, prompt, n):
"""Reduces n by half until the function succeeds."""
try:
return fn(prompt, n)
except BatchSizeException as e:
if n == 1:
raise e
return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
def generate_text(self, prompt, n):
if not isinstance(prompt, list):
prompt = [prompt]
if self.needs_confirmation:
self.confirm_cost(
prompt, n, self.config['gpt_config']['max_tokens'])
batch_size = self.config['batch_size']
assert batch_size == 1
prompt_batches = [prompt[i:i + batch_size]
for i in range(0, len(prompt), batch_size)]
if not self.disable_tqdm:
print(
f"[{self.config['name']}] Generating {len(prompt) * n} completions, split into {len(prompt_batches)} batches of (maximum) size {batch_size * n}")
text = []
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
return text
def log_probs(self, text, log_prob_range=None):
raise NotImplementedError
def __generate_text(self, prompt, n):
"""Generates text from the model."""
config = self.config['gpt_config'].copy()
config['n'] = n
# Split prompts into prefixes and suffixes with the [APE] token (do not include the [APE] token in the suffix)
prefix = prompt[0].split('[APE]')[0]
suffix = prompt[0].split('[APE]')[1]
response = None
while response is None:
try:
response = openai.Completion.create(
**config, prompt=prefix, suffix=suffix)
except Exception as e:
print(e)
print('Retrying...')
time.sleep(5)
# Remove suffix from the generated text
texts = [response['choices'][i]['text'].replace(suffix, '') for i in range(len(response['choices']))]
return texts
def gpt_get_estimated_cost(config, prompt, max_tokens):
"""Uses the current API costs/1000 tokens to estimate the cost of generating text from the model."""
# Get rid of [APE] token
prompt = prompt.replace('[APE]', '')
# Get the number of tokens in the prompt
n_prompt_tokens = len(prompt) // 4
# Get the number of tokens in the generated text
total_tokens = n_prompt_tokens + max_tokens
engine = config['gpt_config']['model'].split('-')[1]
costs_per_thousand = gpt_costs_per_thousand
if engine not in costs_per_thousand:
# Try as if it is a fine-tuned model
engine = config['gpt_config']['model'].split(':')[0]
costs_per_thousand = {
'davinci': 0.1200,
'curie': 0.0120,
'babbage': 0.0024,
'ada': 0.0016
}
price = costs_per_thousand[engine] * total_tokens / 1000
return price
class BatchSizeException(Exception):
pass

@ -0,0 +1,114 @@
import time
import types
from typing import Any, Dict, List, Tuple, Union
from langchain.agents import AgentExecutor
from langchain.input import get_color_mapping
from langchain.schema import AgentAction, AgentFinish
from .translator import Translator
class AgentExecutorWithTranslation(AgentExecutor):
translator: Translator = Translator()
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
try:
outputs = super().prep_outputs(inputs, outputs, return_only_outputs)
except ValueError as e:
return outputs
else:
if "input" in outputs:
outputs = self.translator(outputs)
return outputs
class Executor(AgentExecutorWithTranslation):
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through
iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations, time_elapsed):
next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
if isinstance(next_step_output, AgentFinish):
yield self._return(next_step_output, intermediate_steps)
return
for i, output in enumerate(next_step_output):
agent_action = output[0]
tool_logo = None
for tool in self.tools:
if tool.name == agent_action.tool:
tool_logo = tool.tool_logo_md
if isinstance(output[1], types.GeneratorType):
logo = f"{tool_logo}" if tool_logo is not None else ""
yield (AgentAction("", agent_action.tool_input, agent_action.log), f"Further use other tool {logo} to answer the question.")
for out in output[1]:
yield out
next_step_output[i] = (agent_action, out)
else:
for tool in self.tools:
if tool.name == agent_action.tool:
yield (AgentAction(tool_logo, agent_action.tool_input, agent_action.log), output[1])
intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
# See if tool should return directly
tool_return = self._get_tool_return(next_step_action)
if tool_return is not None:
yield self._return(tool_return, intermediate_steps)
return
iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
yield self._return(output, intermediate_steps)
return
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
for output in self._call(inputs):
if type(output) is dict:
output = self.prep_outputs(inputs, output, return_only_outputs)
yield output
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(output, verbose=self.verbose)
# return self.prep_outputs(inputs, output, return_only_outputs)
return output

@ -0,0 +1,3 @@
# File operation tool
Contributor: [Yujia Qin](https://github.com/thuqinyj16)

@ -0,0 +1,7 @@
from ..registry import register
@register("file_operation")
def file_operation():
from .api import build_tool
return build_tool

@ -0,0 +1,45 @@
from pathlib import Path
from ..tool import Tool
def build_tool(config) -> Tool:
tool = Tool(
"File Operation Tool",
"Write / read file to / from disk",
name_for_model="file_operation",
description_for_model="Plugin for operating files",
logo_url=None,
contact_email=None,
legal_info_url=None
)
@tool.get("/write_file")
def write_file(file_path: str, text: str) -> str:
'''write file to disk
'''
write_path = (
Path(file_path)
)
try:
write_path.parent.mkdir(exist_ok=True, parents=False)
with write_path.open("w", encoding="utf-8") as f:
f.write(text)
return f"File written successfully to {file_path}."
except Exception as e:
return "Error: " + str(e)
@tool.get("/read_file")
def read_file(file_path: str) -> str:
'''read file from disk
'''
read_path = (
Path(file_path)
)
try:
with read_path.open("r", encoding="utf-8") as f:
content = f.read()
return content
except Exception as e:
return "Error: " + str(e)
return tool

@ -0,0 +1,9 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'file_operation', "http://127.0.0.1:8079/tools/file_operation/"
tool_name, tool_config = load_single_tools(tool_name, tool_url)
print(tool_name, tool_config)
stqa = STQuestionAnswerer()
agent = stqa.load_tools(tool_name, tool_config)
agent("write hello world to test.txt")

@ -0,0 +1,6 @@
from ..registry import register
@register("douban-film")
def douban_film():
from .douban import build_tool
return build_tool

@ -0,0 +1 @@
from .api import build_tool

@ -0,0 +1,299 @@
import requests
from lxml import etree
import pandas as pd
import re
from ...tool import Tool
from typing import List
from typing_extensions import TypedDict
class ComingMovieInfo(TypedDict):
date : str
title : str
cate : str
region : str
wantWatchPeopleNum : str
link : str
class PlayingMovieInfo(TypedDict):
title : str
score : str
region : str
director : str
actors : str
link : str
class DoubanAPI:
def __init__(self) -> None:
self._endpoint = "https://movie.douban.com"
self._headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/108.0.0.0 Safari/537.36'
}
def fetch_page(self, url: str):
"""fetch_page(url: str) print html text of url
"""
s = requests.session()
s.keep_alive = False
response = s.get(url, headers=self._headers, verify=False)
return response
def get_coming(self) -> List[ComingMovieInfo]:
response = self.fetch_page(f"{self._endpoint}/coming")
ret : List[ComingMovieInfo] = []
parser = etree.HTMLParser(encoding='utf-8')
tree = etree.HTML(response.text, parser=parser)
movies_table_path = '//*[@id="content"]/div/div[1]/table/tbody'
movies_table = tree.xpath(movies_table_path)
for filmChild in movies_table[0].iter('tr'):
filmTime = filmChild.xpath('td[1]/text()')[0].strip()
filmName = filmChild.xpath('td[2]/a/text()')[0]
filmType = filmChild.xpath('td[3]/text()')[0].strip()
filmRegion = filmChild.xpath('td[4]/text()')[0].strip()
filmWantWatching = filmChild.xpath('td[5]/text()')[0].strip()
filmLink = filmChild.xpath('td[2]/a/@href')[0]
ret.append(ComingMovieInfo(
date=filmTime,
title=filmName,
cate=filmType,
region=filmRegion,
wantWatchPeopleNum=filmWantWatching,
link=filmLink
))
return ret
def get_now_playing(self) -> List[PlayingMovieInfo]:
# Get the movie list currently on show, the movie list of different cities is the same
response = self.fetch_page(f"{self._endpoint}/cinema/nowplaying/beijing/")
ret : List[PlayingMovieInfo] = []
parser = etree.HTMLParser(encoding='utf-8')
tree = etree.HTML(response.text, parser=parser)
movies_table_path = './/div[@id="nowplaying"]/div[2]/ul'
movies_table = tree.xpath(movies_table_path)
for filmChild in movies_table[0]:
filmName = filmChild.xpath('@data-title')[0]
filmScore = filmChild.xpath('@data-score')[0]
filmRegion = filmChild.xpath('@data-region')[0]
filmDirector = filmChild.xpath('@data-director')[0]
filmActors = filmChild.xpath('@data-actors')[0]
filmLink = filmChild.xpath('ul/li[1]/a/@href')[0]
ret.append(PlayingMovieInfo(
title=filmName,
score=filmScore,
region=filmRegion,
director=filmDirector,
actors=filmActors,
link=filmLink
))
return ret
def get_movie_detail(self, url : str) -> str:
response = self.fetch_page(url)
parser = etree.HTMLParser(encoding='utf-8')
tree = etree.HTML(response.text, parser=parser)
info_path = './/div[@class="subject clearfix"]/div[2]'
director = tree.xpath(f'{info_path}/span[1]/span[2]/a/text()')[0]
actors = []
actors_spans = tree.xpath(f'{info_path}/span[3]/span[2]')[0]
for actors_span in actors_spans:
actors.append(actors_span.text)
actors = ''.join(actors[:3])
types = []
spans = tree.xpath(f'{info_path}')[0]
for span in spans.iter('span'):
if 'property' in span.attrib and span.attrib['property']=='v:genre':
types.append(span.text)
types = ''.join(types)
for span in spans:
if span.text=='制片国家/地区:':
region = span.tail.strip()
break
Synopsis = tree.xpath('.//div[@class="related-info"]/div/span')[0].text.strip()
detail = f'是一部{region}{types}电影,由{director}导演,{actors}等人主演.\n剧情简介:{Synopsis}'
return detail
def build_tool(config) -> Tool:
tool = Tool(
"Film Search Plugin",
"search for up-to-date film information.",
name_for_model="Film Search",
description_for_model="Plugin for search for up-to-date film information.",
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
if "debug" in config and config["debug"]:
douban_api = config["douban_api"]
else:
douban_api = DoubanAPI()
@tool.get("/coming_out_filter")
def coming_out_filter(args : str):
"""coming_out_filter(args: str) prints the details of the filtered [outNum] coming films now according to region, cate and outNum.
args is a list like 'str1, str2, str3, str4'
str1 represents Production country or region. If you cannot find a region, str1 is 全部
str2 represents movie's category. If you cannot find a category, str2 is 全部
str3 can be a integer number that agent want to get. If you cannot find a number, str2 is 100. If the found movie's num is less than str2, Final Answer only print [the found movie's num] movies.
str4 can be a True or False that refluct whether agent want the result sorted by people number which look forward to the movie.
Final answer should be complete.
This is an example:
Thought: I need to find the upcoming Chinese drama movies and the top 2 most wanted movies
Action: coming_out_filter
Action Input: {"args" : "中国, 剧情, 2, True"}
Observation: {"date":{"23":"04月28日","50":"07月"},"title":{"23":"长空之王","50":"热烈"},"cate":{"23":"剧情 / 动作","50":"剧情 / 喜剧"},"region":{"23":"中国大陆","50":"中国大陆"},"wantWatchPeopleNum":{"23":"39303人","50":"26831人"}}
Thought: I now know the top 2 upcoming Chinese drama movies
Final Answer: 即将上映的中国剧情电影有2部长空之王热烈大家最想看的前2部分别是长空之王热烈
"""
args = re.findall(r'\b\w+\b', args)
region = args[0]
if region=='全部':
region = ''
cate = args[1]
if cate=='全部':
cate = ''
outNum = int(args[2])
WantSort = True if args[3]=='True' else False
coming_movies = []
for movie in douban_api.get_coming():
if (cate in movie["cate"]) and (region in movie["region"]):
coming_movies.append({
"date": movie["date"],
"title": movie["title"],
"cate": movie["cate"],
"region": movie["region"],
"wantWatchPeopleNum": int(movie["wantWatchPeopleNum"].replace("", "")),
"link": movie["link"]
})
# Sort by people that are looking forward to the movie
if WantSort:
coming_movies = sorted(coming_movies, key=lambda x: x["wantWatchPeopleNum"], reverse=True)
ret = {
"date": {},
"title": {},
"cate": {},
"region": {},
"wantWatchPeopleNum": {},
}
for i, movie in enumerate(coming_movies[:outNum]):
i = str(i)
ret["date"][i] = movie["date"]
ret["title"][i] = movie["title"]
ret["cate"][i] = movie["cate"]
ret["region"][i] = movie["region"]
ret["wantWatchPeopleNum"][i] = "{}".format(movie["wantWatchPeopleNum"])
return ret
@tool.get("/now_playing_out_filter")
def now_playing_out_filter(args : str):
"""NowPlayingOutFilter(args: str) prints the details of the filtered [outNum] playing films now according to region, scoreSort
args is a list like 'str1, str2, str3'
str1 can be '中国','日本' or other Production country or region. If you cannot find a region, str1 is 全部
str2 can be a integer number that agent want to get. If you cannot find a number, str2 is 100. If the found movie's num is less than str2, Final Answer only print [the found movie's num] movies.
str3 can be a True or False that refluct whether agent want the result sorted by score.
Final answer should be complete.
This is an example:
Input: 您知道现在有正在上映中国的电影吗请输出3部
Thought: I need to find the currently playing movies with the highest scores
Action: now_playing_out_filter
Action Input: {"args" : "全部, 3, True"}
Observation: {"title":{"34":"切腹","53":"吉赛尔","31":"小森林 夏秋篇"},"score":{"34":"9.4","53":"9.2","31":"9.0"},"region":{"34":"日本","53":"西德","31":"日本"},"director":{"34":"小林正树","53":"Hugo Niebeling","31":"森淳一"},"actors":{"34":"仲代达矢 / 石浜朗 / 岩下志麻","53":"卡拉·弗拉奇 / 埃里克·布鲁恩 / Bruce Marks","31":"桥本爱 / 三浦贵大 / 松冈茉优"}}
Thought: I now know the currently playing movies with the highest scores
Final Answer: 现在上映的评分最高的3部电影是切腹吉赛尔小森林 夏秋篇
"""
args = re.findall(r'\b\w+\b', args)
region = args[0]
if region=='全部':
region = ''
outNum = int(args[1])
scoreSort = True if args[2]=='True' else False
playing_movies = []
for movie in douban_api.get_now_playing():
if region in movie["region"]:
playing_movies.append({
"title": movie["title"],
"score": float(movie["score"]),
"region": movie["region"],
"director": movie["director"],
"actors": movie["actors"],
"link": movie["link"]
})
# Sort by score
if scoreSort:
playing_movies = sorted(playing_movies, key=lambda x: x["score"], reverse=True)
ret = {
"title": {},
"score": {},
"region": {},
"director": {},
"actors": {},
}
for i, movie in enumerate(playing_movies[:outNum]):
i = str(i)
ret["title"][i] = movie["title"]
ret["score"][i] = "{}".format(movie["score"])
ret["region"][i] = movie["region"]
ret["director"][i] = movie["director"]
ret["actors"][i] = movie["actors"]
return ret
@tool.get("/print_detail")
def print_detail(args : str):
"""parsing_detail_page(args) prints the details of a movie, giving its name.
args is a list like 'str1'
str1 is target movie's name.
step1: apply function parse_coming_page and parse_nowplaying_page and get all movie's links and other infomation.
step2: get the target movie's link from df_coming or df_nowplaying
step3: get detail from step2's link
This is an example:
Input: "电影流浪地球2怎么样"
Thought: I need to find the movie's information
Action: print_detail
Action Input: {"args" : "流浪地球2"}
Observation: "是一部中国大陆的科幻、冒险、灾难电影,由郭帆导演,吴京、刘德华、李雪健等人主演.\n剧情简介:太阳即将毁灭,人类在地球表面建造出巨大的推进器,寻找新的家园。然而宇宙之路危机四伏,为了拯救地球,流浪地球时代的年轻人再次挺身而出,展开争分夺秒的生死之战。"
Thought: I now know the final answer
Final Answer: 流浪地球2是一部中国大陆的科幻冒险灾难电影由郭帆导演吴京刘德华李雪健等人主演剧情简介是太阳即将毁灭人类在地球表面建造出巨大的推进器寻找新的家园然而宇宙之路危机四伏为了拯救地球流浪地球时代的年轻人再次挺身而出
"""
args = re.findall(r'\b\w+\b', args)
filmName = args[0]
link = None
if link is None:
for movie in douban_api.get_coming():
if movie["title"] == filmName:
link = movie["link"]
break
if link is None:
for movie in douban_api.get_now_playing():
if movie["title"] == filmName:
link = movie["link"]
break
if link is None:
return "没有找到该电影"
return "{}{}".format(filmName, douban_api.get_movie_detail(link))
return tool

@ -0,0 +1,34 @@
# Douban Film Search
Contributor: [Jing Yi](https://github.com/yijing16)
## Tool Description
The "Film Search Plugin" is a robust tool that allows you to access up-to-date film information, filter through this information, and retrieve detailed descriptions of specific films. It utilizes a fictional API called "DoubanAPI" to pull information on films.
### Tool Specifications
- **Name**: Film Search Plugin
- **Purpose**: Search for up-to-date film information.
- **Name for Model**: Film Search
- **Model Description**: Plugin for search for up-to-date film information.
- **Logo URL**: [logo](https://your-app-url.com/.well-known/logo.png)
- **Contact Email**: hello@contact.com
- **Legal Information URL**: hello@legal.com
### Core Functionality
1. `coming_out_filter`
This method accepts a string argument that contains details about the desired region, category, number of films, and a flag indicating if sorting is needed based on the number of people looking forward to the movie. It filters through the upcoming films based on these details and provides the information.
2. `now_playing_out_filter`
This method accepts a string argument that contains details about the desired region, the number of films, and a flag indicating if sorting is needed based on the film score. It filters through the currently playing films based on these details and provides the information.
3. `print_detail`
This method accepts a string argument that contains the name of a specific film. It provides detailed information about the film, including the genre, director, actors, and a brief synopsis of the film's plot.
All methods use the `DoubanAPI` to retrieve and filter information on films.
**Note**: This tool's functionality is hypothetical and relies on the existence and proper functioning of a fictional API, DoubanAPI, which is not included in the provided code. In a real-world application, replace DoubanAPI with a functional API that can retrieve film data.

@ -0,0 +1,15 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'douban', "http://127.0.0.1:8079/tools/douban-film/"
tools_name, tools_config = load_single_tools(tool_name, tool_url)
# tools_name, tools_config = load_single_tools()
print(tools_name, tools_config)
qa = STQuestionAnswerer()
agent = qa.load_tools(tools_name, tools_config)
agent("有哪些即将上映的中国喜剧电影哪些是大家最想看的前5部")
agent("想去电影院看一些国产电影有评分高的吗输出3部")
agent("帮我介绍下《深海》这部电影")

@ -0,0 +1,97 @@
from fastapi.testclient import TestClient
from .api import build_tool, DoubanAPI, ComingMovieInfo, PlayingMovieInfo
from typing import List
class DoubanMock(DoubanAPI):
def __init__(self) -> None:
pass
def get_coming(self) -> List[ComingMovieInfo]:
return [
ComingMovieInfo(date="2020-12-12", title="test1", cate="test1", region="test1", wantWatchPeopleNum="1", link="test1"),
ComingMovieInfo(date="2021-12-12", title="test2", cate="test2", region="test2", wantWatchPeopleNum="2", link="test2"),
ComingMovieInfo(date="2022-12-12", title="test3", cate="test3", region="test3", wantWatchPeopleNum="3", link="test3"),
]
def get_now_playing(self) -> List[PlayingMovieInfo]:
return [
PlayingMovieInfo(title="test1", score="1.1", region="test1", director="test1", actors="test1", link="test1"),
PlayingMovieInfo(title="test2", score="2.2", region="test2", director="test2", actors="test2", link="test2"),
PlayingMovieInfo(title="test3", score="3.3", region="test3", director="test3", actors="test3", link="test3"),
]
def get_movie_detail(self, url : str) -> str:
return url
app = build_tool({"debug": True, "douban_api": DoubanMock()})
client = TestClient(app)
def test_get_coming():
response = client.get("/coming_out_filter", params={
"args": "全部, 全部, 2, True"
})
assert response.status_code == 200
assert response.json() == {
"date": {
"1": "2021-12-12",
"0": "2022-12-12",
},
"title": {
"1": "test2",
"0": "test3",
},
"cate": {
"1": "test2",
"0": "test3",
},
"region": {
"1": "test2",
"0": "test3",
},
"wantWatchPeopleNum": {
"1": "2人",
"0": "3人",
},
}
def test_get_playing():
response = client.get("/now_playing_out_filter", params={
"args": "全部, 3, False"
})
assert response.status_code == 200
assert response.json() == {
"title": {
"0": "test1",
"1": "test2",
"2": "test3",
},
"score": {
"0": "1.1",
"1": "2.2",
"2": "3.3",
},
"region": {
"0": "test1",
"1": "test2",
"2": "test3",
},
"director": {
"0": "test1",
"1": "test2",
"2": "test3",
},
"actors": {
"0": "test1",
"1": "test2",
"2": "test3",
},
}
def test_detail():
response = client.get("/print_detail", params={
"args": "test1"
})
assert response.status_code == 200
assert response.json() == "test1test1"

@ -0,0 +1,6 @@
from ..registry import register
@register("google_places")
def google_places():
from .api import build_tool
return build_tool

@ -0,0 +1,93 @@
import requests
import json
from ..tool import Tool
import os
from typing import Any, Dict, List, Optional
import googlemaps
class GooglePlacesAPIWrapper:
def __init__(self, subscription_key) -> None:
self.gplaces_api_key: str = subscription_key
self.google_map_client = googlemaps.Client(self.gplaces_api_key)
self.top_k_results: Optional[int] = None
def run(self, query: str) -> str:
"""Run Places search and get k number of places that exists that match."""
search_results = self.google_map_client.places(query)["results"]
num_to_return = len(search_results)
places = []
if num_to_return == 0:
return "Google Places did not find any places that match the description"
num_to_return = (
num_to_return
if self.top_k_results is None
else min(num_to_return, self.top_k_results)
)
for i in range(num_to_return):
result = search_results[i]
details = self.fetch_place_details(result["place_id"])
if details is not None:
places.append(details)
return "\n".join([f"{i+1}. {item}" for i, item in enumerate(places)])
def fetch_place_details(self, place_id: str) -> Optional[str]:
try:
place_details = self.google_map_client.place(place_id)
formatted_details = self.format_place_details(place_details)
return formatted_details
except Exception as e:
logging.error(f"An Error occurred while fetching place details: {e}")
return None
def format_place_details(self, place_details: Dict[str, Any]) -> Optional[str]:
try:
name = place_details.get("result", {}).get("name", "Unkown")
address = place_details.get("result", {}).get(
"formatted_address", "Unknown"
)
phone_number = place_details.get("result", {}).get(
"formatted_phone_number", "Unknown"
)
website = place_details.get("result", {}).get("website", "Unknown")
formatted_details = (
f"{name}\nAddress: {address}\n"
f"Phone: {phone_number}\nWebsite: {website}\n\n"
)
return formatted_details
except Exception as e:
logging.error(f"An error occurred while formatting place details: {e}")
return None
def build_tool(config) -> Tool:
tool = Tool(
"google_places",
"Look up for information about places and locations",
name_for_model="google_places",
description_for_model=(
"A tool that query the Google Places API. "
"Useful for when you need to validate or "
"discover addressed from ambiguous text. "
"Input should be a search query."
),
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
api_wrapper = GooglePlacesAPIWrapper(config["subscription_key"])
@tool.get("/search_places")
def search_places(query : str):
"""Run Places search."""
return str(api_wrapper.run(query))
return tool

@ -0,0 +1,25 @@
# Google Places Queries
Contributor: [Sihan Zhao](https://github.com/Sarah816)
## Tool Description
The "Google Places" tool allows you to fetch information about places and locations by querying the Google Places API. This tool can be especially useful when you need to validate or discover addresses from ambiguous text.
### Tool Specifications
- **Name**: Google Places
- **Purpose**: Look up for information about places and locations
- **Name for Model**: google_places
- **Model Description**: A tool that query the Google Places API. Useful for when you need to validate or discover addressed from ambiguous text. Input should be a search query.
- **Contact Email**: hello@contact.com
- **Legal Information URL**: hello@legal.com
### Core Functionality
1. `search_places`
This method accepts a string argument that contains a search query. It queries the Google Places API with the given input and returns information about the corresponding places and locations.
The tool leverages a fictional wrapper class called `GooglePlacesAPIWrapper` to interact with the Google Places API and execute the required functionalities.
**note**: The GooglePlacesAPIWrapper is a placeholder here and in a real-world scenario, this should be replaced with a properly implemented class that can fetch data from the Google Places API. The Google Places API itself requires an API key for access, which is not provided in this code.

@ -0,0 +1,12 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'google_places', "http://127.0.0.1:8079/tools/google_places/"
tools_name, tools_config = load_single_tools(tool_name, tool_url)
print(tools_name, tools_config)
qa = STQuestionAnswerer()
agent = qa.load_tools(tools_name, tools_config)
agent("Where is Tsinghua University?")
agent("List a few locations of KFC in Beijing.")

@ -0,0 +1,7 @@
from ..registry import register
@register("google_scholar")
def google_scholar():
from .api import build_tool
return build_tool

@ -0,0 +1,223 @@
import requests
import json
from datetime import date, datetime, timedelta
import os
from ..tool import Tool
from typing import Optional, List, Dict, Any
from serpapi import GoogleSearch
def build_tool(config) -> Tool:
tool = Tool(
"Google Scholar Info",
"Look up google scholar information",
name_for_model="Google_Scholar",
description_for_model="Plugin for look up Google Scholar information",
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
KEY = config["subscription_key"]
@tool.get("/search_google_scholar")
def search_google_scholar(
query: str,
engine: str = "google_scholar",
cites: Optional[str] = None,
as_ylo: Optional[int] = None,
as_yhi: Optional[int] = None,
scisbd: Optional[int] = None,
cluster: Optional[str] = None,
hl: Optional[str] = None,
lr: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
as_sdt: Optional[str] = None,
safe: Optional[str] = None,
filter: Optional[str] = None,
as_vis: Optional[str] = None,
):
"""
Search for scholarly articles based on a query according to the google scholar
:param query: The query to search for.
:param engine: The search engine to use, default is "google_scholar"
:param cites: The unique ID of an article for triggering "Cited By" searches
:param as_ylo: The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted)
:param as_yhi: The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted)
:param scisbd: Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything
:param cluster: The unique ID of an article for triggering "All Versions" searches
:param hl: The language to use for the Google Scholar search
:param lr: One or multiple languages to limit the search to
:param start: The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
:param num: The maximum number of results to return, limited to 20
:param as_sdt: Can be used either as a search type or a filter
:param safe: The level of filtering for adult content
:param filter: Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off
:param as_vis: Defines whether to include citations or not
:return: Return a list of dictionaries of the papers
"""
params = {
"q": query,
"engine": engine,
"api_key": KEY,
"cites": cites,
"as_ylo": as_ylo,
"as_yhi": as_yhi,
"scisbd": scisbd,
"cluster": cluster,
"hl": hl,
"lr": lr,
"start": start,
"num": num,
"as_sdt": as_sdt,
"safe": safe,
"filter": filter,
"as_vis": as_vis
}
search = GoogleSearch(params)
results = search.get_dict()
organic_results = results["organic_results"]
return organic_results
@tool.get("/search_author")
def search_author(
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None
):
"""
Search for an author using the Google Scholar Author API.
:param author_id: Required. The ID of an author.
:param hl: Optional. The language to use for the Google Scholar Author search. Default is 'en'.
:param view_op: Optional. Used for viewing specific parts of a page.
:param sort: Optional. Used for sorting and refining articles.
:param citation_id: Optional. Used for retrieving individual article citation.
:param start: Optional. Defines the result offset. Default is 0.
:param num: Optional. Defines the number of results to return. Default is 20.
:param no_cache: Optional. Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
:param async_req: Optional. Defines the way you want to submit your search to SerpApi. Default is False.
:param output: Optional. Defines the final output you want. Default is 'json'.
:return: Returns the search results of the author basic information.
"""
params = {
"engine": "google_scholar_author",
"author_id": author_id,
"api_key": KEY,
"hl": hl,
"view_op": view_op,
"sort": sort,
"citation_id": citation_id,
"start": start,
"num": num,
"no_cache": no_cache,
"async": async_req,
"output": output
}
search = GoogleSearch(params)
results = search.get_dict()
author = results["author"]
return author
@tool.get("/get_citation")
def get_citation(
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> Dict[str, Any]:
"""
Function to get citation results from the Google Scholar organic results using the Google Scholar Cite API.
Parameters:
q (str): ID of an individual Google Scholar organic search result.
engine (str, optional): Set to 'google_scholar_cite' to use the Google Scholar API engine. Defaults to 'google_scholar_cite'.
no_cache (bool, optional): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
async_ (bool, optional): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
api_key (str): SerpApi private key to use.
output (str, optional): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
Dict[str, Any]: Returns the search results in the specified format.
"""
params = {
"q": q,
"engine": 'google_scholar_cite',
"no_cache": no_cache,
"async": async_,
"api_key": KEY,
"output": output
}
search = GoogleSearch(params)
results = search.get_dict()
citation = results["citations"]
return citation
@tool.get("/get_profile")
def get_profile(self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json'
) -> Dict:
"""
The getProfile function is used to scrape profile results from the Google Scholar Profiles search page.
Args:
mauthors (str): Defines the author you want to search for.
hl (str, optional): Defines the language to use for the Google Scholar Profiles search.
It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French).
Defaults to 'en'.
after_author (str, optional): Defines the next page token.
It is used for retrieving the next page results.
The parameter has the precedence over before_author parameter. Defaults to None.
before_author (str, optional): Defines the previous page token.
It is used for retrieving the previous page results. Defaults to None.
no_cache (bool, optional): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present.
Defaults to False.
_async (bool, optional): Defines the way you want to submit your search to SerpApi. Defaults to False.
api_key (str): Defines the SerpApi private key to use.
output (str, optional): Defines the final output you want.
It can be set to 'json' (default) to get a structured JSON of the results,
or 'html' to get the raw html retrieved. Defaults to 'json'.
Returns:
Dict: The Google Scholar Profiles search results.
"""
params = {
'mauthors': mauthors,
'engine':'google_scholar_profiles',
'hl': hl,
'after_author': after_author,
'before_author': before_author,
'engine': 'google_scholar_profiles',
'no_cache': no_cache,
'async': _async,
'api_key': KEY,
'output': output
}
search = GoogleSearch(params)
results = search.get_dict()
profile = results['profiles']
return profile
return tool

@ -0,0 +1,7 @@
# Google Scholar Service
Contributor: [Kunlun Zhu](https://github.com/Kunlun-Zhu)
Before you use this tool: pip install google-serp-api
You can get your SerpAPI key here: https://serpapi.com/

@ -0,0 +1,11 @@
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
tool_name, tool_url = 'Google_Scholar', "http://127.0.0.1:8079/tools/google_scholar/"
tools_name, tools_config = load_single_tools(tool_name, tool_url)
print(tools_name, tools_config)
qa = STQuestionAnswerer()
agent = qa.load_tools(tools_name, tools_config)
agent("Search the profile of Jeffery Hinton?")

@ -0,0 +1,6 @@
from ..registry import register
@register("google_serper")
def google_serper():
from .api import build_tool
return build_tool

@ -0,0 +1,124 @@
import requests
import json
from ..tool import Tool
import os
from typing import Any, Dict, List, Optional
import aiohttp
from pydantic.main import BaseModel
from pydantic.class_validators import root_validator
class GoogleSerperAPIWrapper:
def __init__(self, subscription_key) -> None:
self.k: int = 10
self.gl: str = "us"
self.hl: str = "en"
self.type: str = "search" # type: search, images, places, news
self.tbs: Optional[str] = None
self.serper_api_key: str = subscription_key
self.aiosession: Optional[aiohttp.ClientSession] = None
def results(self, query: str, **kwargs: Any) -> Dict:
"""Run query through GoogleSearch."""
return self._google_serper_search_results(
query,
gl=self.gl,
hl=self.hl,
num=self.k,
tbs=self.tbs,
search_type=self.type,
**kwargs,
)
def run(self, query: str, **kwargs: Any) -> str:
"""Run query through GoogleSearch and parse result."""
results = self._google_serper_search_results(
query,
gl=self.gl,
hl=self.hl,
num=self.k,
tbs=self.tbs,
search_type=self.type,
**kwargs,
)
return self._parse_results(results)
def _parse_snippets(self, results: dict) -> List[str]:
snippets = []
if results.get("answerBox"):
answer_box = results.get("answerBox", {})
if answer_box.get("answer"):
return [answer_box.get("answer")]
elif answer_box.get("snippet"):
return [answer_box.get("snippet").replace("\n", " ")]
elif answer_box.get("snippetHighlighted"):
return answer_box.get("snippetHighlighted")
if results.get("knowledgeGraph"):
kg = results.get("knowledgeGraph", {})
title = kg.get("title")
entity_type = kg.get("type")
if entity_type:
snippets.append(f"{title}: {entity_type}.")
description = kg.get("description")
if description:
snippets.append(description)
for attribute, value in kg.get("attributes", {}).items():
snippets.append(f"{title} {attribute}: {value}.")
for result in results["organic"][: self.k]:
if "snippet" in result:
snippets.append(result["snippet"])
for attribute, value in result.get("attributes", {}).items():
snippets.append(f"{attribute}: {value}.")
if len(snippets) == 0:
return ["No good Google Search Result was found"]
return snippets
def _parse_results(self, results: dict) -> str:
return " ".join(self._parse_snippets(results))
def _google_serper_search_results(
self, search_term: str, search_type: str = "search", **kwargs: Any
) -> dict:
headers = {
"X-API-KEY": self.serper_api_key or "",
"Content-Type": "application/json",
}
params = {
"q": search_term,
**{key: value for key, value in kwargs.items() if value is not None},
}
response = requests.post(
f"https://google.serper.dev/{search_type}", headers=headers, params=params
)
response.raise_for_status()
search_results = response.json()
return search_results
def build_tool(config) -> Tool:
tool = Tool(
"google_serper",
"Look up for information from Serper.dev Google Search API",
name_for_model="google_serper",
description_for_model=(
"A low-cost Google Search API."
"Useful for when you need to answer questions about current events."
"Input should be a search query. Output is a JSON object of the query results"
),
logo_url="https://your-app-url.com/.well-known/logo.png",
contact_email="hello@contact.com",
legal_info_url="hello@legal.com"
)
api_wrapper = GoogleSerperAPIWrapper(config["subscription_key"])
@tool.get("/search_general")
def search_general(query : str):
"""Run query through GoogleSearch and parse result."""
return str(api_wrapper.run(query))
return tool

@ -0,0 +1,26 @@
# Google Serper Queries
Contributor: [Sihan Zhao](https://github.com/Sarah816)
## Tool Description
The "google_serper" tool allows you to fetch information using the Serper.dev Google Search API. This is a low-cost Google Search API, highly useful when you need to answer questions about current events. The input should be a search query and the output is a JSON object of the query results.
### Tool Specification
- **Name**: google_serper
- **Purpose**: Fetch information using the Serper.dev Google Search API
- **Model Name**: google_serper
- **Model Description**: A tool for querying the Serper.dev Google Search API
- **Logo URL**: [logo](https://your-app-url.com/.well-known/logo.png)
- **Contact Email**: hello@contact.com
- **Legal Information URL**: hello@legal.com
### Core Functionality
1. `search_general`
This method runs the query through GoogleSearch and parses the result. The result is a JSON object of the query results.
This tool uses a fictional `GoogleSerperAPIWrapper` class to interact with the Google Search API and perform the desired functionality. The actual implementation might need an API wrapper that can interact with the Google Search API.
It's important to note that although the Google Search API wrapper used in this example is fictional, in reality, you would need to find an actual API that can perform Google searches and provide search results. As of my knowledge cutoff in September 2021, Google does not publicly provide its Search API, so you might need to explore alternative methods to retrieve this information while ensuring compliance with Google's terms of service.

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save