You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/models/palm.py

189 lines
5.8 KiB

# from __future__ import annotations
# import logging
# from swarms.utils.logger import logger
# from typing import Any, Callable, Dict, List, Optional
# from pydantic import BaseModel, model_validator
# from tenacity import (
# before_sleep_log,
# retry,
# retry_if_exception_type,
# stop_after_attempt,
# wait_exponential,
# )
# import google.generativeai as palm
# class GooglePalmError(Exception):
# """Error raised when there is an issue with the Google PaLM API."""
# def _truncate_at_stop_tokens(
# text: str,
# stop: Optional[List[str]],
# ) -> str:
# """Truncates text at the earliest stop token found."""
# if stop is None:
# return text
# for stop_token in stop:
# stop_token_idx = text.find(stop_token)
# if stop_token_idx != -1:
# text = text[:stop_token_idx]
# return text
# def _response_to_result(response: palm.types.ChatResponse, stop: Optional[List[str]]) -> Dict[str, Any]:
# """Convert a PaLM chat response to a result dictionary."""
# result = {
# "id": response.id,
# "created": response.created,
# "model": response.model,
# "usage": {
# "prompt_tokens": response.usage.prompt_tokens,
# "completion_tokens": response.usage.completion_tokens,
# "total_tokens": response.usage.total_tokens,
# },
# "choices": [],
# }
# for choice in response.choices:
# result["choices"].append({
# "text": _truncate_at_stop_tokens(choice.text, stop),
# "index": choice.index,
# "finish_reason": choice.finish_reason,
# })
# return result
# def _messages_to_prompt_dict(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
# """Convert a list of message dictionaries to a prompt dictionary."""
# prompt = {"messages": []}
# for message in messages:
# prompt["messages"].append({
# "role": message["role"],
# "content": message["content"],
# })
# return prompt
# def _create_retry_decorator() -> Callable[[Any], Any]:
# """Create a retry decorator with exponential backoff."""
# return retry(
# retry=retry_if_exception_type(GooglePalmError),
# stop=stop_after_attempt(5),
# wait=wait_exponential(multiplier=1, min=2, max=30),
# before_sleep=before_sleep_log(logger, logging.DEBUG),
# reraise=True,
# )
# ####################### => main class
# class GooglePalm(BaseModel):
# """Wrapper around Google's PaLM Chat API."""
# client: Any #: :meta private:
# model_name: str = "models/chat-bison-001"
# google_api_key: Optional[str] = None
# temperature: Optional[float] = None
# top_p: Optional[float] = None
# top_k: Optional[int] = None
# n: int = 1
# @model_validator(mode="pre")
# def validate_environment(cls, values: Dict) -> Dict:
# # Same as before
# pass
# def chat_with_retry(self, **kwargs: Any) -> Any:
# """Use tenacity to retry the completion call."""
# retry_decorator = _create_retry_decorator()
# @retry_decorator
# def _chat_with_retry(**kwargs: Any) -> Any:
# return self.client.chat(**kwargs)
# return _chat_with_retry(**kwargs)
# async def achat_with_retry(self, **kwargs: Any) -> Any:
# """Use tenacity to retry the async completion call."""
# retry_decorator = _create_retry_decorator()
# @retry_decorator
# async def _achat_with_retry(**kwargs: Any) -> Any:
# return await self.client.chat_async(**kwargs)
# return await _achat_with_retry(**kwargs)
# def __call__(
# self,
# messages: List[Dict[str, Any]],
# stop: Optional[List[str]] = None,
# **kwargs: Any,
# ) -> Dict[str, Any]:
# prompt = _messages_to_prompt_dict(messages)
# response: palm.types.ChatResponse = self.chat_with_retry(
# model=self.model_name,
# prompt=prompt,
# temperature=self.temperature,
# top_p=self.top_p,
# top_k=self.top_k,
# candidate_count=self.n,
# **kwargs,
# )
# return _response_to_result(response, stop)
# def generate(
# self,
# messages: List[Dict[str, Any]],
# stop: Optional[List[str]] = None,
# **kwargs: Any,
# ) -> Dict[str, Any]:
# prompt = _messages_to_prompt_dict(messages)
# response: palm.types.ChatResponse = self.chat_with_retry(
# model=self.model_name,
# prompt=prompt,
# temperature=self.temperature,
# top_p=self.top_p,
# top_k=self.top_k,
# candidate_count=self.n,
# **kwargs,
# )
# return _response_to_result(response, stop)
# async def _agenerate(
# self,
# messages: List[Dict[str, Any]],
# stop: Optional[List[str]] = None,
# **kwargs: Any,
# ) -> Dict[str, Any]:
# prompt = _messages_to_prompt_dict(messages)
# response: palm.types.ChatResponse = await self.achat_with_retry(
# model=self.model_name,
# prompt=prompt,
# temperature=self.temperature,
# top_p=self.top_p,
# top_k=self.top_k,
# candidate_count=self.n,
# )
# return _response_to_result(response, stop)
# @property
# def _identifying_params(self) -> Dict[str, Any]:
# """Get the identifying parameters."""
# return {
# "model_name": self.model_name,
# "temperature": self.temperature,
# "top_p": self.top_p,
# "top_k": self.top_k,
# "n": self.n,
# }
# @property
# def _llm_type(self) -> str:
# return "google-palm-chat"