Former-commit-id: 5677d821e0376f1e1fcf9e4df889f5dede7b1c1a
NewTools
Kye 1 year ago
parent 650804d4b7
commit 926f154b02

@ -13,7 +13,8 @@ class LLM:
openai_api_key: Optional[str] = None, openai_api_key: Optional[str] = None,
hf_repo_id: Optional[str] = None, hf_repo_id: Optional[str] = None,
hf_api_token: Optional[str] = None, hf_api_token: Optional[str] = None,
model_kwargs: Optional[dict] = None): temperature: Optional[float] = 0.5,
max_length: Optional[int] = 64):
# Check if keys are in the environment variables # Check if keys are in the environment variables
openai_api_key = openai_api_key or os.getenv('OPENAI_API_KEY') openai_api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
@ -22,18 +23,23 @@ class LLM:
self.openai_api_key = openai_api_key self.openai_api_key = openai_api_key
self.hf_repo_id = hf_repo_id self.hf_repo_id = hf_repo_id
self.hf_api_token = hf_api_token self.hf_api_token = hf_api_token
self.model_kwargs = model_kwargs if model_kwargs else {} self.temperature = temperature
self.max_length = max_length
# If the HuggingFace API token is provided, set it in environment variables # If the HuggingFace API token is provided, set it in environment variables
if self.hf_api_token: if self.hf_api_token:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.hf_api_token os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.hf_api_token
# Create the LLM object based on the provided keys #initialize the LLM project
self.initialize_llm()
def initialize_llm(self):
model_kwargs = {"temperature": self.temperature, "max_length": self.max_length}
try: try:
if self.hf_repo_id and self.hf_api_token: if self.hf_repo_id and self.hf_api_token:
self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=self.model_kwargs) self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=model_kwargs)
elif self.openai_api_key: elif self.openai_api_key:
self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=self.model_kwargs) self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=model_kwargs)
else: else:
raise ValueError("Please provide either OpenAI API key or both HuggingFace repository ID and API token.") raise ValueError("Please provide either OpenAI API key or both HuggingFace repository ID and API token.")
except Exception as e: except Exception as e:
@ -43,7 +49,6 @@ class LLM:
def run(self, prompt: str) -> str: def run(self, prompt: str) -> str:
template = """Question: {question} template = """Question: {question}
Answer: Let's think step by step.""" Answer: Let's think step by step."""
try: try:
prompt_template = PromptTemplate(template=template, input_variables=["question"]) prompt_template = PromptTemplate(template=template, input_variables=["question"])
llm_chain = LLMChain(prompt=prompt_template, llm=self.llm) llm_chain = LLMChain(prompt=prompt_template, llm=self.llm)
@ -62,3 +67,8 @@ class LLM:
# llm_instance = LLM(hf_repo_id="google/flan-t5-xl", hf_api_token="your_hf_api_token") # llm_instance = LLM(hf_repo_id="google/flan-t5-xl", hf_api_token="your_hf_api_token")
# result = llm_instance.run("Who won the FIFA World Cup in 1998?") # result = llm_instance.run("Who won the FIFA World Cup in 1998?")
# print(result) # print(result)
# make super easy to chaneg parameters, in class, use cpu and
#add qlora, 8bit inference
# look into adding deepspeed
Loading…
Cancel
Save