Squashed commit of the following:

commit df18590934 [formerly b0a609641c82dcbdff91e36267de47010c102980]
Author: Kye <kye@apacmediasolutions.com>
Date:   Tue Jul 18 11:51:40 2023 -0400

    clean up

commit 926f154b02 [formerly 5677d821e0376f1e1fcf9e4df889f5dede7b1c1a]
Author: Kye <kye@apacmediasolutions.com>
Date:   Tue Jul 18 11:49:07 2023 -0400

    clean up


Former-commit-id: 2bb11ff046
kyegomez-patch-1
Kye 2 years ago
parent 650804d4b7
commit f5e8667a02

@ -8,12 +8,14 @@ from langchain import PromptTemplate, HuggingFaceHub, ChatOpenAI, LLMChain
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLM:
def __init__(self,
openai_api_key: Optional[str] = None,
hf_repo_id: Optional[str] = None,
hf_api_token: Optional[str] = None,
model_kwargs: Optional[dict] = None):
temperature: Optional[float] = 0.5,
max_length: Optional[int] = 64):
# Check if keys are in the environment variables
openai_api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
@ -22,18 +24,23 @@ class LLM:
self.openai_api_key = openai_api_key
self.hf_repo_id = hf_repo_id
self.hf_api_token = hf_api_token
self.model_kwargs = model_kwargs if model_kwargs else {}
self.temperature = temperature
self.max_length = max_length
# If the HuggingFace API token is provided, set it in environment variables
if self.hf_api_token:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.hf_api_token
# Create the LLM object based on the provided keys
# Initialize the LLM object
self.initialize_llm()
def initialize_llm(self):
model_kwargs = {"temperature": self.temperature, "max_length": self.max_length}
try:
if self.hf_repo_id and self.hf_api_token:
self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=self.model_kwargs)
self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=model_kwargs)
elif self.openai_api_key:
self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=self.model_kwargs)
self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=model_kwargs)
else:
raise ValueError("Please provide either OpenAI API key or both HuggingFace repository ID and API token.")
except Exception as e:
@ -43,7 +50,6 @@ class LLM:
def run(self, prompt: str) -> str:
template = """Question: {question}
Answer: Let's think step by step."""
try:
prompt_template = PromptTemplate(template=template, input_variables=["question"])
llm_chain = LLMChain(prompt=prompt_template, llm=self.llm)
@ -62,3 +68,8 @@ class LLM:
# llm_instance = LLM(hf_repo_id="google/flan-t5-xl", hf_api_token="your_hf_api_token")
# result = llm_instance.run("Who won the FIFA World Cup in 1998?")
# print(result)
# make super easy to chaneg parameters, in class, use cpu and
#add qlora, 8bit inference
# look into adding deepspeed
Loading…
Cancel
Save