From 8907b077af16ec25d0af7f3905baff1c0232492a Mon Sep 17 00:00:00 2001 From: Zack Date: Tue, 5 Dec 2023 17:47:24 -0800 Subject: [PATCH] feat: add vllm download option --- app.py | 45 ++++++++++++++++++++++++++++++++++----------- tool_server.py | 6 +++--- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/app.py b/app.py index 8526c58b..9bcdb51d 100644 --- a/app.py +++ b/app.py @@ -8,7 +8,7 @@ import warnings from swarms.modelui.modules.block_requests import OpenMonkeyPatch, RequestBlocker from swarms.modelui.modules.logging_colors import logger -from vllm import LLM, SamplingParams +from vllm import LLM os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['BITSANDBYTES_NOWELCOME'] = '1' @@ -59,6 +59,7 @@ from tool_server import run_tool_server from threading import Thread from multiprocessing import Process import time +from langchain.llms import VLLM tool_server_flag = False def start_tool_server(): @@ -123,6 +124,27 @@ chat_history = "" MAX_SLEEP_TIME = 40 +def download_model(model_url: str): + # Extract model name from the URL + model_name = model_url.split('/')[-1] + # response = requests.get(model_url, stream=True) + # total_size = int(response.headers.get('content-length', 0)) + # block_size = 1024 #1 Kibibyte + # progress_bar = gr.outputs.Progress_Bar(total_size) + # model_data = b"" + # for data in response.iter_content(block_size): + # model_data += data + # progress_bar.update(len(data)) + # yield progress_bar + # Save the model data to a file, or load it into a model here + vllm_model = LLM( + model=model_url, + trust_remote_code=True, + device="cuda", + ) + available_models.append((model_name, vllm_model)) + return gr.update(choices=available_models) + def load_tools(): global valid_tools_info global all_tools_list @@ -131,14 +153,9 @@ def load_tools(): except BaseException as e: print(repr(e)) all_tools_list = sorted(list(valid_tools_info.keys())) - # Download the VLLM model from the provided URL and add it to the dropdown array - vllm_model_url = os.environ.get("VLLM_MODEL_URL", "") - if vllm_model_url: - vllm_model = LLM.from_pretrained(vllm_model_url) - available_models.append(vllm_model) return gr.update(choices=all_tools_list) -def set_environ(OPENAI_API_KEY: str, +def set_environ(OPENAI_API_KEY: str = "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", WOLFRAMALPH_APP_ID: str = "", WEATHER_API_KEYS: str = "", BING_SUBSCRIPT_KEY: str = "", @@ -152,7 +169,8 @@ def set_environ(OPENAI_API_KEY: str, STEAMSHIP_API_KEY: str = "", HUGGINGFACE_API_KEY: str = "", AMADEUS_ID: str = "", - AMADEUS_KEY: str = "",): + AMADEUS_KEY: str = "", + VLLM_MODEL_URL: str = ""): os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY os.environ["WOLFRAMALPH_APP_ID"] = WOLFRAMALPH_APP_ID os.environ["WEATHER_API_KEYS"] = WEATHER_API_KEYS @@ -168,7 +186,6 @@ def set_environ(OPENAI_API_KEY: str, os.environ["HUGGINGFACE_API_KEY"] = HUGGINGFACE_API_KEY os.environ["AMADEUS_ID"] = AMADEUS_ID os.environ["AMADEUS_KEY"] = AMADEUS_KEY - os.environ["VLLM_MODEL_URL"] = VLLM_MODEL_URL if not tool_server_flag: start_tool_server() time.sleep(MAX_SLEEP_TIME) @@ -295,7 +312,6 @@ with gr.Blocks() as demo: HUGGINGFACE_API_KEY = gr.Textbox(label="Huggingface api key:", placeholder="Key to use models in huggingface hub", type="text") AMADEUS_ID = gr.Textbox(label="Amadeus id:", placeholder="Id to use Amadeus", type="text") AMADEUS_KEY = gr.Textbox(label="Amadeus key:", placeholder="Key to use Amadeus", type="text") - VLLM_MODEL_URL = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text") key_set_btn = gr.Button(value="Set keys!") @@ -309,11 +325,16 @@ with gr.Blocks() as demo: with gr.Column(scale=0.15, min_width=0): buttonChat = gr.Button("Chat") + CUDA_DEVICE = gr.Checkbox(label="CUDA Device:", placeholder="Enter CUDA device number", type="text") + MEMORY_UTILIZATION = gr.Slider(label="Memory Utilization:", min=0, max=1, step=0.1, default=0.5) chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600) buttonClear = gr.Button("Clear History") buttonStop = gr.Button("Stop", visible=False) with gr.Column(scale=1): + model_url = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text"); + buttonDownload = gr.Button("Download Model"); + buttonDownload.click(fn=download_model, inputs=[model_url]); model_chosen = gr.Dropdown( list(available_models), value=DEFAULTMODEL, multiselect=False, label="Model provided", info="Choose the model to solve your question, Default means ChatGPT." @@ -348,7 +369,6 @@ with gr.Blocks() as demo: HUGGINGFACE_API_KEY, AMADEUS_ID, AMADEUS_KEY, - VLLM_MODEL_URL, ], outputs=key_set_btn) key_set_btn.click(fn=load_tools, outputs=tools_chosen) @@ -365,3 +385,6 @@ with gr.Blocks() as demo: # demo.queue().launch(share=False, inbrowser=True, server_name="127.0.0.1", server_port=7001) demo.queue().launch() + + + diff --git a/tool_server.py b/tool_server.py index fe3d21de..b2b253d0 100644 --- a/tool_server.py +++ b/tool_server.py @@ -25,8 +25,8 @@ def run_tool_server(): # def load_wikidata_tool(): # server.load_tool("wikidata") - def load_travel_tool(): - server.load_tool("travel") + # def load_travel_tool(): + # server.load_tool("travel") # def load_wolframalpha_tool(): # WOLFRAMALPH_APP_ID = os.environ.get("WOLFRAMALPH_APP_ID", None) @@ -165,7 +165,7 @@ def run_tool_server(): # load_image_generation_tool() # load_hugging_tools() load_gradio_tools() - load_travel_tool() + # load_travel_tool() server.serve()