From a04c43d7ea294e05d7fb749e14faeb25a8ec005e Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 8 Aug 2023 11:54:20 -0400 Subject: [PATCH] call functions for all llms to make it easer for call Former-commit-id: d3b9d912adaee06ddf002959a75274ace12817bc --- swarms/agents/models/anthropic.py | 13 +++++++++++++ swarms/agents/models/huggingface.py | 13 +++++++++++++ swarms/agents/models/openai.py | 1 - swarms/agents/models/palm.py | 20 ++++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/swarms/agents/models/anthropic.py b/swarms/agents/models/anthropic.py index 803127a9..47b43db8 100644 --- a/swarms/agents/models/anthropic.py +++ b/swarms/agents/models/anthropic.py @@ -30,6 +30,19 @@ class Anthropic: return d def generate(self, prompt, stop=None): + """Call out to Anthropic's completion endpoint.""" + stop = stop or [] + params = self._default_params() + headers = {"Authorization": f"Bearer {self.anthropic_api_key}"} + data = { + "prompt": prompt, + "stop_sequences": stop, + **params + } + response = requests.post(f"{self.anthropic_api_url}/completions", headers=headers, json=data, timeout=self.default_request_timeout) + return response.json().get("completion") + + def __call__(self, prompt, stop=None): """Call out to Anthropic's completion endpoint.""" stop = stop or [] params = self._default_params() diff --git a/swarms/agents/models/huggingface.py b/swarms/agents/models/huggingface.py index a69a4efd..8f1cf5cc 100644 --- a/swarms/agents/models/huggingface.py +++ b/swarms/agents/models/huggingface.py @@ -27,6 +27,19 @@ class HuggingFaceLLM: except Exception as e: self.logger.error(f"Failed to load the model or the tokenizer: {e}") raise + + def __call__(self, prompt_text: str, max_length: int = None): + max_length = max_length if max_length else self.max_length + try: + inputs = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device) + with torch.no_grad(): + outputs = self.model.generate(inputs, max_length=max_length, do_sample=True) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + except Exception as e: + self.logger.error(f"Failed to generate the text: {e}") + raise + + def generate(self, prompt_text: str, max_length: int = None): max_length = max_length if max_length else self.max_length try: diff --git a/swarms/agents/models/openai.py b/swarms/agents/models/openai.py index 80bcbcbe..8edd5cd3 100644 --- a/swarms/agents/models/openai.py +++ b/swarms/agents/models/openai.py @@ -97,4 +97,3 @@ class OpenAI: #async # async_responses = asyncio.run(chat.ask_multiple(['id1', 'id2'], "How is {id}")) # print(async_responses) -# \ No newline at end of file diff --git a/swarms/agents/models/palm.py b/swarms/agents/models/palm.py index b91af5d5..a7ee6be2 100644 --- a/swarms/agents/models/palm.py +++ b/swarms/agents/models/palm.py @@ -112,6 +112,26 @@ class GooglePalm(BaseModel): return await self.client.chat_async(**kwargs) return await _achat_with_retry(**kwargs) + + def __call__( + self, + messages: List[Dict[str, Any]], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + prompt = _messages_to_prompt_dict(messages) + + response: genai.types.ChatResponse = self.chat_with_retry( + model=self.model_name, + prompt=prompt, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + candidate_count=self.n, + **kwargs, + ) + + return _response_to_result(response, stop) def generate( self,