feat: add test

pull/282/head^2
Zack 2 years ago
parent d0f3bb5077
commit 287eb8526f

@ -124,7 +124,7 @@ chat_history = ""
MAX_SLEEP_TIME = 40
def download_model(model_url: str):
def download_model(model_url: str, memory_utilization: int):
# Extract model name from the URL
model_name = model_url.split('/')[-1]
# TODO continue debugging
@ -141,7 +141,7 @@ def download_model(model_url: str):
vllm_model = LLM(
model=model_url,
trust_remote_code=True,
device="cuda",
gpu_memory_utilization=memory_utilization,
)
available_models.append((model_name, vllm_model))
return gr.update(choices=available_models)
@ -326,19 +326,25 @@ with gr.Blocks() as demo:
with gr.Column(scale=0.15, min_width=0):
buttonChat = gr.Button("Chat")
CUDA_DEVICE = gr.Checkbox(label="CUDA Device:", placeholder="Enter CUDA device number", type="text")
MEMORY_UTILIZATION = gr.Slider(label="Memory Utilization:", min=0, max=1, step=0.1, default=0.5)
memory_utilization = gr.Slider(label="Memory Utilization:", min=0, max=1, step=0.1, default=0.5)
iface = gr.Interface(
fn=download_model,
inputs=["text", memory_utilization],
)
chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600)
buttonClear = gr.Button("Clear History")
buttonStop = gr.Button("Stop", visible=False)
with gr.Column(scale=1):
model_url = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text");
buttonDownload = gr.Button("Download Model");
buttonDownload.click(fn=download_model, inputs=[model_url]);
model_chosen = gr.Dropdown(
list(available_models), value=DEFAULTMODEL, multiselect=False, label="Model provided",
info="Choose the model to solve your question, Default means ChatGPT."
with gr.Column(scale=4):
with gr.Row():
with gr.Column(scale=1):
model_url = gr.Textbox(label="VLLM Model URL:", placeholder="URL to download VLLM model from Hugging Face", type="text");
buttonDownload = gr.Button("Download Model");
buttonDownload.click(fn=download_model, inputs=[model_url]);
model_chosen = gr.Dropdown(
list(available_models), value=DEFAULTMODEL, multiselect=False, label="Model provided",
info="Choose the model to solve your question, Default means ChatGPT."
)
with gr.Row():
tools_search = gr.Textbox(

@ -0,0 +1,31 @@
from unittest import result
import gradio_client
import pytest
import os
from unittest.mock import patch, MagicMock
from app import set_environ, load_tools, download_model
def test_set_environ():
@patch('app.LLM')
def test_download_model(mock_llm):
# Arrange
model_url = "facebook/opt-125m"
memory_utilization = 8
mock_model = MagicMock()
mock_llm.return_value = mock_model
# Act
result = download_model(model_url, memory_utilization)
# Assert
mock_llm.assert_called_once_with(model=model_url, trust_remote_code=True, gpu_memory_utilization=memory_utilization)
self.assertEqual(result, gradio_client.update(choices=[(model_url.split('/')[-1], mock_model)]))
def test_load_tools(self):
# Call the function
result = load_tools()
# Check if the function returns the expected result
assert result is not None
assert isinstance(result, list)
Loading…
Cancel
Save