From f5e8667a02b598418430d6ba79e45f36c6665a68 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 18 Jul 2023 11:59:04 -0400 Subject: [PATCH] Squashed commit of the following: commit df18590934bb3acf6736105bd1ff36793c65d562 [formerly b0a609641c82dcbdff91e36267de47010c102980] Author: Kye Date: Tue Jul 18 11:51:40 2023 -0400 clean up commit 926f154b028e83f496a39c2ef763112238c8775d [formerly 5677d821e0376f1e1fcf9e4df889f5dede7b1c1a] Author: Kye Date: Tue Jul 18 11:49:07 2023 -0400 clean up Former-commit-id: 2bb11ff0461eca973f0e2ffccc8ac8d6255cd94f --- swarms/utils/llm.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/swarms/utils/llm.py b/swarms/utils/llm.py index 0180ec74..31343b70 100644 --- a/swarms/utils/llm.py +++ b/swarms/utils/llm.py @@ -8,12 +8,14 @@ from langchain import PromptTemplate, HuggingFaceHub, ChatOpenAI, LLMChain logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class LLM: def __init__(self, openai_api_key: Optional[str] = None, hf_repo_id: Optional[str] = None, hf_api_token: Optional[str] = None, - model_kwargs: Optional[dict] = None): + temperature: Optional[float] = 0.5, + max_length: Optional[int] = 64): # Check if keys are in the environment variables openai_api_key = openai_api_key or os.getenv('OPENAI_API_KEY') @@ -22,18 +24,23 @@ class LLM: self.openai_api_key = openai_api_key self.hf_repo_id = hf_repo_id self.hf_api_token = hf_api_token - self.model_kwargs = model_kwargs if model_kwargs else {} + self.temperature = temperature + self.max_length = max_length # If the HuggingFace API token is provided, set it in environment variables if self.hf_api_token: os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.hf_api_token - # Create the LLM object based on the provided keys + # Initialize the LLM object + self.initialize_llm() + + def initialize_llm(self): + model_kwargs = {"temperature": self.temperature, "max_length": self.max_length} try: if self.hf_repo_id and self.hf_api_token: - self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=self.model_kwargs) + self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=model_kwargs) elif self.openai_api_key: - self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=self.model_kwargs) + self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=model_kwargs) else: raise ValueError("Please provide either OpenAI API key or both HuggingFace repository ID and API token.") except Exception as e: @@ -43,7 +50,6 @@ class LLM: def run(self, prompt: str) -> str: template = """Question: {question} Answer: Let's think step by step.""" - try: prompt_template = PromptTemplate(template=template, input_variables=["question"]) llm_chain = LLMChain(prompt=prompt_template, llm=self.llm) @@ -62,3 +68,8 @@ class LLM: # llm_instance = LLM(hf_repo_id="google/flan-t5-xl", hf_api_token="your_hf_api_token") # result = llm_instance.run("Who won the FIFA World Cup in 1998?") # print(result) + + +# make super easy to chaneg parameters, in class, use cpu and +#add qlora, 8bit inference +# look into adding deepspeed \ No newline at end of file