From 30633b83163d7c113111799bc903374dec186e28 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 26 Sep 2023 12:14:07 -0400 Subject: [PATCH] godmode example Former-commit-id: 33dec6c0ed2db64340b409bdd11ab8d255133187 --- example_godmode.py | 18 +- swarms/__init__.py | 7 +- swarms/models/palm.py | 352 ++++++++++++++++------------------ swarms/swarms/god_mode.py | 20 +- swarms/swarms/simple_swarm.py | 2 +- swarms/tools/autogpt.py | 4 +- 6 files changed, 193 insertions(+), 210 deletions(-) diff --git a/example_godmode.py b/example_godmode.py index 2525cab2..b1cbe42e 100644 --- a/example_godmode.py +++ b/example_godmode.py @@ -1,10 +1,22 @@ -from swarms import GodMode +from langchain.llms import GooglePalm, OpenAIChat + +from swarms.swarms.god_mode import Anthropic, GodMode + +claude = Anthropic(anthropic_api_key="") +palm = GooglePalm(google_api_key="") +gpt = OpenAIChat( + openai_api_key="" +) # Usage -llms = [Anthropic(model="", anthropic_api_key="my-api-key") for _ in range(5)] +llms = [ + claude, + palm, + gpt +] god_mode = GodMode(llms) -task = f"{anthropic.HUMAN_PROMPT} What are the biggest risks facing humanity?{anthropic.AI_PROMPT}" +task = f"What are the biggest risks facing humanity?" god_mode.print_responses(task) \ No newline at end of file diff --git a/swarms/__init__.py b/swarms/__init__.py index 02495710..8b24e4c6 100644 --- a/swarms/__init__.py +++ b/swarms/__init__.py @@ -6,16 +6,15 @@ print(logo2) # worker # from swarms.workers.worker_node import WorkerNode +# from swarms.workers.worker import Worker #boss # from swarms.boss.boss_node import Boss #models from swarms.models.anthropic import Anthropic - -# from swarms.models.palm import GooglePalm +from swarms.models.palm import GooglePalm from swarms.models.petals import Petals -from swarms.workers.worker import Worker #from swarms.models.openai import OpenAIChat #structs @@ -32,6 +31,6 @@ from swarms.swarms.multi_agent_debate import MultiAgentDebate #agents from swarms.swarms.profitpilot import ProfitPilot -from swarms.aot import AoTAgent +from swarms.agents.aot_agent import AOTAgent from swarms.agents.multi_modal_agent import MultiModalVisualAgent from swarms.agents.omni_modal_agent import OmniModalAgent \ No newline at end of file diff --git a/swarms/models/palm.py b/swarms/models/palm.py index 20eafd61..86b0dc85 100644 --- a/swarms/models/palm.py +++ b/swarms/models/palm.py @@ -1,189 +1,163 @@ -# from __future__ import annotations - -# import logging -# from swarms.utils.logger import logger -# from typing import Any, Callable, Dict, List, Optional - -# from pydantic import BaseModel, model_validator -# from tenacity import ( -# before_sleep_log, -# retry, -# retry_if_exception_type, -# stop_after_attempt, -# wait_exponential, -# ) - -# import google.generativeai as palm - - -# class GooglePalmError(Exception): -# """Error raised when there is an issue with the Google PaLM API.""" - -# def _truncate_at_stop_tokens( -# text: str, -# stop: Optional[List[str]], -# ) -> str: -# """Truncates text at the earliest stop token found.""" -# if stop is None: -# return text - -# for stop_token in stop: -# stop_token_idx = text.find(stop_token) -# if stop_token_idx != -1: -# text = text[:stop_token_idx] -# return text - - -# def _response_to_result(response: palm.types.ChatResponse, stop: Optional[List[str]]) -> Dict[str, Any]: -# """Convert a PaLM chat response to a result dictionary.""" -# result = { -# "id": response.id, -# "created": response.created, -# "model": response.model, -# "usage": { -# "prompt_tokens": response.usage.prompt_tokens, -# "completion_tokens": response.usage.completion_tokens, -# "total_tokens": response.usage.total_tokens, -# }, -# "choices": [], -# } -# for choice in response.choices: -# result["choices"].append({ -# "text": _truncate_at_stop_tokens(choice.text, stop), -# "index": choice.index, -# "finish_reason": choice.finish_reason, -# }) -# return result - -# def _messages_to_prompt_dict(messages: List[Dict[str, Any]]) -> Dict[str, Any]: -# """Convert a list of message dictionaries to a prompt dictionary.""" -# prompt = {"messages": []} -# for message in messages: -# prompt["messages"].append({ -# "role": message["role"], -# "content": message["content"], -# }) -# return prompt - - -# def _create_retry_decorator() -> Callable[[Any], Any]: -# """Create a retry decorator with exponential backoff.""" -# return retry( -# retry=retry_if_exception_type(GooglePalmError), -# stop=stop_after_attempt(5), -# wait=wait_exponential(multiplier=1, min=2, max=30), -# before_sleep=before_sleep_log(logger, logging.DEBUG), -# reraise=True, -# ) - - -# ####################### => main class -# class GooglePalm(BaseModel): -# """Wrapper around Google's PaLM Chat API.""" - -# client: Any #: :meta private: -# model_name: str = "models/chat-bison-001" -# google_api_key: Optional[str] = None -# temperature: Optional[float] = None -# top_p: Optional[float] = None -# top_k: Optional[int] = None -# n: int = 1 - -# @model_validator(mode="pre") -# def validate_environment(cls, values: Dict) -> Dict: -# # Same as before -# pass - -# def chat_with_retry(self, **kwargs: Any) -> Any: -# """Use tenacity to retry the completion call.""" -# retry_decorator = _create_retry_decorator() - -# @retry_decorator -# def _chat_with_retry(**kwargs: Any) -> Any: -# return self.client.chat(**kwargs) - -# return _chat_with_retry(**kwargs) - -# async def achat_with_retry(self, **kwargs: Any) -> Any: -# """Use tenacity to retry the async completion call.""" -# retry_decorator = _create_retry_decorator() - -# @retry_decorator -# async def _achat_with_retry(**kwargs: Any) -> Any: -# return await self.client.chat_async(**kwargs) - -# return await _achat_with_retry(**kwargs) - -# def __call__( -# self, -# messages: List[Dict[str, Any]], -# stop: Optional[List[str]] = None, -# **kwargs: Any, -# ) -> Dict[str, Any]: -# prompt = _messages_to_prompt_dict(messages) - -# response: palm.types.ChatResponse = self.chat_with_retry( -# model=self.model_name, -# prompt=prompt, -# temperature=self.temperature, -# top_p=self.top_p, -# top_k=self.top_k, -# candidate_count=self.n, -# **kwargs, -# ) - -# return _response_to_result(response, stop) - -# def generate( -# self, -# messages: List[Dict[str, Any]], -# stop: Optional[List[str]] = None, -# **kwargs: Any, -# ) -> Dict[str, Any]: -# prompt = _messages_to_prompt_dict(messages) - -# response: palm.types.ChatResponse = self.chat_with_retry( -# model=self.model_name, -# prompt=prompt, -# temperature=self.temperature, -# top_p=self.top_p, -# top_k=self.top_k, -# candidate_count=self.n, -# **kwargs, -# ) - -# return _response_to_result(response, stop) - -# async def _agenerate( -# self, -# messages: List[Dict[str, Any]], -# stop: Optional[List[str]] = None, -# **kwargs: Any, -# ) -> Dict[str, Any]: -# prompt = _messages_to_prompt_dict(messages) - -# response: palm.types.ChatResponse = await self.achat_with_retry( -# model=self.model_name, -# prompt=prompt, -# temperature=self.temperature, -# top_p=self.top_p, -# top_k=self.top_k, -# candidate_count=self.n, -# ) - -# return _response_to_result(response, stop) - -# @property -# def _identifying_params(self) -> Dict[str, Any]: -# """Get the identifying parameters.""" -# return { -# "model_name": self.model_name, -# "temperature": self.temperature, -# "top_p": self.top_p, -# "top_k": self.top_k, -# "n": self.n, -# } - -# @property -# def _llm_type(self) -> str: -# return "google-palm-chat" \ No newline at end of file +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms import BaseLLM +from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.schema import Generation, LLMResult +from langchain.utils import get_from_dict_or_env +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator() -> Callable[[Any], Any]: + """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" + try: + import google.api_core.exceptions + except ImportError: + raise ImportError( + "Could not import google-api-core python package. " + "Please install it with `pip install google-api-core`." + ) + + multiplier = 2 + min_seconds = 1 + max_seconds = 60 + max_retries = 10 + + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) + | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) + | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator() + + @retry_decorator + def _generate_with_retry(**kwargs: Any) -> Any: + return llm.client.generate_text(**kwargs) + + return _generate_with_retry(**kwargs) + + +def _strip_erroneous_leading_spaces(text: str) -> str: + """Strip erroneous leading spaces from text. + + The PaLM API will sometimes erroneously return a single leading space in all + lines > 1. This function strips that space. + """ + has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) + if has_leading_space: + return text.replace("\n ", "\n") + else: + return text + + +class GooglePalm(BaseLLM, BaseModel): + """Google PaLM models.""" + + client: Any #: :meta private: + google_api_key: Optional[str] + model_name: str = "models/text-bison-001" + """Model name to use.""" + temperature: float = 0.7 + """Run inference with this temperature. Must by in the closed interval + [0.0, 1.0].""" + top_p: Optional[float] = None + """Decode using nucleus sampling: consider the smallest set of tokens whose + probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" + top_k: Optional[int] = None + """Decode using top-k sampling: consider the set of top_k most probable tokens. + Must be positive.""" + max_output_tokens: Optional[int] = None + """Maximum number of tokens to include in a candidate. Must be greater than zero. + If unset, will default to 64.""" + n: int = 1 + """Number of chat completions to generate for each prompt. Note that the API may + not return the full n completions if duplicates are generated.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists.""" + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + try: + import google.generativeai as genai + + genai.configure(api_key=google_api_key) + except ImportError: + raise ImportError( + "Could not import google-generativeai python package. " + "Please install it with `pip install google-generativeai`." + ) + + values["client"] = genai + + if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: + raise ValueError("temperature must be in the range [0.0, 1.0]") + + if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: + raise ValueError("top_p must be in the range [0.0, 1.0]") + + if values["top_k"] is not None and values["top_k"] <= 0: + raise ValueError("top_k must be positive") + + if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: + raise ValueError("max_output_tokens must be greater than zero") + + return values + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + generations = [] + for prompt in prompts: + completion = generate_with_retry( + self, + model=self.model_name, + prompt=prompt, + stop_sequences=stop, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + max_output_tokens=self.max_output_tokens, + candidate_count=self.n, + **kwargs, + ) + + prompt_generations = [] + for candidate in completion.candidates: + raw_text = candidate["output"] + stripped_text = _strip_erroneous_leading_spaces(raw_text) + prompt_generations.append(Generation(text=stripped_text)) + generations.append(prompt_generations) + + return LLMResult(generations=generations) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "google_palm" \ No newline at end of file diff --git a/swarms/swarms/god_mode.py b/swarms/swarms/god_mode.py index 81e560e0..00f32be7 100644 --- a/swarms/swarms/god_mode.py +++ b/swarms/swarms/god_mode.py @@ -1,8 +1,7 @@ from concurrent.futures import ThreadPoolExecutor from termcolor import colored from tabulate import tabulate -import anthropic -from langchain.llms import Anthropic + class GodMode: def __init__(self, llms): @@ -18,11 +17,12 @@ class GodMode: table = [] for i, response in enumerate(responses): table.append([f"LLM {i+1}", response]) - print(colored(tabulate(table, headers=["LLM", "Response"], tablefmt="pretty"), "cyan")) - -# Usage -llms = [Anthropic(model="", anthropic_api_key="my-api-key") for _ in range(5)] - -god_mode = GodMode(llms) -task = f"{anthropic.HUMAN_PROMPT} What are the biggest risks facing humanity?{anthropic.AI_PROMPT}" -god_mode.print_responses(task) \ No newline at end of file + print( + colored( + tabulate( + table, + headers=["LLM", "Response"], + tablefmt="pretty" + ), "cyan" + ) + ) diff --git a/swarms/swarms/simple_swarm.py b/swarms/swarms/simple_swarm.py index 70b793c2..579c367c 100644 --- a/swarms/swarms/simple_swarm.py +++ b/swarms/swarms/simple_swarm.py @@ -1,4 +1,4 @@ -from swarms.worker.worker import Worker +from swarms.workers.worker import Worker class SimpleSwarm: def __init__( diff --git a/swarms/tools/autogpt.py b/swarms/tools/autogpt.py index 17ecc47a..2ed041c7 100644 --- a/swarms/tools/autogpt.py +++ b/swarms/tools/autogpt.py @@ -23,7 +23,6 @@ from swarms.utils.logger import logger from langchain.tools.file_management.write import WriteFileTool from langchain.tools.file_management.read import ReadFileTool -llm = ChatOpenAI(model_name="gpt-4", temperature=1.0) @contextmanager @@ -128,7 +127,7 @@ class WebpageQATool(BaseTool): async def _arun(self, url: str, question: str) -> str: raise NotImplementedError - +llm = ChatOpenAI(model_name="gpt-4", temperature=1.0) query_website_tool = WebpageQATool(qa_chain=load_qa_with_sources_chain(llm)) # !pip install duckduckgo_search @@ -142,7 +141,6 @@ query_website_tool = WebpageQATool(qa_chain=load_qa_with_sources_chain(llm)) # code_intepret = CodeInterpreter() import interpreter - @tool def compile(task: str): """