parent
7c328b5009
commit
30633b8316
@ -1,10 +1,22 @@
|
|||||||
from swarms import GodMode
|
from langchain.llms import GooglePalm, OpenAIChat
|
||||||
|
|
||||||
|
from swarms.swarms.god_mode import Anthropic, GodMode
|
||||||
|
|
||||||
|
claude = Anthropic(anthropic_api_key="")
|
||||||
|
palm = GooglePalm(google_api_key="")
|
||||||
|
gpt = OpenAIChat(
|
||||||
|
openai_api_key=""
|
||||||
|
)
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
llms = [Anthropic(model="<model_name>", anthropic_api_key="my-api-key") for _ in range(5)]
|
llms = [
|
||||||
|
claude,
|
||||||
|
palm,
|
||||||
|
gpt
|
||||||
|
]
|
||||||
|
|
||||||
god_mode = GodMode(llms)
|
god_mode = GodMode(llms)
|
||||||
|
|
||||||
task = f"{anthropic.HUMAN_PROMPT} What are the biggest risks facing humanity?{anthropic.AI_PROMPT}"
|
task = f"What are the biggest risks facing humanity?"
|
||||||
|
|
||||||
god_mode.print_responses(task)
|
god_mode.print_responses(task)
|
@ -1,189 +1,163 @@
|
|||||||
# from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# import logging
|
import logging
|
||||||
# from swarms.utils.logger import logger
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
# from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
# from pydantic import BaseModel, model_validator
|
from langchain.llms import BaseLLM
|
||||||
# from tenacity import (
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
# before_sleep_log,
|
from langchain.schema import Generation, LLMResult
|
||||||
# retry,
|
from langchain.utils import get_from_dict_or_env
|
||||||
# retry_if_exception_type,
|
from tenacity import (
|
||||||
# stop_after_attempt,
|
before_sleep_log,
|
||||||
# wait_exponential,
|
retry,
|
||||||
# )
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
# import google.generativeai as palm
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
# class GooglePalmError(Exception):
|
logger = logging.getLogger(__name__)
|
||||||
# """Error raised when there is an issue with the Google PaLM API."""
|
|
||||||
|
|
||||||
# def _truncate_at_stop_tokens(
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||||
# text: str,
|
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||||
# stop: Optional[List[str]],
|
try:
|
||||||
# ) -> str:
|
import google.api_core.exceptions
|
||||||
# """Truncates text at the earliest stop token found."""
|
except ImportError:
|
||||||
# if stop is None:
|
raise ImportError(
|
||||||
# return text
|
"Could not import google-api-core python package. "
|
||||||
|
"Please install it with `pip install google-api-core`."
|
||||||
# for stop_token in stop:
|
)
|
||||||
# stop_token_idx = text.find(stop_token)
|
|
||||||
# if stop_token_idx != -1:
|
multiplier = 2
|
||||||
# text = text[:stop_token_idx]
|
min_seconds = 1
|
||||||
# return text
|
max_seconds = 60
|
||||||
|
max_retries = 10
|
||||||
|
|
||||||
# def _response_to_result(response: palm.types.ChatResponse, stop: Optional[List[str]]) -> Dict[str, Any]:
|
return retry(
|
||||||
# """Convert a PaLM chat response to a result dictionary."""
|
reraise=True,
|
||||||
# result = {
|
stop=stop_after_attempt(max_retries),
|
||||||
# "id": response.id,
|
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||||
# "created": response.created,
|
retry=(
|
||||||
# "model": response.model,
|
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||||
# "usage": {
|
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||||
# "prompt_tokens": response.usage.prompt_tokens,
|
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||||
# "completion_tokens": response.usage.completion_tokens,
|
),
|
||||||
# "total_tokens": response.usage.total_tokens,
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
# },
|
)
|
||||||
# "choices": [],
|
|
||||||
# }
|
|
||||||
# for choice in response.choices:
|
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
||||||
# result["choices"].append({
|
"""Use tenacity to retry the completion call."""
|
||||||
# "text": _truncate_at_stop_tokens(choice.text, stop),
|
retry_decorator = _create_retry_decorator()
|
||||||
# "index": choice.index,
|
|
||||||
# "finish_reason": choice.finish_reason,
|
@retry_decorator
|
||||||
# })
|
def _generate_with_retry(**kwargs: Any) -> Any:
|
||||||
# return result
|
return llm.client.generate_text(**kwargs)
|
||||||
|
|
||||||
# def _messages_to_prompt_dict(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
return _generate_with_retry(**kwargs)
|
||||||
# """Convert a list of message dictionaries to a prompt dictionary."""
|
|
||||||
# prompt = {"messages": []}
|
|
||||||
# for message in messages:
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
# prompt["messages"].append({
|
"""Strip erroneous leading spaces from text.
|
||||||
# "role": message["role"],
|
|
||||||
# "content": message["content"],
|
The PaLM API will sometimes erroneously return a single leading space in all
|
||||||
# })
|
lines > 1. This function strips that space.
|
||||||
# return prompt
|
"""
|
||||||
|
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
||||||
|
if has_leading_space:
|
||||||
# def _create_retry_decorator() -> Callable[[Any], Any]:
|
return text.replace("\n ", "\n")
|
||||||
# """Create a retry decorator with exponential backoff."""
|
else:
|
||||||
# return retry(
|
return text
|
||||||
# retry=retry_if_exception_type(GooglePalmError),
|
|
||||||
# stop=stop_after_attempt(5),
|
|
||||||
# wait=wait_exponential(multiplier=1, min=2, max=30),
|
class GooglePalm(BaseLLM, BaseModel):
|
||||||
# before_sleep=before_sleep_log(logger, logging.DEBUG),
|
"""Google PaLM models."""
|
||||||
# reraise=True,
|
|
||||||
# )
|
client: Any #: :meta private:
|
||||||
|
google_api_key: Optional[str]
|
||||||
|
model_name: str = "models/text-bison-001"
|
||||||
# ####################### => main class
|
"""Model name to use."""
|
||||||
# class GooglePalm(BaseModel):
|
temperature: float = 0.7
|
||||||
# """Wrapper around Google's PaLM Chat API."""
|
"""Run inference with this temperature. Must by in the closed interval
|
||||||
|
[0.0, 1.0]."""
|
||||||
# client: Any #: :meta private:
|
top_p: Optional[float] = None
|
||||||
# model_name: str = "models/chat-bison-001"
|
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||||
# google_api_key: Optional[str] = None
|
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||||
# temperature: Optional[float] = None
|
top_k: Optional[int] = None
|
||||||
# top_p: Optional[float] = None
|
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||||
# top_k: Optional[int] = None
|
Must be positive."""
|
||||||
# n: int = 1
|
max_output_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
||||||
# @model_validator(mode="pre")
|
If unset, will default to 64."""
|
||||||
# def validate_environment(cls, values: Dict) -> Dict:
|
n: int = 1
|
||||||
# # Same as before
|
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||||
# pass
|
not return the full n completions if duplicates are generated."""
|
||||||
|
|
||||||
# def chat_with_retry(self, **kwargs: Any) -> Any:
|
@root_validator()
|
||||||
# """Use tenacity to retry the completion call."""
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
# retry_decorator = _create_retry_decorator()
|
"""Validate api key, python package exists."""
|
||||||
|
google_api_key = get_from_dict_or_env(
|
||||||
# @retry_decorator
|
values, "google_api_key", "GOOGLE_API_KEY"
|
||||||
# def _chat_with_retry(**kwargs: Any) -> Any:
|
)
|
||||||
# return self.client.chat(**kwargs)
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
# return _chat_with_retry(**kwargs)
|
|
||||||
|
genai.configure(api_key=google_api_key)
|
||||||
# async def achat_with_retry(self, **kwargs: Any) -> Any:
|
except ImportError:
|
||||||
# """Use tenacity to retry the async completion call."""
|
raise ImportError(
|
||||||
# retry_decorator = _create_retry_decorator()
|
"Could not import google-generativeai python package. "
|
||||||
|
"Please install it with `pip install google-generativeai`."
|
||||||
# @retry_decorator
|
)
|
||||||
# async def _achat_with_retry(**kwargs: Any) -> Any:
|
|
||||||
# return await self.client.chat_async(**kwargs)
|
values["client"] = genai
|
||||||
|
|
||||||
# return await _achat_with_retry(**kwargs)
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||||
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
# def __call__(
|
|
||||||
# self,
|
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||||
# messages: List[Dict[str, Any]],
|
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||||
# stop: Optional[List[str]] = None,
|
|
||||||
# **kwargs: Any,
|
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||||
# ) -> Dict[str, Any]:
|
raise ValueError("top_k must be positive")
|
||||||
# prompt = _messages_to_prompt_dict(messages)
|
|
||||||
|
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
|
||||||
# response: palm.types.ChatResponse = self.chat_with_retry(
|
raise ValueError("max_output_tokens must be greater than zero")
|
||||||
# model=self.model_name,
|
|
||||||
# prompt=prompt,
|
return values
|
||||||
# temperature=self.temperature,
|
|
||||||
# top_p=self.top_p,
|
def _generate(
|
||||||
# top_k=self.top_k,
|
self,
|
||||||
# candidate_count=self.n,
|
prompts: List[str],
|
||||||
# **kwargs,
|
stop: Optional[List[str]] = None,
|
||||||
# )
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
# return _response_to_result(response, stop)
|
) -> LLMResult:
|
||||||
|
generations = []
|
||||||
# def generate(
|
for prompt in prompts:
|
||||||
# self,
|
completion = generate_with_retry(
|
||||||
# messages: List[Dict[str, Any]],
|
self,
|
||||||
# stop: Optional[List[str]] = None,
|
model=self.model_name,
|
||||||
# **kwargs: Any,
|
prompt=prompt,
|
||||||
# ) -> Dict[str, Any]:
|
stop_sequences=stop,
|
||||||
# prompt = _messages_to_prompt_dict(messages)
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
# response: palm.types.ChatResponse = self.chat_with_retry(
|
top_k=self.top_k,
|
||||||
# model=self.model_name,
|
max_output_tokens=self.max_output_tokens,
|
||||||
# prompt=prompt,
|
candidate_count=self.n,
|
||||||
# temperature=self.temperature,
|
**kwargs,
|
||||||
# top_p=self.top_p,
|
)
|
||||||
# top_k=self.top_k,
|
|
||||||
# candidate_count=self.n,
|
prompt_generations = []
|
||||||
# **kwargs,
|
for candidate in completion.candidates:
|
||||||
# )
|
raw_text = candidate["output"]
|
||||||
|
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
||||||
# return _response_to_result(response, stop)
|
prompt_generations.append(Generation(text=stripped_text))
|
||||||
|
generations.append(prompt_generations)
|
||||||
# async def _agenerate(
|
|
||||||
# self,
|
return LLMResult(generations=generations)
|
||||||
# messages: List[Dict[str, Any]],
|
|
||||||
# stop: Optional[List[str]] = None,
|
@property
|
||||||
# **kwargs: Any,
|
def _llm_type(self) -> str:
|
||||||
# ) -> Dict[str, Any]:
|
"""Return type of llm."""
|
||||||
# prompt = _messages_to_prompt_dict(messages)
|
return "google_palm"
|
||||||
|
|
||||||
# response: palm.types.ChatResponse = await self.achat_with_retry(
|
|
||||||
# model=self.model_name,
|
|
||||||
# prompt=prompt,
|
|
||||||
# temperature=self.temperature,
|
|
||||||
# top_p=self.top_p,
|
|
||||||
# top_k=self.top_k,
|
|
||||||
# candidate_count=self.n,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return _response_to_result(response, stop)
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
# """Get the identifying parameters."""
|
|
||||||
# return {
|
|
||||||
# "model_name": self.model_name,
|
|
||||||
# "temperature": self.temperature,
|
|
||||||
# "top_p": self.top_p,
|
|
||||||
# "top_k": self.top_k,
|
|
||||||
# "n": self.n,
|
|
||||||
# }
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def _llm_type(self) -> str:
|
|
||||||
# return "google-palm-chat"
|
|
Loading…
Reference in new issue