[CODE QUALITY]

pull/362/head^2
Kye 1 year ago
parent e70b401b54
commit d34427a4f8

@ -7,4 +7,4 @@ model = Kosmos()
out = model.run("Analyze the reciepts in this image", "docs.jpg") out = model.run("Analyze the reciepts in this image", "docs.jpg")
# Print the output # Print the output
print(out) print(out)

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "3.8.2" version = "3.8.5"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

@ -156,4 +156,4 @@ class AbstractDatabase(ABC):
""" """
pass pass

@ -91,7 +91,6 @@ class AbstractVectorDatabase(ABC):
pass pass
@abstractmethod @abstractmethod
def get(self, query: str): def get(self, query: str):
""" """
@ -139,4 +138,4 @@ class AbstractVectorDatabase(ABC):
""" """
pass pass

@ -48,7 +48,7 @@ from swarms.models.vip_llava import VipLlavaMultiModal # noqa: E402
from swarms.models.llava import LavaMultiModal # noqa: E402 from swarms.models.llava import LavaMultiModal # noqa: E402
from swarms.models.qwen import QwenVLMultiModal # noqa: E402 from swarms.models.qwen import QwenVLMultiModal # noqa: E402
from swarms.models.clipq import CLIPQ # noqa: E402 from swarms.models.clipq import CLIPQ # noqa: E402
from swarms.models.kosmos_two import Kosmos # noqa: E402 from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.fuyu import Fuyu # noqa: E402 from swarms.models.fuyu import Fuyu # noqa: E402
# from swarms.models.dalle3 import Dalle3 # from swarms.models.dalle3 import Dalle3

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import functools
import logging import logging
import sys import sys
from typing import ( from typing import (
@ -16,6 +18,7 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -23,7 +26,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.llms.base import BaseLLM
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
@ -32,7 +35,17 @@ from langchain.utils import (
get_pydantic_field_names, get_pydantic_field_names,
) )
from langchain.utils.utils import build_extra_kwargs from langchain.utils.utils import build_extra_kwargs
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
from importlib.metadata import version from importlib.metadata import version
@ -41,6 +54,62 @@ from packaging.version import parse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
logger.error(msg)
def create_base_retry_decorator(
error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(
error_types[0]
)
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(
error
)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(
multiplier=1, min=min_seconds, max=max_seconds
),
retry=retry_instance,
before_sleep=_before_sleep,
)
def is_openai_v1() -> bool: def is_openai_v1() -> bool:
_version = parse(version("openai")) _version = parse(version("openai"))
return _version.major >= 1 return _version.major >= 1
@ -833,7 +902,7 @@ class OpenAIChat(BaseLLM):
""" """
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo-1106" model_name: str = "gpt-4-1106-preview"
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None openai_api_base: Optional[str] = None

@ -75,7 +75,7 @@ class Conversation(BaseStructure):
self.autosave = autosave self.autosave = autosave
self.save_filepath = save_filepath self.save_filepath = save_filepath
self.conversation_history = [] self.conversation_history = []
# If system prompt is not None, add it to the conversation history # If system prompt is not None, add it to the conversation history
if self.system_prompt: if self.system_prompt:
self.add("system", self.system_prompt) self.add("system", self.system_prompt)

@ -9,24 +9,24 @@ from swarms.structs.agent import Agent
class MultiAgentRag: class MultiAgentRag:
""" """
Represents a multi-agent RAG (Relational Agent Graph) structure. Represents a multi-agent RAG (Relational Agent Graph) structure.
Attributes: Attributes:
agents (List[Agent]): List of agents in the multi-agent RAG. agents (List[Agent]): List of agents in the multi-agent RAG.
db (AbstractVectorDatabase): Database used for querying. db (AbstractVectorDatabase): Database used for querying.
verbose (bool): Flag indicating whether to print verbose output. verbose (bool): Flag indicating whether to print verbose output.
""" """
agents: List[Agent] agents: List[Agent]
db: AbstractVectorDatabase db: AbstractVectorDatabase
verbose: bool = False verbose: bool = False
def query_database(self, query: str): def query_database(self, query: str):
""" """
Queries the database using the given query string. Queries the database using the given query string.
Args: Args:
query (str): The query string. query (str): The query string.
Returns: Returns:
List: The list of results from the database. List: The list of results from the database.
""" """
@ -35,15 +35,14 @@ class MultiAgentRag:
agent_results = agent.long_term_memory_prompt(query) agent_results = agent.long_term_memory_prompt(query)
results.extend(agent_results) results.extend(agent_results)
return results return results
def get_agent_by_id(self, agent_id) -> Optional[Agent]: def get_agent_by_id(self, agent_id) -> Optional[Agent]:
""" """
Retrieves an agent from the multi-agent RAG by its ID. Retrieves an agent from the multi-agent RAG by its ID.
Args: Args:
agent_id: The ID of the agent to retrieve. agent_id: The ID of the agent to retrieve.
Returns: Returns:
Agent or None: The agent with the specified ID, or None if not found. Agent or None: The agent with the specified ID, or None if not found.
""" """
@ -51,47 +50,36 @@ class MultiAgentRag:
if agent.agent_id == agent_id: if agent.agent_id == agent_id:
return agent return agent
return None return None
def add_message( def add_message(
self, self, sender: Agent, message: str, *args, **kwargs
sender: Agent,
message: str,
*args,
**kwargs
): ):
""" """
Adds a message to the database. Adds a message to the database.
Args: Args:
sender (Agent): The agent sending the message. sender (Agent): The agent sending the message.
message (str): The message to add. message (str): The message to add.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
Returns: Returns:
int: The ID of the added message. int: The ID of the added message.
""" """
doc = f"{sender.ai_name}: {message}" doc = f"{sender.ai_name}: {message}"
return self.db.add(doc) return self.db.add(doc)
def query( def query(self, message: str, *args, **kwargs):
self,
message: str,
*args,
**kwargs
):
""" """
Queries the database using the given message. Queries the database using the given message.
Args: Args:
message (str): The message to query. message (str): The message to query.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
Returns: Returns:
List: The list of results from the database. List: The list of results from the database.
""" """
return self.db.query(message) return self.db.query(message)

Loading…
Cancel
Save