You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/agents/models/palm.py

169 lines
5.0 KiB

from __future__ import annotations
import logging
from swarms.utils.logger import logger
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
import google.generativeai as genai
###############helpers
class ChatGooglePalmError(Exception):
"""Error raised when there is an issue with the Google PaLM API."""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
def _response_to_result(response: genai.types.ChatResponse, stop: Optional[List[str]]) -> Dict[str, Any]:
"""Convert a PaLM chat response to a result dictionary."""
result = {
"id": response.id,
"created": response.created,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
},
"choices": [],
}
for choice in response.choices:
result["choices"].append({
"text": _truncate_at_stop_tokens(choice.text, stop),
"index": choice.index,
"finish_reason": choice.finish_reason,
})
return result
def _messages_to_prompt_dict(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Convert a list of message dictionaries to a prompt dictionary."""
prompt = {"messages": []}
for message in messages:
prompt["messages"].append({
"role": message["role"],
"content": message["content"],
})
return prompt
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Create a retry decorator with exponential backoff."""
return retry(
retry=retry_if_exception_type(ChatGooglePalmError),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=30),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
####################### => main class
class ChatGooglePalm(BaseModel):
"""Wrapper around Google's PaLM Chat API."""
client: Any #: :meta private:
model_name: str = "models/chat-bison-001"
google_api_key: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
n: int = 1
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
# Same as before
pass
def chat_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
return self.client.chat(**kwargs)
return _chat_with_retry(**kwargs)
async def achat_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
return await self.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
def generate(
self,
messages: List[Dict[str, Any]],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = self.chat_with_retry(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
async def _agenerate(
self,
messages: List[Dict[str, Any]],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await self.achat_with_retry(
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
)
return _response_to_result(response, stop)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "google-palm-chat"