[CODE QUALITY]

pull/362/head^2
Kye 1 year ago
parent e70b401b54
commit d34427a4f8

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "3.8.2" version = "3.8.5"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

@ -91,7 +91,6 @@ class AbstractVectorDatabase(ABC):
pass pass
@abstractmethod @abstractmethod
def get(self, query: str): def get(self, query: str):
""" """

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import functools
import logging import logging
import sys import sys
from typing import ( from typing import (
@ -16,6 +18,7 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -23,7 +26,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
@ -32,7 +35,17 @@ from langchain.utils import (
get_pydantic_field_names, get_pydantic_field_names,
) )
from langchain.utils.utils import build_extra_kwargs from langchain.utils.utils import build_extra_kwargs
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
from importlib.metadata import version from importlib.metadata import version
@ -41,6 +54,62 @@ from packaging.version import parse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
logger.error(msg)
def create_base_retry_decorator(
error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(
error_types[0]
)
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(
error
)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(
multiplier=1, min=min_seconds, max=max_seconds
),
retry=retry_instance,
before_sleep=_before_sleep,
)
def is_openai_v1() -> bool: def is_openai_v1() -> bool:
_version = parse(version("openai")) _version = parse(version("openai"))
return _version.major >= 1 return _version.major >= 1
@ -833,7 +902,7 @@ class OpenAIChat(BaseLLM):
""" """
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo-1106" model_name: str = "gpt-4-1106-preview"
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None openai_api_base: Optional[str] = None

@ -15,11 +15,11 @@ class MultiAgentRag:
db (AbstractVectorDatabase): Database used for querying. db (AbstractVectorDatabase): Database used for querying.
verbose (bool): Flag indicating whether to print verbose output. verbose (bool): Flag indicating whether to print verbose output.
""" """
agents: List[Agent] agents: List[Agent]
db: AbstractVectorDatabase db: AbstractVectorDatabase
verbose: bool = False verbose: bool = False
def query_database(self, query: str): def query_database(self, query: str):
""" """
Queries the database using the given query string. Queries the database using the given query string.
@ -36,7 +36,6 @@ class MultiAgentRag:
results.extend(agent_results) results.extend(agent_results)
return results return results
def get_agent_by_id(self, agent_id) -> Optional[Agent]: def get_agent_by_id(self, agent_id) -> Optional[Agent]:
""" """
Retrieves an agent from the multi-agent RAG by its ID. Retrieves an agent from the multi-agent RAG by its ID.
@ -53,11 +52,7 @@ class MultiAgentRag:
return None return None
def add_message( def add_message(
self, self, sender: Agent, message: str, *args, **kwargs
sender: Agent,
message: str,
*args,
**kwargs
): ):
""" """
Adds a message to the database. Adds a message to the database.
@ -75,12 +70,7 @@ class MultiAgentRag:
return self.db.add(doc) return self.db.add(doc)
def query( def query(self, message: str, *args, **kwargs):
self,
message: str,
*args,
**kwargs
):
""" """
Queries the database using the given message. Queries the database using the given message.
@ -93,5 +83,3 @@ class MultiAgentRag:
List: The list of results from the database. List: The list of results from the database.
""" """
return self.db.query(message) return self.db.query(message)

Loading…
Cancel
Save