From d34427a4f856b9e757f07e10d78ae23843bc282d Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 28 Jan 2024 09:53:53 -0500 Subject: [PATCH] [CODE QUALITY] --- playground/models/kosmos.py | 2 +- pyproject.toml | 2 +- swarms/memory/base_db.py | 2 +- swarms/memory/base_vectordatabase.py | 3 +- swarms/models/__init__.py | 2 +- swarms/models/openai_models.py | 73 +++++++++++++++++++++++++++- swarms/structs/conversation.py | 2 +- swarms/structs/multi_agent_rag.py | 46 +++++++----------- 8 files changed, 94 insertions(+), 38 deletions(-) diff --git a/playground/models/kosmos.py b/playground/models/kosmos.py index 3d0f1dd2..dbfd108f 100644 --- a/playground/models/kosmos.py +++ b/playground/models/kosmos.py @@ -7,4 +7,4 @@ model = Kosmos() out = model.run("Analyze the reciepts in this image", "docs.jpg") # Print the output -print(out) \ No newline at end of file +print(out) diff --git a/pyproject.toml b/pyproject.toml index cd5f9f74..31b72863 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "3.8.2" +version = "3.8.5" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/memory/base_db.py b/swarms/memory/base_db.py index bb0a2961..0501def7 100644 --- a/swarms/memory/base_db.py +++ b/swarms/memory/base_db.py @@ -156,4 +156,4 @@ class AbstractDatabase(ABC): """ - pass \ No newline at end of file + pass diff --git a/swarms/memory/base_vectordatabase.py b/swarms/memory/base_vectordatabase.py index 734c872a..06f42007 100644 --- a/swarms/memory/base_vectordatabase.py +++ b/swarms/memory/base_vectordatabase.py @@ -91,7 +91,6 @@ class AbstractVectorDatabase(ABC): pass - @abstractmethod def get(self, query: str): """ @@ -139,4 +138,4 @@ class AbstractVectorDatabase(ABC): """ - pass \ No newline at end of file + pass diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index a8fb119a..dfeb9cfe 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -48,7 +48,7 @@ from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402 from swarms.models.llava import LavaMultiModal # noqa: E402 from swarms.models.qwen import QwenVLMultiModal # noqa: E402 from swarms.models.clipq import CLIPQ # noqa: E402 -from swarms.models.kosmos_two import Kosmos # noqa: E402 +from swarms.models.kosmos_two import Kosmos # noqa: E402 from swarms.models.fuyu import Fuyu # noqa: E402 # from swarms.models.dalle3 import Dalle3 diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index f13657dc..b1aa0117 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import functools import logging import sys from typing import ( @@ -16,6 +18,7 @@ from typing import ( Optional, Set, Tuple, + Type, Union, ) @@ -23,7 +26,7 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import BaseLLM, create_base_retry_decorator +from langchain.llms.base import BaseLLM from langchain.pydantic_v1 import Field, root_validator from langchain.schema import Generation, LLMResult from langchain.schema.output import GenerationChunk @@ -32,7 +35,17 @@ from langchain.utils import ( get_pydantic_field_names, ) from langchain.utils.utils import build_extra_kwargs +from tenacity import ( + RetryCallState, + before_sleep_log, + retry, + retry_base, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) +logger = logging.getLogger(__name__) from importlib.metadata import version @@ -41,6 +54,62 @@ from packaging.version import parse logger = logging.getLogger(__name__) +@functools.lru_cache +def _log_error_once(msg: str) -> None: + """Log an error once.""" + logger.error(msg) + + +def create_base_retry_decorator( + error_types: List[Type[BaseException]], + max_retries: int = 1, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + """Create a retry decorator for a given LLM and provided list of error types.""" + + _logging = before_sleep_log(logger, logging.WARNING) + + def _before_sleep(retry_state: RetryCallState) -> None: + _logging(retry_state) + if run_manager: + if isinstance(run_manager, AsyncCallbackManagerForLLMRun): + coro = run_manager.on_retry(retry_state) + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(coro) + else: + asyncio.run(coro) + except Exception as e: + _log_error_once(f"Error in on_retry: {e}") + else: + run_manager.on_retry(retry_state) + return None + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + retry_instance: "retry_base" = retry_if_exception_type( + error_types[0] + ) + for error in error_types[1:]: + retry_instance = retry_instance | retry_if_exception_type( + error + ) + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential( + multiplier=1, min=min_seconds, max=max_seconds + ), + retry=retry_instance, + before_sleep=_before_sleep, + ) + + def is_openai_v1() -> bool: _version = parse(version("openai")) return _version.major >= 1 @@ -833,7 +902,7 @@ class OpenAIChat(BaseLLM): """ client: Any #: :meta private: - model_name: str = "gpt-3.5-turbo-1106" + model_name: str = "gpt-4-1106-preview" model_kwargs: Dict[str, Any] = Field(default_factory=dict) openai_api_key: Optional[str] = None openai_api_base: Optional[str] = None diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 9a2224a4..a59e4cf9 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -75,7 +75,7 @@ class Conversation(BaseStructure): self.autosave = autosave self.save_filepath = save_filepath self.conversation_history = [] - + # If system prompt is not None, add it to the conversation history if self.system_prompt: self.add("system", self.system_prompt) diff --git a/swarms/structs/multi_agent_rag.py b/swarms/structs/multi_agent_rag.py index 7b51332e..91d8c39d 100644 --- a/swarms/structs/multi_agent_rag.py +++ b/swarms/structs/multi_agent_rag.py @@ -9,24 +9,24 @@ from swarms.structs.agent import Agent class MultiAgentRag: """ Represents a multi-agent RAG (Relational Agent Graph) structure. - + Attributes: agents (List[Agent]): List of agents in the multi-agent RAG. db (AbstractVectorDatabase): Database used for querying. verbose (bool): Flag indicating whether to print verbose output. """ + agents: List[Agent] db: AbstractVectorDatabase verbose: bool = False - - + def query_database(self, query: str): """ Queries the database using the given query string. - + Args: query (str): The query string. - + Returns: List: The list of results from the database. """ @@ -35,15 +35,14 @@ class MultiAgentRag: agent_results = agent.long_term_memory_prompt(query) results.extend(agent_results) return results - - + def get_agent_by_id(self, agent_id) -> Optional[Agent]: """ Retrieves an agent from the multi-agent RAG by its ID. - + Args: agent_id: The ID of the agent to retrieve. - + Returns: Agent or None: The agent with the specified ID, or None if not found. """ @@ -51,47 +50,36 @@ class MultiAgentRag: if agent.agent_id == agent_id: return agent return None - + def add_message( - self, - sender: Agent, - message: str, - *args, - **kwargs + self, sender: Agent, message: str, *args, **kwargs ): """ Adds a message to the database. - + Args: sender (Agent): The agent sending the message. message (str): The message to add. *args: Additional positional arguments. **kwargs: Additional keyword arguments. - + Returns: int: The ID of the added message. """ doc = f"{sender.ai_name}: {message}" - + return self.db.add(doc) - - def query( - self, - message: str, - *args, - **kwargs - ): + + def query(self, message: str, *args, **kwargs): """ Queries the database using the given message. - + Args: message (str): The message to query. *args: Additional positional arguments. **kwargs: Additional keyword arguments. - + Returns: List: The list of results from the database. """ return self.db.query(message) - -