[CODE QUALITY] [MEMORY][New DB][SQLite]

pull/328/head
Kye 1 year ago
parent 20ad31543a
commit 51a271172f

@ -1,8 +1,5 @@
from swarms.structs import Agent from swarms.structs import Agent
from swarms.models.gpt4_vision_api import GPT4VisionAPI from swarms.models.gpt4_vision_api import GPT4VisionAPI
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
)
llm = GPT4VisionAPI() llm = GPT4VisionAPI()

@ -1,4 +1,3 @@
from swarms.models import OpenAIChat
from autotemp import AutoTemp from autotemp import AutoTemp
# Your OpenAI API key # Your OpenAI API key

@ -21,5 +21,8 @@ model = Gemini(
) )
out = model.chat("Create the code for a react component that displays a name", img=img) out = model.chat(
"Create the code for a react component that displays a name",
img=img,
)
print(out) print(out)

@ -22,5 +22,7 @@ model = Gemini(
# Run the model # Run the model
out = model.run("Create the code for a react component that displays a name") out = model.run(
"Create the code for a react component that displays a name"
)
print(out) print(out)

@ -2,7 +2,7 @@ import os
import base64 import base64
import requests import requests
from dotenv import load_dotenv from dotenv import load_dotenv
from swarms.models import Anthropic, OpenAIChat from swarms.models import OpenAIChat
from swarms.structs import Agent from swarms.structs import Agent
# Load environment variables # Load environment variables

@ -1,5 +1,4 @@
import os import os
import subprocess
from dotenv import load_dotenv from dotenv import load_dotenv

@ -20,7 +20,6 @@ from termcolor import colored
from swarms.models import GPT4VisionAPI from swarms.models import GPT4VisionAPI
from swarms.structs import Agent from swarms.structs import Agent
from swarms.utils.phoenix_handler import phoenix_trace_decorator
load_dotenv() load_dotenv()
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")

@ -1,4 +1,4 @@
from swarms.models.kosmos2 import Kosmos2, Detections from swarms.models.kosmos2 import Kosmos2
from PIL import Image from PIL import Image

@ -1,7 +1,6 @@
import os import os
from swarms.models import OpenAIChat from swarms.models import OpenAIChat
from swarms.structs import Agent from swarms.structs import Agent
from swarms.tools.tool import tool
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()

@ -1,4 +1,5 @@
from swarms.memory.base_vectordb import VectorDatabase from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.short_term_memory import ShortTermMemory from swarms.memory.short_term_memory import ShortTermMemory
from swarms.memory.sqlite import SQLiteDB
__all__ = ["VectorDatabase", "ShortTermMemory"] __all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB"]

@ -82,7 +82,7 @@ class Qdrant:
f"Collection '{self.collection_name}' already" f"Collection '{self.collection_name}' already"
" exists." " exists."
) )
except Exception as e: except Exception:
self.client.create_collection( self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=VectorParams( vectors_config=VectorParams(

@ -25,6 +25,7 @@ class ShortTermMemory(BaseStructure):
""" """
def __init__( def __init__(
self, self,
return_str: bool = True, return_str: bool = True,

@ -0,0 +1,120 @@
from typing import List, Tuple, Any, Optional
from swarms.memory.base_vectordb import VectorDatabase
try:
import sqlite3
except ImportError:
raise ImportError(
"Please install sqlite3 to use the SQLiteDB class."
)
class SQLiteDB(VectorDatabase):
"""
A reusable class for SQLite database operations with methods for adding,
deleting, updating, and querying data.
Attributes:
db_path (str): The file path to the SQLite database.
"""
def __init__(self, db_path: str):
"""
Initializes the SQLiteDB class with the given database path.
Args:
db_path (str): The file path to the SQLite database.
"""
self.db_path = db_path
def execute_query(
self, query: str, params: Optional[Tuple[Any, ...]] = None
) -> List[Tuple]:
"""
Executes a SQL query and returns fetched results.
Args:
query (str): The SQL query to execute.
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
Returns:
List[Tuple]: The results fetched from the database.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params or ())
return cursor.fetchall()
except Exception as error:
print(f"Error executing query: {error}")
raise error
def add(self, query: str, params: Tuple[Any, ...]) -> None:
"""
Adds a new entry to the database.
Args:
query (str): The SQL query for insertion.
params (Tuple[Any, ...]): The parameters to substitute into the query.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
except Exception as error:
print(f"Error adding new entry: {error}")
raise error
def delete(self, query: str, params: Tuple[Any, ...]) -> None:
"""
Deletes an entry from the database.
Args:
query (str): The SQL query for deletion.
params (Tuple[Any, ...]): The parameters to substitute into the query.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
except Exception as error:
print(f"Error deleting entry: {error}")
raise error
def update(self, query: str, params: Tuple[Any, ...]) -> None:
"""
Updates an entry in the database.
Args:
query (str): The SQL query for updating.
params (Tuple[Any, ...]): The parameters to substitute into the query.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
except Exception as error:
print(f"Error updating entry: {error}")
raise error
def query(
self, query: str, params: Optional[Tuple[Any, ...]] = None
) -> List[Tuple]:
"""
Fetches data from the database based on a query.
Args:
query (str): The SQL query to execute.
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
Returns:
List[Tuple]: The results fetched from the database.
"""
try:
return self.execute_query(query, params)
except Exception as error:
print(f"Error querying database: {error}")
raise error

@ -1,7 +1,7 @@
import wave import wave
from typing import Optional from typing import Optional
from swarms.models.base_llm import AbstractLLM from swarms.models.base_llm import AbstractLLM
from abc import ABC, abstractmethod from abc import abstractmethod
class BaseTTSModel(AbstractLLM): class BaseTTSModel(AbstractLLM):

@ -174,7 +174,10 @@ class Gemini(BaseMultiModalModel):
return response.text return response.text
else: else:
response = self.model.generate_content( response = self.model.generate_content(
prepare_prompt, stream=self.stream, *args, **kwargs prepare_prompt,
stream=self.stream,
*args,
**kwargs,
) )
return response.text return response.text
except Exception as error: except Exception as error:

@ -1,4 +1,3 @@
import os
from openai import OpenAI from openai import OpenAI
client = OpenAI() client = OpenAI()

@ -12,7 +12,6 @@ from termcolor import colored
from swarms.memory.base_vectordb import VectorDatabase from swarms.memory.base_vectordb import VectorDatabase
from swarms.prompts.agent_system_prompts import ( from swarms.prompts.agent_system_prompts import (
FLOW_SYSTEM_PROMPT,
AGENT_SYSTEM_PROMPT_3, AGENT_SYSTEM_PROMPT_3,
agent_system_prompt_2, agent_system_prompt_2,
) )

@ -3,8 +3,6 @@ import queue
import threading import threading
from time import sleep from time import sleep
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
import asyncio
import concurrent.futures
from termcolor import colored from termcolor import colored

@ -1,6 +1,6 @@
import json import json
import os import os
from abc import ABC, abstractmethod from abc import ABC
from typing import Optional, Any, Dict, List from typing import Optional, Any, Dict, List
from datetime import datetime from datetime import datetime
import asyncio import asyncio

@ -1,7 +1,5 @@
from enum import Enum, unique, auto from enum import Enum, unique, auto
import abc import abc
import hashlib
import re
from typing import List, Optional from typing import List, Optional
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field

@ -487,21 +487,21 @@ def print_action_base(action: Action):
""" """
if action.content != "": if action.content != "":
logger.typewriter_log( logger.typewriter_log(
f"content:", Fore.YELLOW, f"{action.content}" "content:", Fore.YELLOW, f"{action.content}"
) )
logger.typewriter_log( logger.typewriter_log(
f"Thought:", Fore.YELLOW, f"{action.thought}" "Thought:", Fore.YELLOW, f"{action.thought}"
) )
if len(action.plan) > 0: if len(action.plan) > 0:
logger.typewriter_log( logger.typewriter_log(
f"Plan:", "Plan:",
Fore.YELLOW, Fore.YELLOW,
) )
for line in action.plan: for line in action.plan:
line = line.lstrip("- ") line = line.lstrip("- ")
logger.typewriter_log("- ", Fore.GREEN, line.strip()) logger.typewriter_log("- ", Fore.GREEN, line.strip())
logger.typewriter_log( logger.typewriter_log(
f"Criticism:", Fore.YELLOW, f"{action.criticism}" "Criticism:", Fore.YELLOW, f"{action.criticism}"
) )
@ -515,15 +515,15 @@ def print_action_tool(action: Action):
Returns: Returns:
None None
""" """
logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}") logger.typewriter_log("Tool:", Fore.BLUE, f"{action.tool_name}")
logger.typewriter_log( logger.typewriter_log(
f"Tool Input:", Fore.BLUE, f"{action.tool_input}" "Tool Input:", Fore.BLUE, f"{action.tool_input}"
) )
output = ( output = (
action.tool_output if action.tool_output != "" else "None" action.tool_output if action.tool_output != "" else "None"
) )
logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}") logger.typewriter_log("Tool Output:", Fore.BLUE, f"{output}")
color = Fore.RED color = Fore.RED
if action.tool_output_status == ToolCallStatus.ToolCallSuccess: if action.tool_output_status == ToolCallStatus.ToolCallSuccess:
@ -534,7 +534,7 @@ def print_action_tool(action: Action):
color = Fore.YELLOW color = Fore.YELLOW
logger.typewriter_log( logger.typewriter_log(
f"Tool Call Status:", "Tool Call Status:",
Fore.BLUE, Fore.BLUE,
f"{color}{action.tool_output_status.name}{Style.RESET_ALL}", f"{color}{action.tool_output_status.name}{Style.RESET_ALL}",
) )

@ -9,7 +9,6 @@ from typing import Dict
import boto3 import boto3
import numpy as np import numpy as np
import pandas as pd
import requests import requests

@ -1,5 +1,4 @@
import sys import sys
import os
try: try:
import PyPDF2 import PyPDF2

@ -1,4 +1,3 @@
import pytest
from swarms.memory.short_term_memory import ShortTermMemory from swarms.memory.short_term_memory import ShortTermMemory
import threading import threading

@ -0,0 +1,104 @@
import pytest
import sqlite3
from swarms.memory.sqlite import SQLiteDB
@pytest.fixture
def db():
conn = sqlite3.connect(":memory:")
conn.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
)
conn.commit()
return SQLiteDB(":memory:")
def test_add(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.query("SELECT * FROM test")
assert result == [(1, "test")]
def test_delete(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
db.delete("DELETE FROM test WHERE name = ?", ("test",))
result = db.query("SELECT * FROM test")
assert result == []
def test_update(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
db.update(
"UPDATE test SET name = ? WHERE name = ?", ("new", "test")
)
result = db.query("SELECT * FROM test")
assert result == [(1, "new")]
def test_query(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.query("SELECT * FROM test WHERE name = ?", ("test",))
assert result == [(1, "test")]
def test_execute_query(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.execute_query(
"SELECT * FROM test WHERE name = ?", ("test",)
)
assert result == [(1, "test")]
def test_add_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.add("INSERT INTO test (name) VALUES (?)")
def test_delete_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.delete("DELETE FROM test WHERE name = ?")
def test_update_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.update("UPDATE test SET name = ? WHERE name = ?")
def test_query_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.query("SELECT * FROM test WHERE name = ?")
def test_execute_query_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.execute_query("SELECT * FROM test WHERE name = ?")
def test_add_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.add("INSERT INTO wrong (name) VALUES (?)", ("test",))
def test_delete_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.delete("DELETE FROM wrong WHERE name = ?", ("test",))
def test_update_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.update(
"UPDATE wrong SET name = ? WHERE name = ?",
("new", "test"),
)
def test_query_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.query("SELECT * FROM wrong WHERE name = ?", ("test",))
def test_execute_query_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.execute_query(
"SELECT * FROM wrong WHERE name = ?", ("test",)
)

@ -1,5 +1,4 @@
# Import necessary modules and define fixtures if needed # Import necessary modules and define fixtures if needed
import os
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
@ -156,7 +155,6 @@ def test_clip_inference_performance(
def test_clip_preprocessing_pipelines( def test_clip_preprocessing_pipelines(
clip_instance, sample_image_path clip_instance, sample_image_path
): ):
labels = ["label1", "label2"]
image = Image.open(sample_image_path) image = Image.open(sample_image_path)
# Test preprocessing for training # Test preprocessing for training

@ -110,7 +110,7 @@ def test_gemini_init_missing_api_key():
with pytest.raises( with pytest.raises(
ValueError, match="Please provide a Gemini API key" ValueError, match="Please provide a Gemini API key"
): ):
model = Gemini(gemini_api_key=None) Gemini(gemini_api_key=None)
# Test Gemini initialization with missing model name # Test Gemini initialization with missing model name
@ -118,7 +118,7 @@ def test_gemini_init_missing_model_name():
with pytest.raises( with pytest.raises(
ValueError, match="Please provide a model name" ValueError, match="Please provide a model name"
): ):
model = Gemini(model_name=None) Gemini(model_name=None)
# Test Gemini run method with empty task # Test Gemini run method with empty task

@ -48,7 +48,7 @@ def test_run_success(vision_api):
def test_run_request_error(vision_api): def test_run_request_error(vision_api):
with patch( with patch(
"requests.post", side_effect=RequestException("Request Error") "requests.post", side_effect=RequestException("Request Error")
) as mock_post: ):
with pytest.raises(RequestException): with pytest.raises(RequestException):
vision_api.run("What is this?", img) vision_api.run("What is this?", img)
@ -58,7 +58,7 @@ def test_run_response_error(vision_api):
with patch( with patch(
"requests.post", "requests.post",
return_value=Mock(json=lambda: expected_response), return_value=Mock(json=lambda: expected_response),
) as mock_post: ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
vision_api.run("What is this?", img) vision_api.run("What is this?", img)
@ -153,7 +153,7 @@ async def test_arun_request_error(vision_api):
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=Exception("Request Error"), side_effect=Exception("Request Error"),
) as mock_post: ):
with pytest.raises(Exception): with pytest.raises(Exception):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)
@ -181,7 +181,7 @@ def test_run_many_success(vision_api):
def test_run_many_request_error(vision_api): def test_run_many_request_error(vision_api):
with patch( with patch(
"requests.post", side_effect=RequestException("Request Error") "requests.post", side_effect=RequestException("Request Error")
) as mock_post: ):
tasks = ["What is this?", "What is that?"] tasks = ["What is this?", "What is that?"]
imgs = [img, img] imgs = [img, img]
with pytest.raises(RequestException): with pytest.raises(RequestException):
@ -196,7 +196,7 @@ async def test_arun_json_decode_error(vision_api):
return_value=AsyncMock( return_value=AsyncMock(
json=AsyncMock(side_effect=ValueError) json=AsyncMock(side_effect=ValueError)
), ),
) as mock_post: ):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)
@ -210,7 +210,7 @@ async def test_arun_api_error(vision_api):
return_value=AsyncMock( return_value=AsyncMock(
json=AsyncMock(return_value=error_response) json=AsyncMock(return_value=error_response)
), ),
) as mock_post: ):
with pytest.raises(Exception, match="API Error"): with pytest.raises(Exception, match="API Error"):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)
@ -224,7 +224,7 @@ async def test_arun_unexpected_response(vision_api):
return_value=AsyncMock( return_value=AsyncMock(
json=AsyncMock(return_value=unexpected_response) json=AsyncMock(return_value=unexpected_response)
), ),
) as mock_post: ):
with pytest.raises(Exception, match="Unexpected response"): with pytest.raises(Exception, match="Unexpected response"):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)
@ -247,6 +247,6 @@ async def test_arun_timeout(vision_api):
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
new_callable=AsyncMock, new_callable=AsyncMock,
side_effect=asyncio.TimeoutError, side_effect=asyncio.TimeoutError,
) as mock_post: ):
with pytest.raises(asyncio.TimeoutError): with pytest.raises(asyncio.TimeoutError):
await vision_api.arun("What is this?", img) await vision_api.arun("What is this?", img)

@ -15,7 +15,7 @@ def mock_multion():
def test_multion_import(): def test_multion_import():
with pytest.raises(ImportError): with pytest.raises(ImportError):
import multion pass
def test_multion_init(): def test_multion_init():

@ -1,4 +1,3 @@
import os
import requests import requests
import pytest import pytest
from unittest.mock import patch, Mock from unittest.mock import patch, Mock

@ -113,7 +113,7 @@ def test_vllm_run_empty_task(vllm_instance):
# Test initialization with invalid parameters # Test initialization with invalid parameters
def test_vllm_invalid_init(): def test_vllm_invalid_init():
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_instance = vLLM( vLLM(
model_name=None, model_name=None,
tensor_parallel_size=-1, tensor_parallel_size=-1,
trust_remote_code="invalid", trust_remote_code="invalid",

@ -347,7 +347,7 @@ def test_flow_response_filtering(flow_instance):
def test_flow_undo_last(flow_instance): def test_flow_undo_last(flow_instance):
# Test the undo functionality # Test the undo functionality
response1 = flow_instance.run("Task 1") response1 = flow_instance.run("Task 1")
response2 = flow_instance.run("Task 2") flow_instance.run("Task 2")
previous_state, message = flow_instance.undo_last() previous_state, message = flow_instance.undo_last()
assert response1 == previous_state assert response1 == previous_state
assert "Restored to" in message assert "Restored to" in message
@ -577,7 +577,7 @@ def test_flow_rollback(flow_instance):
# Test rolling back to a previous state # Test rolling back to a previous state
state1 = flow_instance.get_state() state1 = flow_instance.get_state()
flow_instance.change_prompt("New prompt") flow_instance.change_prompt("New prompt")
state2 = flow_instance.get_state() flow_instance.get_state()
flow_instance.rollback_to_state(state1) flow_instance.rollback_to_state(state1)
assert ( assert (
flow_instance.get_current_prompt() == state1["current_prompt"] flow_instance.get_current_prompt() == state1["current_prompt"]

@ -23,18 +23,18 @@ def test_autoscaler_init():
assert autoscaler.scale_up_factor == 1 assert autoscaler.scale_up_factor == 1
assert autoscaler.idle_threshold == 0.2 assert autoscaler.idle_threshold == 0.2
assert autoscaler.busy_threshold == 0.7 assert autoscaler.busy_threshold == 0.7
assert autoscaler.autoscale == True assert autoscaler.autoscale is True
assert autoscaler.min_agents == 1 assert autoscaler.min_agents == 1
assert autoscaler.max_agents == 5 assert autoscaler.max_agents == 5
assert autoscaler.custom_scale_strategy == None assert autoscaler.custom_scale_strategy is None
assert len(autoscaler.agents_pool) == 5 assert len(autoscaler.agents_pool) == 5
assert autoscaler.task_queue.empty() == True assert autoscaler.task_queue.empty() is True
def test_autoscaler_add_task(): def test_autoscaler_add_task():
autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler = AutoScaler(initial_agents=5, agent=agent)
autoscaler.add_task("task1") autoscaler.add_task("task1")
assert autoscaler.task_queue.empty() == False assert autoscaler.task_queue.empty() is False
def test_autoscaler_run(): def test_autoscaler_run():
@ -75,7 +75,7 @@ def test_autoscaler_get_agent_by_id():
def test_autoscaler_get_agent_by_id_not_found(): def test_autoscaler_get_agent_by_id_not_found():
autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler = AutoScaler(initial_agents=5, agent=agent)
agent = autoscaler.get_agent_by_id("fake_id") agent = autoscaler.get_agent_by_id("fake_id")
assert agent == None assert agent is None
@patch("swarms.swarms.Agent.is_healthy") @patch("swarms.swarms.Agent.is_healthy")

@ -1,4 +1,3 @@
import os
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
@ -163,4 +162,4 @@ def test_execute():
agent = Agent() agent = Agent()
task = Task(id="5", task="Task5", result=None, agents=[agent]) task = Task(id="5", task="Task5", result=None, agents=[agent])
# Assuming execute method returns True on successful execution # Assuming execute method returns True on successful execution
assert task.execute() == True assert task.execute() is True

@ -26,7 +26,7 @@ def test_collaboration_initialization(collaboration):
assert callable(collaboration.select_next_speaker) assert callable(collaboration.select_next_speaker)
assert collaboration.max_iters == 10 assert collaboration.max_iters == 10
assert collaboration.results == [] assert collaboration.results == []
assert collaboration.logging == True assert collaboration.logging is True
def test_reset(collaboration): def test_reset(collaboration):

@ -1,7 +1,5 @@
import os import os
import subprocess import subprocess
import json
import re
import requests import requests
from dotenv import load_dotenv from dotenv import load_dotenv

@ -1,8 +1,6 @@
# Import necessary modules and functions for testing # Import necessary modules and functions for testing
import functools
import subprocess import subprocess
import sys import sys
import traceback
import pytest import pytest

Loading…
Cancel
Save