[CODE QUALITY]

pull/362/head^2
Kye 1 year ago
parent e70b401b54
commit d34427a4f8

@ -7,4 +7,4 @@ model = Kosmos()
out = model.run("Analyze the reciepts in this image", "docs.jpg")
# Print the output
print(out)
print(out)

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "3.8.2"
version = "3.8.5"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -156,4 +156,4 @@ class AbstractDatabase(ABC):
"""
pass
pass

@ -91,7 +91,6 @@ class AbstractVectorDatabase(ABC):
pass
@abstractmethod
def get(self, query: str):
"""
@ -139,4 +138,4 @@ class AbstractVectorDatabase(ABC):
"""
pass
pass

@ -48,7 +48,7 @@ from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402
from swarms.models.llava import LavaMultiModal # noqa: E402
from swarms.models.qwen import QwenVLMultiModal # noqa: E402
from swarms.models.clipq import CLIPQ # noqa: E402
from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.fuyu import Fuyu # noqa: E402
# from swarms.models.dalle3 import Dalle3

@ -1,5 +1,7 @@
from __future__ import annotations
import asyncio
import functools
import logging
import sys
from typing import (
@ -16,6 +18,7 @@ from typing import (
Optional,
Set,
Tuple,
Type,
Union,
)
@ -23,7 +26,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk
@ -32,7 +35,17 @@ from langchain.utils import (
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
from importlib.metadata import version
@ -41,6 +54,62 @@ from packaging.version import parse
logger = logging.getLogger(__name__)
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
logger.error(msg)
def create_base_retry_decorator(
error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(
error_types[0]
)
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(
error
)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(
multiplier=1, min=min_seconds, max=max_seconds
),
retry=retry_instance,
before_sleep=_before_sleep,
)
def is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version.major >= 1
@ -833,7 +902,7 @@ class OpenAIChat(BaseLLM):
"""
client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo-1106"
model_name: str = "gpt-4-1106-preview"
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None

@ -75,7 +75,7 @@ class Conversation(BaseStructure):
self.autosave = autosave
self.save_filepath = save_filepath
self.conversation_history = []
# If system prompt is not None, add it to the conversation history
if self.system_prompt:
self.add("system", self.system_prompt)

@ -9,24 +9,24 @@ from swarms.structs.agent import Agent
class MultiAgentRag:
"""
Represents a multi-agent RAG (Relational Agent Graph) structure.
Attributes:
agents (List[Agent]): List of agents in the multi-agent RAG.
db (AbstractVectorDatabase): Database used for querying.
verbose (bool): Flag indicating whether to print verbose output.
"""
agents: List[Agent]
db: AbstractVectorDatabase
verbose: bool = False
def query_database(self, query: str):
"""
Queries the database using the given query string.
Args:
query (str): The query string.
Returns:
List: The list of results from the database.
"""
@ -35,15 +35,14 @@ class MultiAgentRag:
agent_results = agent.long_term_memory_prompt(query)
results.extend(agent_results)
return results
def get_agent_by_id(self, agent_id) -> Optional[Agent]:
"""
Retrieves an agent from the multi-agent RAG by its ID.
Args:
agent_id: The ID of the agent to retrieve.
Returns:
Agent or None: The agent with the specified ID, or None if not found.
"""
@ -51,47 +50,36 @@ class MultiAgentRag:
if agent.agent_id == agent_id:
return agent
return None
def add_message(
self,
sender: Agent,
message: str,
*args,
**kwargs
self, sender: Agent, message: str, *args, **kwargs
):
"""
Adds a message to the database.
Args:
sender (Agent): The agent sending the message.
message (str): The message to add.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
int: The ID of the added message.
"""
doc = f"{sender.ai_name}: {message}"
return self.db.add(doc)
def query(
self,
message: str,
*args,
**kwargs
):
def query(self, message: str, *args, **kwargs):
"""
Queries the database using the given message.
Args:
message (str): The message to query.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
List: The list of results from the database.
"""
return self.db.query(message)

Loading…
Cancel
Save