feat: add vllm download option

pull/282/head^2
Zack 1 year ago
parent a4a0cb7d63
commit 8907b077af

@ -8,7 +8,7 @@ import warnings
from swarms.modelui.modules.block_requests import OpenMonkeyPatch, RequestBlocker
from swarms.modelui.modules.logging_colors import logger
from vllm import LLM, SamplingParams
from vllm import LLM
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
@ -59,6 +59,7 @@ from tool_server import run_tool_server
from threading import Thread
from multiprocessing import Process
import time
from langchain.llms import VLLM
tool_server_flag = False
def start_tool_server():
@ -123,6 +124,27 @@ chat_history = ""
MAX_SLEEP_TIME = 40
def download_model(model_url: str):
# Extract model name from the URL
model_name = model_url.split('/')[-1]
# response = requests.get(model_url, stream=True)
# total_size = int(response.headers.get('content-length', 0))
# block_size = 1024 #1 Kibibyte
# progress_bar = gr.outputs.Progress_Bar(total_size)
# model_data = b""
# for data in response.iter_content(block_size):
# model_data += data
# progress_bar.update(len(data))
# yield progress_bar
# Save the model data to a file, or load it into a model here
vllm_model = LLM(
model=model_url,
trust_remote_code=True,
device="cuda",
)
available_models.append((model_name, vllm_model))
return gr.update(choices=available_models)
def load_tools():
global valid_tools_info
global all_tools_list
@ -131,14 +153,9 @@ def load_tools():
except BaseException as e:
print(repr(e))
all_tools_list = sorted(list(valid_tools_info.keys()))
# Download the VLLM model from the provided URL and add it to the dropdown array
vllm_model_url = os.environ.get("VLLM_MODEL_URL", "")
if vllm_model_url:
vllm_model = LLM.from_pretrained(vllm_model_url)
available_models.append(vllm_model)
return gr.update(choices=all_tools_list)
def set_environ(OPENAI_API_KEY: str,
def set_environ(OPENAI_API_KEY: str = "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",
WOLFRAMALPH_APP_ID: str = "",
WEATHER_API_KEYS: str = "",
BING_SUBSCRIPT_KEY: str = "",
@ -152,7 +169,8 @@ def set_environ(OPENAI_API_KEY: str,
STEAMSHIP_API_KEY: str = "",
HUGGINGFACE_API_KEY: str = "",
AMADEUS_ID: str = "",
AMADEUS_KEY: str = "",):
AMADEUS_KEY: str = "",
VLLM_MODEL_URL: str = ""):
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
os.environ["WOLFRAMALPH_APP_ID"] = WOLFRAMALPH_APP_ID
os.environ["WEATHER_API_KEYS"] = WEATHER_API_KEYS
@ -168,7 +186,6 @@ def set_environ(OPENAI_API_KEY: str,
os.environ["HUGGINGFACE_API_KEY"] = HUGGINGFACE_API_KEY
os.environ["AMADEUS_ID"] = AMADEUS_ID
os.environ["AMADEUS_KEY"] = AMADEUS_KEY
os.environ["VLLM_MODEL_URL"] = VLLM_MODEL_URL
if not tool_server_flag:
start_tool_server()
time.sleep(MAX_SLEEP_TIME)
@ -295,7 +312,6 @@ with gr.Blocks() as demo:
HUGGINGFACE_API_KEY = gr.Textbox(label="Huggingface api key:", placeholder="Key to use models in huggingface hub", type="text")
AMADEUS_ID = gr.Textbox(label="Amadeus id:", placeholder="Id to use Amadeus", type="text")
AMADEUS_KEY = gr.Textbox(label="Amadeus key:", placeholder="Key to use Amadeus", type="text")
VLLM_MODEL_URL = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text")
key_set_btn = gr.Button(value="Set keys!")
@ -309,11 +325,16 @@ with gr.Blocks() as demo:
with gr.Column(scale=0.15, min_width=0):
buttonChat = gr.Button("Chat")
CUDA_DEVICE = gr.Checkbox(label="CUDA Device:", placeholder="Enter CUDA device number", type="text")
MEMORY_UTILIZATION = gr.Slider(label="Memory Utilization:", min=0, max=1, step=0.1, default=0.5)
chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600)
buttonClear = gr.Button("Clear History")
buttonStop = gr.Button("Stop", visible=False)
with gr.Column(scale=1):
model_url = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text");
buttonDownload = gr.Button("Download Model");
buttonDownload.click(fn=download_model, inputs=[model_url]);
model_chosen = gr.Dropdown(
list(available_models), value=DEFAULTMODEL, multiselect=False, label="Model provided",
info="Choose the model to solve your question, Default means ChatGPT."
@ -348,7 +369,6 @@ with gr.Blocks() as demo:
HUGGINGFACE_API_KEY,
AMADEUS_ID,
AMADEUS_KEY,
VLLM_MODEL_URL,
], outputs=key_set_btn)
key_set_btn.click(fn=load_tools, outputs=tools_chosen)
@ -365,3 +385,6 @@ with gr.Blocks() as demo:
# demo.queue().launch(share=False, inbrowser=True, server_name="127.0.0.1", server_port=7001)
demo.queue().launch()

@ -25,8 +25,8 @@ def run_tool_server():
# def load_wikidata_tool():
# server.load_tool("wikidata")
def load_travel_tool():
server.load_tool("travel")
# def load_travel_tool():
# server.load_tool("travel")
# def load_wolframalpha_tool():
# WOLFRAMALPH_APP_ID = os.environ.get("WOLFRAMALPH_APP_ID", None)
@ -165,7 +165,7 @@ def run_tool_server():
# load_image_generation_tool()
# load_hugging_tools()
load_gradio_tools()
load_travel_tool()
# load_travel_tool()
server.serve()

Loading…
Cancel
Save