call functions for all llms to make it easer for call

Former-commit-id: d3b9d912ad
pull/47/head
Kye 1 year ago
parent 1eafd5ec49
commit a04c43d7ea

@ -30,6 +30,19 @@ class Anthropic:
return d
def generate(self, prompt, stop=None):
"""Call out to Anthropic's completion endpoint."""
stop = stop or []
params = self._default_params()
headers = {"Authorization": f"Bearer {self.anthropic_api_key}"}
data = {
"prompt": prompt,
"stop_sequences": stop,
**params
}
response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout)
return response.json().get("completion")
def __call__(self, prompt, stop=None):
"""Call out to Anthropic's completion endpoint."""
stop = stop or []
params = self._default_params()

@ -27,6 +27,19 @@ class HuggingFaceLLM:
except Exception as e:
self.logger.error(f"Failed to load the model or the tokenizer: {e}")
raise
def __call__(self, prompt_text: str, max_length: int = None):
max_length = max_length if max_length else self.max_length
try:
inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(inputs, max_length=max_length, do_sample=True)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
self.logger.error(f"Failed to generate the text: {e}")
raise
def generate(self, prompt_text: str, max_length: int = None):
max_length = max_length if max_length else self.max_length
try:

@ -97,4 +97,3 @@ class OpenAI:
#async
# async_responses = asyncio.run(chat.ask_multiple(['id1', 'id2'], "How is {id}"))
# print(async_responses)
#

@ -112,6 +112,26 @@ class GooglePalm(BaseModel):
return await self.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
def __call__(
self,
messages: List[Dict[str, Any]],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = self.chat_with_retry(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
def generate(
self,

Loading…
Cancel
Save