parent
7c328b5009
commit
30633b8316
@ -1,10 +1,22 @@
|
||||
from swarms import GodMode
|
||||
from langchain.llms import GooglePalm, OpenAIChat
|
||||
|
||||
from swarms.swarms.god_mode import Anthropic, GodMode
|
||||
|
||||
claude = Anthropic(anthropic_api_key="")
|
||||
palm = GooglePalm(google_api_key="")
|
||||
gpt = OpenAIChat(
|
||||
openai_api_key=""
|
||||
)
|
||||
|
||||
# Usage
|
||||
llms = [Anthropic(model="<model_name>", anthropic_api_key="my-api-key") for _ in range(5)]
|
||||
llms = [
|
||||
claude,
|
||||
palm,
|
||||
gpt
|
||||
]
|
||||
|
||||
god_mode = GodMode(llms)
|
||||
|
||||
task = f"{anthropic.HUMAN_PROMPT} What are the biggest risks facing humanity?{anthropic.AI_PROMPT}"
|
||||
task = f"What are the biggest risks facing humanity?"
|
||||
|
||||
god_mode.print_responses(task)
|
@ -1,189 +1,163 @@
|
||||
# from __future__ import annotations
|
||||
|
||||
# import logging
|
||||
# from swarms.utils.logger import logger
|
||||
# from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
# from pydantic import BaseModel, model_validator
|
||||
# from tenacity import (
|
||||
# before_sleep_log,
|
||||
# retry,
|
||||
# retry_if_exception_type,
|
||||
# stop_after_attempt,
|
||||
# wait_exponential,
|
||||
# )
|
||||
|
||||
# import google.generativeai as palm
|
||||
|
||||
|
||||
# class GooglePalmError(Exception):
|
||||
# """Error raised when there is an issue with the Google PaLM API."""
|
||||
|
||||
# def _truncate_at_stop_tokens(
|
||||
# text: str,
|
||||
# stop: Optional[List[str]],
|
||||
# ) -> str:
|
||||
# """Truncates text at the earliest stop token found."""
|
||||
# if stop is None:
|
||||
# return text
|
||||
|
||||
# for stop_token in stop:
|
||||
# stop_token_idx = text.find(stop_token)
|
||||
# if stop_token_idx != -1:
|
||||
# text = text[:stop_token_idx]
|
||||
# return text
|
||||
|
||||
|
||||
# def _response_to_result(response: palm.types.ChatResponse, stop: Optional[List[str]]) -> Dict[str, Any]:
|
||||
# """Convert a PaLM chat response to a result dictionary."""
|
||||
# result = {
|
||||
# "id": response.id,
|
||||
# "created": response.created,
|
||||
# "model": response.model,
|
||||
# "usage": {
|
||||
# "prompt_tokens": response.usage.prompt_tokens,
|
||||
# "completion_tokens": response.usage.completion_tokens,
|
||||
# "total_tokens": response.usage.total_tokens,
|
||||
# },
|
||||
# "choices": [],
|
||||
# }
|
||||
# for choice in response.choices:
|
||||
# result["choices"].append({
|
||||
# "text": _truncate_at_stop_tokens(choice.text, stop),
|
||||
# "index": choice.index,
|
||||
# "finish_reason": choice.finish_reason,
|
||||
# })
|
||||
# return result
|
||||
|
||||
# def _messages_to_prompt_dict(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
# """Convert a list of message dictionaries to a prompt dictionary."""
|
||||
# prompt = {"messages": []}
|
||||
# for message in messages:
|
||||
# prompt["messages"].append({
|
||||
# "role": message["role"],
|
||||
# "content": message["content"],
|
||||
# })
|
||||
# return prompt
|
||||
|
||||
|
||||
# def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
# """Create a retry decorator with exponential backoff."""
|
||||
# return retry(
|
||||
# retry=retry_if_exception_type(GooglePalmError),
|
||||
# stop=stop_after_attempt(5),
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=30),
|
||||
# before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
# reraise=True,
|
||||
# )
|
||||
|
||||
|
||||
# ####################### => main class
|
||||
# class GooglePalm(BaseModel):
|
||||
# """Wrapper around Google's PaLM Chat API."""
|
||||
|
||||
# client: Any #: :meta private:
|
||||
# model_name: str = "models/chat-bison-001"
|
||||
# google_api_key: Optional[str] = None
|
||||
# temperature: Optional[float] = None
|
||||
# top_p: Optional[float] = None
|
||||
# top_k: Optional[int] = None
|
||||
# n: int = 1
|
||||
|
||||
# @model_validator(mode="pre")
|
||||
# def validate_environment(cls, values: Dict) -> Dict:
|
||||
# # Same as before
|
||||
# pass
|
||||
|
||||
# def chat_with_retry(self, **kwargs: Any) -> Any:
|
||||
# """Use tenacity to retry the completion call."""
|
||||
# retry_decorator = _create_retry_decorator()
|
||||
|
||||
# @retry_decorator
|
||||
# def _chat_with_retry(**kwargs: Any) -> Any:
|
||||
# return self.client.chat(**kwargs)
|
||||
|
||||
# return _chat_with_retry(**kwargs)
|
||||
|
||||
# async def achat_with_retry(self, **kwargs: Any) -> Any:
|
||||
# """Use tenacity to retry the async completion call."""
|
||||
# retry_decorator = _create_retry_decorator()
|
||||
|
||||
# @retry_decorator
|
||||
# async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||
# return await self.client.chat_async(**kwargs)
|
||||
|
||||
# return await _achat_with_retry(**kwargs)
|
||||
|
||||
# def __call__(
|
||||
# self,
|
||||
# messages: List[Dict[str, Any]],
|
||||
# stop: Optional[List[str]] = None,
|
||||
# **kwargs: Any,
|
||||
# ) -> Dict[str, Any]:
|
||||
# prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
# response: palm.types.ChatResponse = self.chat_with_retry(
|
||||
# model=self.model_name,
|
||||
# prompt=prompt,
|
||||
# temperature=self.temperature,
|
||||
# top_p=self.top_p,
|
||||
# top_k=self.top_k,
|
||||
# candidate_count=self.n,
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
# return _response_to_result(response, stop)
|
||||
|
||||
# def generate(
|
||||
# self,
|
||||
# messages: List[Dict[str, Any]],
|
||||
# stop: Optional[List[str]] = None,
|
||||
# **kwargs: Any,
|
||||
# ) -> Dict[str, Any]:
|
||||
# prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
# response: palm.types.ChatResponse = self.chat_with_retry(
|
||||
# model=self.model_name,
|
||||
# prompt=prompt,
|
||||
# temperature=self.temperature,
|
||||
# top_p=self.top_p,
|
||||
# top_k=self.top_k,
|
||||
# candidate_count=self.n,
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
# return _response_to_result(response, stop)
|
||||
|
||||
# async def _agenerate(
|
||||
# self,
|
||||
# messages: List[Dict[str, Any]],
|
||||
# stop: Optional[List[str]] = None,
|
||||
# **kwargs: Any,
|
||||
# ) -> Dict[str, Any]:
|
||||
# prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
# response: palm.types.ChatResponse = await self.achat_with_retry(
|
||||
# model=self.model_name,
|
||||
# prompt=prompt,
|
||||
# temperature=self.temperature,
|
||||
# top_p=self.top_p,
|
||||
# top_k=self.top_k,
|
||||
# candidate_count=self.n,
|
||||
# )
|
||||
|
||||
# return _response_to_result(response, stop)
|
||||
|
||||
# @property
|
||||
# def _identifying_params(self) -> Dict[str, Any]:
|
||||
# """Get the identifying parameters."""
|
||||
# return {
|
||||
# "model_name": self.model_name,
|
||||
# "temperature": self.temperature,
|
||||
# "top_p": self.top_p,
|
||||
# "top_k": self.top_k,
|
||||
# "n": self.n,
|
||||
# }
|
||||
|
||||
# @property
|
||||
# def _llm_type(self) -> str:
|
||||
# return "google-palm-chat"
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
try:
|
||||
import google.api_core.exceptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-api-core python package. "
|
||||
"Please install it with `pip install google-api-core`."
|
||||
)
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _generate_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(**kwargs)
|
||||
|
||||
return _generate_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
"""Strip erroneous leading spaces from text.
|
||||
|
||||
The PaLM API will sometimes erroneously return a single leading space in all
|
||||
lines > 1. This function strips that space.
|
||||
"""
|
||||
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
||||
if has_leading_space:
|
||||
return text.replace("\n ", "\n")
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
class GooglePalm(BaseLLM, BaseModel):
|
||||
"""Google PaLM models."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[str]
|
||||
model_name: str = "models/text-bison-001"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""Run inference with this temperature. Must by in the closed interval
|
||||
[0.0, 1.0]."""
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
max_output_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||
If unset, will default to 64."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`."
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
|
||||
raise ValueError("max_output_tokens must be greater than zero")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
completion = generate_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stop_sequences=stop,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt_generations = []
|
||||
for candidate in completion.candidates:
|
||||
raw_text = candidate["output"]
|
||||
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||
prompt_generations.append(Generation(text=stripped_text))
|
||||
generations.append(prompt_generations)
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "google_palm"
|
Loading…
Reference in new issue