feat: fix model loading

pull/282/head^2
Zack 1 year ago
parent 6f9d716250
commit d76ba69311

@ -127,8 +127,7 @@ chat_history = ""
MAX_SLEEP_TIME = 40 MAX_SLEEP_TIME = 40
def download_model(model_url: str, memory_utilization: int): def download_model(model_url: str, memory_utilization: int , model_dir: str):
# Extract model name from the URL
model_name = model_url.split('/')[-1] model_name = model_url.split('/')[-1]
# Download the model using VLLM # Download the model using VLLM
vllm_model = VLLM( vllm_model = VLLM(

File diff suppressed because one or more lines are too long

@ -1,3 +1,4 @@
from pathlib import Path
from langchain.llms import OpenAI from langchain.llms import OpenAI
from langchain import OpenAI, LLMChain, PromptTemplate, SerpAPIWrapper from langchain import OpenAI, LLMChain, PromptTemplate, SerpAPIWrapper
from langchain.agents import ZeroShotAgent, AgentExecutor, initialize_agent, Tool from langchain.agents import ZeroShotAgent, AgentExecutor, initialize_agent, Tool
@ -5,11 +6,13 @@ import importlib
import json import json
import os import os
import requests import requests
from vllm import LLM
import yaml import yaml
from .apitool import RequestTool from .apitool import RequestTool
from .executor import Executor, AgentExecutorWithTranslation from .executor import Executor, AgentExecutorWithTranslation
from swarms.utils import get_logger from swarms.utils import get_logger
from .BabyagiTools import BabyAGI from .BabyagiTools import BabyAGI
from langchain.llms import VLLM
# from .models.customllm import CustomLLM # from .models.customllm import CustomLLM
@ -62,6 +65,9 @@ def load_single_tools(tool_name, tool_url):
return tool_name, tool_config_json return tool_name, tool_config_json
# Read the model/ directory and get the list of models
model_dir = Path("./models/")
available_models = ["ChatGPT", "GPT-3.5"] + [f.name for f in model_dir.iterdir() if f.is_dir()]
class STQuestionAnswerer: class STQuestionAnswerer:
def __init__(self, openai_api_key="", stream_output=False, llm="ChatGPT"): def __init__(self, openai_api_key="", stream_output=False, llm="ChatGPT"):
@ -83,6 +89,8 @@ class STQuestionAnswerer:
self.llm = OpenAI( self.llm = OpenAI(
model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key
) # use chatgpt ) # use chatgpt
elif self.llm_model in available_models:
self.llm = VLLM(model=f"models/{self.llm_model}")
else: else:
raise RuntimeError("Your model is not available.") raise RuntimeError("Your model is not available.")

@ -12,6 +12,7 @@ from .executor import Executor, AgentExecutorWithTranslation
from vllm import LLM from vllm import LLM
from swarms.utils import get_logger from swarms.utils import get_logger
from pathlib import Path from pathlib import Path
from langchain.llms import VLLM
logger = get_logger(__name__) logger = get_logger(__name__)
@ -45,7 +46,7 @@ class MTQuestionAnswerer:
self.openai_api_key = openai_api_key self.openai_api_key = openai_api_key
self.stream_output = stream_output self.stream_output = stream_output
self.llm_model = llm self.llm_model = llm
self.model_path = model_path self.model = model_path
self.set_openai_api_key(openai_api_key) self.set_openai_api_key(openai_api_key)
self.load_tools(all_tools) self.load_tools(all_tools)
@ -59,7 +60,7 @@ class MTQuestionAnswerer:
model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key
) # use chatgpt ) # use chatgpt
elif self.llm_model in available_models: # If the selected model is a vLLM model elif self.llm_model in available_models: # If the selected model is a vLLM model
self.llm = LLM(model_path=f"model/{self.llm_model}") # Load the vLLM model self.llm = VLLM(model=f"models/{self.llm_model}")
else: else:
raise RuntimeError("Your model is not available.") raise RuntimeError("Your model is not available.")

Loading…
Cancel
Save