|
|
|
@ -8,6 +8,7 @@ from langchain import PromptTemplate, HuggingFaceHub, ChatOpenAI, LLMChain
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLM:
|
|
|
|
|
def __init__(self,
|
|
|
|
|
openai_api_key: Optional[str] = None,
|
|
|
|
@ -30,9 +31,9 @@ class LLM:
|
|
|
|
|
if self.hf_api_token:
|
|
|
|
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.hf_api_token
|
|
|
|
|
|
|
|
|
|
#initialize the LLM project
|
|
|
|
|
# Initialize the LLM object
|
|
|
|
|
self.initialize_llm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_llm(self):
|
|
|
|
|
model_kwargs = {"temperature": self.temperature, "max_length": self.max_length}
|
|
|
|
|
try:
|
|
|
|
|