call functions for all llms to make it easer for call

pull/43/head
Kye 1 year ago
parent 9493604f63
commit d3b9d912ad

@ -30,6 +30,19 @@ class Anthropic:
return d return d
def generate(self, prompt, stop=None): def generate(self, prompt, stop=None):
"""Call out to Anthropic's completion endpoint."""
stop = stop or []
params = self._default_params()
headers = {"Authorization": f"Bearer {self.anthropic_api_key}"}
data = {
"prompt": prompt,
"stop_sequences": stop,
**params
}
response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout)
return response.json().get("completion")
def __call__(self, prompt, stop=None):
"""Call out to Anthropic's completion endpoint.""" """Call out to Anthropic's completion endpoint."""
stop = stop or [] stop = stop or []
params = self._default_params() params = self._default_params()

@ -27,6 +27,19 @@ class HuggingFaceLLM:
except Exception as e: except Exception as e:
self.logger.error(f"Failed to load the model or the tokenizer: {e}") self.logger.error(f"Failed to load the model or the tokenizer: {e}")
raise raise
def __call__(self, prompt_text: str, max_length: int = None):
max_length = max_length if max_length else self.max_length
try:
inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(inputs, max_length=max_length, do_sample=True)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
self.logger.error(f"Failed to generate the text: {e}")
raise
def generate(self, prompt_text: str, max_length: int = None): def generate(self, prompt_text: str, max_length: int = None):
max_length = max_length if max_length else self.max_length max_length = max_length if max_length else self.max_length
try: try:

@ -97,4 +97,3 @@ class OpenAI:
#async #async
# async_responses = asyncio.run(chat.ask_multiple(['id1', 'id2'], "How is {id}")) # async_responses = asyncio.run(chat.ask_multiple(['id1', 'id2'], "How is {id}"))
# print(async_responses) # print(async_responses)
#

@ -112,6 +112,26 @@ class GooglePalm(BaseModel):
return await self.client.chat_async(**kwargs) return await self.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs) return await _achat_with_retry(**kwargs)
def __call__(
self,
messages: List[Dict[str, Any]],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = self.chat_with_retry(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
def generate( def generate(
self, self,

Loading…
Cancel
Save