[CODE QUALITY]

pull/362/head^2
Kye 1 year ago
parent e70b401b54
commit d34427a4f8

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "3.8.2"
version = "3.8.5"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -91,7 +91,6 @@ class AbstractVectorDatabase(ABC):
pass
@abstractmethod
def get(self, query: str):
"""

@ -1,5 +1,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import sys
from typing import (
@ -16,6 +18,7 @@ from typing import (
Optional,
Set,
Tuple,
Type,
Union,
)
@ -23,7 +26,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk
@ -32,7 +35,17 @@ from langchain.utils import (
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
from importlib.metadata import version
@ -41,6 +54,62 @@ from packaging.version import parse
logger = logging.getLogger(__name__)
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
logger.error(msg)
def create_base_retry_decorator(
error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(
error_types[0]
)
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(
error
)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(
multiplier=1, min=min_seconds, max=max_seconds
),
retry=retry_instance,
before_sleep=_before_sleep,
)
def is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version.major >= 1
@ -833,7 +902,7 @@ class OpenAIChat(BaseLLM):
"""
client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo-1106"
model_name: str = "gpt-4-1106-preview"
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None

@ -15,11 +15,11 @@ class MultiAgentRag:
db (AbstractVectorDatabase): Database used for querying.
verbose (bool): Flag indicating whether to print verbose output.
"""
agents: List[Agent]
db: AbstractVectorDatabase
verbose: bool = False
def query_database(self, query: str):
"""
Queries the database using the given query string.
@ -36,7 +36,6 @@ class MultiAgentRag:
results.extend(agent_results)
return results
def get_agent_by_id(self, agent_id) -> Optional[Agent]:
"""
Retrieves an agent from the multi-agent RAG by its ID.
@ -53,11 +52,7 @@ class MultiAgentRag:
return None
def add_message(
self,
sender: Agent,
message: str,
*args,
**kwargs
self, sender: Agent, message: str, *args, **kwargs
):
"""
Adds a message to the database.
@ -75,12 +70,7 @@ class MultiAgentRag:
return self.db.add(doc)
def query(
self,
message: str,
*args,
**kwargs
):
def query(self, message: str, *args, **kwargs):
"""
Queries the database using the given message.
@ -93,5 +83,3 @@ class MultiAgentRag:
List: The list of results from the database.
"""
return self.db.query(message)

Loading…
Cancel
Save