[CODE QUALITY] [MEMORY][New DB][SQLite]

pull/328/head
Kye 1 year ago
parent 20ad31543a
commit 51a271172f

@ -1,8 +1,5 @@
from swarms.structs import Agent
from swarms.models.gpt4_vision_api import GPT4VisionAPI
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
)
llm = GPT4VisionAPI()

@ -1,4 +1,3 @@
from swarms.models import OpenAIChat
from autotemp import AutoTemp
# Your OpenAI API key

@ -21,5 +21,8 @@ model = Gemini(
)
out = model.chat("Create the code for a react component that displays a name", img=img)
out = model.chat(
"Create the code for a react component that displays a name",
img=img,
)
print(out)

@ -22,5 +22,7 @@ model = Gemini(
# Run the model
out = model.run("Create the code for a react component that displays a name")
out = model.run(
"Create the code for a react component that displays a name"
)
print(out)

@ -2,7 +2,7 @@ import os
import base64
import requests
from dotenv import load_dotenv
from swarms.models import Anthropic, OpenAIChat
from swarms.models import OpenAIChat
from swarms.structs import Agent
# Load environment variables

@ -1,5 +1,4 @@
import os
import subprocess
from dotenv import load_dotenv

@ -20,7 +20,6 @@ from termcolor import colored
from swarms.models import GPT4VisionAPI
from swarms.structs import Agent
from swarms.utils.phoenix_handler import phoenix_trace_decorator
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")

@ -1,4 +1,4 @@
from swarms.models.kosmos2 import Kosmos2, Detections
from swarms.models.kosmos2 import Kosmos2
from PIL import Image

@ -1,7 +1,6 @@
import os
from swarms.models import OpenAIChat
from swarms.structs import Agent
from swarms.tools.tool import tool
from dotenv import load_dotenv
load_dotenv()

@ -1,4 +1,5 @@
from swarms.memory.base_vectordb import VectorDatabase
from swarms.memory.short_term_memory import ShortTermMemory
from swarms.memory.sqlite import SQLiteDB
__all__ = ["VectorDatabase", "ShortTermMemory"]
__all__ = ["VectorDatabase", "ShortTermMemory", "SQLiteDB"]

@ -82,7 +82,7 @@ class Qdrant:
f"Collection '{self.collection_name}' already"
" exists."
)
except Exception as e:
except Exception:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(

@ -25,6 +25,7 @@ class ShortTermMemory(BaseStructure):
"""
def __init__(
self,
return_str: bool = True,

@ -0,0 +1,120 @@
from typing import List, Tuple, Any, Optional
from swarms.memory.base_vectordb import VectorDatabase
try:
import sqlite3
except ImportError:
raise ImportError(
"Please install sqlite3 to use the SQLiteDB class."
)
class SQLiteDB(VectorDatabase):
"""
A reusable class for SQLite database operations with methods for adding,
deleting, updating, and querying data.
Attributes:
db_path (str): The file path to the SQLite database.
"""
def __init__(self, db_path: str):
"""
Initializes the SQLiteDB class with the given database path.
Args:
db_path (str): The file path to the SQLite database.
"""
self.db_path = db_path
def execute_query(
self, query: str, params: Optional[Tuple[Any, ...]] = None
) -> List[Tuple]:
"""
Executes a SQL query and returns fetched results.
Args:
query (str): The SQL query to execute.
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
Returns:
List[Tuple]: The results fetched from the database.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params or ())
return cursor.fetchall()
except Exception as error:
print(f"Error executing query: {error}")
raise error
def add(self, query: str, params: Tuple[Any, ...]) -> None:
"""
Adds a new entry to the database.
Args:
query (str): The SQL query for insertion.
params (Tuple[Any, ...]): The parameters to substitute into the query.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
except Exception as error:
print(f"Error adding new entry: {error}")
raise error
def delete(self, query: str, params: Tuple[Any, ...]) -> None:
"""
Deletes an entry from the database.
Args:
query (str): The SQL query for deletion.
params (Tuple[Any, ...]): The parameters to substitute into the query.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
except Exception as error:
print(f"Error deleting entry: {error}")
raise error
def update(self, query: str, params: Tuple[Any, ...]) -> None:
"""
Updates an entry in the database.
Args:
query (str): The SQL query for updating.
params (Tuple[Any, ...]): The parameters to substitute into the query.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
except Exception as error:
print(f"Error updating entry: {error}")
raise error
def query(
self, query: str, params: Optional[Tuple[Any, ...]] = None
) -> List[Tuple]:
"""
Fetches data from the database based on a query.
Args:
query (str): The SQL query to execute.
params (Tuple[Any, ...], optional): The parameters to substitute into the query.
Returns:
List[Tuple]: The results fetched from the database.
"""
try:
return self.execute_query(query, params)
except Exception as error:
print(f"Error querying database: {error}")
raise error

@ -1,7 +1,7 @@
import wave
from typing import Optional
from swarms.models.base_llm import AbstractLLM
from abc import ABC, abstractmethod
from abc import abstractmethod
class BaseTTSModel(AbstractLLM):

@ -174,7 +174,10 @@ class Gemini(BaseMultiModalModel):
return response.text
else:
response = self.model.generate_content(
prepare_prompt, stream=self.stream, *args, **kwargs
prepare_prompt,
stream=self.stream,
*args,
**kwargs,
)
return response.text
except Exception as error:

@ -1,4 +1,3 @@
import os
from openai import OpenAI
client = OpenAI()

@ -12,7 +12,6 @@ from termcolor import colored
from swarms.memory.base_vectordb import VectorDatabase
from swarms.prompts.agent_system_prompts import (
FLOW_SYSTEM_PROMPT,
AGENT_SYSTEM_PROMPT_3,
agent_system_prompt_2,
)

@ -3,8 +3,6 @@ import queue
import threading
from time import sleep
from typing import Callable, Dict, List, Optional
import asyncio
import concurrent.futures
from termcolor import colored

@ -1,6 +1,6 @@
import json
import os
from abc import ABC, abstractmethod
from abc import ABC
from typing import Optional, Any, Dict, List
from datetime import datetime
import asyncio

@ -1,7 +1,5 @@
from enum import Enum, unique, auto
import abc
import hashlib
import re
from typing import List, Optional
import json
from dataclasses import dataclass, field

@ -487,21 +487,21 @@ def print_action_base(action: Action):
"""
if action.content != "":
logger.typewriter_log(
f"content:", Fore.YELLOW, f"{action.content}"
"content:", Fore.YELLOW, f"{action.content}"
)
logger.typewriter_log(
f"Thought:", Fore.YELLOW, f"{action.thought}"
"Thought:", Fore.YELLOW, f"{action.thought}"
)
if len(action.plan) > 0:
logger.typewriter_log(
f"Plan:",
"Plan:",
Fore.YELLOW,
)
for line in action.plan:
line = line.lstrip("- ")
logger.typewriter_log("- ", Fore.GREEN, line.strip())
logger.typewriter_log(
f"Criticism:", Fore.YELLOW, f"{action.criticism}"
"Criticism:", Fore.YELLOW, f"{action.criticism}"
)
@ -515,15 +515,15 @@ def print_action_tool(action: Action):
Returns:
None
"""
logger.typewriter_log(f"Tool:", Fore.BLUE, f"{action.tool_name}")
logger.typewriter_log("Tool:", Fore.BLUE, f"{action.tool_name}")
logger.typewriter_log(
f"Tool Input:", Fore.BLUE, f"{action.tool_input}"
"Tool Input:", Fore.BLUE, f"{action.tool_input}"
)
output = (
action.tool_output if action.tool_output != "" else "None"
)
logger.typewriter_log(f"Tool Output:", Fore.BLUE, f"{output}")
logger.typewriter_log("Tool Output:", Fore.BLUE, f"{output}")
color = Fore.RED
if action.tool_output_status == ToolCallStatus.ToolCallSuccess:
@ -534,7 +534,7 @@ def print_action_tool(action: Action):
color = Fore.YELLOW
logger.typewriter_log(
f"Tool Call Status:",
"Tool Call Status:",
Fore.BLUE,
f"{color}{action.tool_output_status.name}{Style.RESET_ALL}",
)

@ -9,7 +9,6 @@ from typing import Dict
import boto3
import numpy as np
import pandas as pd
import requests

@ -1,5 +1,4 @@
import sys
import os
try:
import PyPDF2

@ -1,4 +1,3 @@
import pytest
from swarms.memory.short_term_memory import ShortTermMemory
import threading

@ -0,0 +1,104 @@
import pytest
import sqlite3
from swarms.memory.sqlite import SQLiteDB
@pytest.fixture
def db():
conn = sqlite3.connect(":memory:")
conn.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
)
conn.commit()
return SQLiteDB(":memory:")
def test_add(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.query("SELECT * FROM test")
assert result == [(1, "test")]
def test_delete(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
db.delete("DELETE FROM test WHERE name = ?", ("test",))
result = db.query("SELECT * FROM test")
assert result == []
def test_update(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
db.update(
"UPDATE test SET name = ? WHERE name = ?", ("new", "test")
)
result = db.query("SELECT * FROM test")
assert result == [(1, "new")]
def test_query(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.query("SELECT * FROM test WHERE name = ?", ("test",))
assert result == [(1, "test")]
def test_execute_query(db):
db.add("INSERT INTO test (name) VALUES (?)", ("test",))
result = db.execute_query(
"SELECT * FROM test WHERE name = ?", ("test",)
)
assert result == [(1, "test")]
def test_add_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.add("INSERT INTO test (name) VALUES (?)")
def test_delete_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.delete("DELETE FROM test WHERE name = ?")
def test_update_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.update("UPDATE test SET name = ? WHERE name = ?")
def test_query_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.query("SELECT * FROM test WHERE name = ?")
def test_execute_query_without_params(db):
with pytest.raises(sqlite3.ProgrammingError):
db.execute_query("SELECT * FROM test WHERE name = ?")
def test_add_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.add("INSERT INTO wrong (name) VALUES (?)", ("test",))
def test_delete_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.delete("DELETE FROM wrong WHERE name = ?", ("test",))
def test_update_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.update(
"UPDATE wrong SET name = ? WHERE name = ?",
("new", "test"),
)
def test_query_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.query("SELECT * FROM wrong WHERE name = ?", ("test",))
def test_execute_query_with_wrong_query(db):
with pytest.raises(sqlite3.OperationalError):
db.execute_query(
"SELECT * FROM wrong WHERE name = ?", ("test",)
)

@ -1,5 +1,4 @@
# Import necessary modules and define fixtures if needed
import os
import pytest
import torch
from PIL import Image
@ -156,7 +155,6 @@ def test_clip_inference_performance(
def test_clip_preprocessing_pipelines(
clip_instance, sample_image_path
):
labels = ["label1", "label2"]
image = Image.open(sample_image_path)
# Test preprocessing for training

@ -110,7 +110,7 @@ def test_gemini_init_missing_api_key():
with pytest.raises(
ValueError, match="Please provide a Gemini API key"
):
model = Gemini(gemini_api_key=None)
Gemini(gemini_api_key=None)
# Test Gemini initialization with missing model name
@ -118,7 +118,7 @@ def test_gemini_init_missing_model_name():
with pytest.raises(
ValueError, match="Please provide a model name"
):
model = Gemini(model_name=None)
Gemini(model_name=None)
# Test Gemini run method with empty task

@ -48,7 +48,7 @@ def test_run_success(vision_api):
def test_run_request_error(vision_api):
with patch(
"requests.post", side_effect=RequestException("Request Error")
) as mock_post:
):
with pytest.raises(RequestException):
vision_api.run("What is this?", img)
@ -58,7 +58,7 @@ def test_run_response_error(vision_api):
with patch(
"requests.post",
return_value=Mock(json=lambda: expected_response),
) as mock_post:
):
with pytest.raises(RuntimeError):
vision_api.run("What is this?", img)
@ -153,7 +153,7 @@ async def test_arun_request_error(vision_api):
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=Exception("Request Error"),
) as mock_post:
):
with pytest.raises(Exception):
await vision_api.arun("What is this?", img)
@ -181,7 +181,7 @@ def test_run_many_success(vision_api):
def test_run_many_request_error(vision_api):
with patch(
"requests.post", side_effect=RequestException("Request Error")
) as mock_post:
):
tasks = ["What is this?", "What is that?"]
imgs = [img, img]
with pytest.raises(RequestException):
@ -196,7 +196,7 @@ async def test_arun_json_decode_error(vision_api):
return_value=AsyncMock(
json=AsyncMock(side_effect=ValueError)
),
) as mock_post:
):
with pytest.raises(ValueError):
await vision_api.arun("What is this?", img)
@ -210,7 +210,7 @@ async def test_arun_api_error(vision_api):
return_value=AsyncMock(
json=AsyncMock(return_value=error_response)
),
) as mock_post:
):
with pytest.raises(Exception, match="API Error"):
await vision_api.arun("What is this?", img)
@ -224,7 +224,7 @@ async def test_arun_unexpected_response(vision_api):
return_value=AsyncMock(
json=AsyncMock(return_value=unexpected_response)
),
) as mock_post:
):
with pytest.raises(Exception, match="Unexpected response"):
await vision_api.arun("What is this?", img)
@ -247,6 +247,6 @@ async def test_arun_timeout(vision_api):
"aiohttp.ClientSession.post",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
) as mock_post:
):
with pytest.raises(asyncio.TimeoutError):
await vision_api.arun("What is this?", img)

@ -15,7 +15,7 @@ def mock_multion():
def test_multion_import():
with pytest.raises(ImportError):
import multion
pass
def test_multion_init():

@ -1,4 +1,3 @@
import os
import requests
import pytest
from unittest.mock import patch, Mock

@ -113,7 +113,7 @@ def test_vllm_run_empty_task(vllm_instance):
# Test initialization with invalid parameters
def test_vllm_invalid_init():
with pytest.raises(ValueError):
vllm_instance = vLLM(
vLLM(
model_name=None,
tensor_parallel_size=-1,
trust_remote_code="invalid",

@ -347,7 +347,7 @@ def test_flow_response_filtering(flow_instance):
def test_flow_undo_last(flow_instance):
# Test the undo functionality
response1 = flow_instance.run("Task 1")
response2 = flow_instance.run("Task 2")
flow_instance.run("Task 2")
previous_state, message = flow_instance.undo_last()
assert response1 == previous_state
assert "Restored to" in message
@ -577,7 +577,7 @@ def test_flow_rollback(flow_instance):
# Test rolling back to a previous state
state1 = flow_instance.get_state()
flow_instance.change_prompt("New prompt")
state2 = flow_instance.get_state()
flow_instance.get_state()
flow_instance.rollback_to_state(state1)
assert (
flow_instance.get_current_prompt() == state1["current_prompt"]

@ -23,18 +23,18 @@ def test_autoscaler_init():
assert autoscaler.scale_up_factor == 1
assert autoscaler.idle_threshold == 0.2
assert autoscaler.busy_threshold == 0.7
assert autoscaler.autoscale == True
assert autoscaler.autoscale is True
assert autoscaler.min_agents == 1
assert autoscaler.max_agents == 5
assert autoscaler.custom_scale_strategy == None
assert autoscaler.custom_scale_strategy is None
assert len(autoscaler.agents_pool) == 5
assert autoscaler.task_queue.empty() == True
assert autoscaler.task_queue.empty() is True
def test_autoscaler_add_task():
autoscaler = AutoScaler(initial_agents=5, agent=agent)
autoscaler.add_task("task1")
assert autoscaler.task_queue.empty() == False
assert autoscaler.task_queue.empty() is False
def test_autoscaler_run():
@ -75,7 +75,7 @@ def test_autoscaler_get_agent_by_id():
def test_autoscaler_get_agent_by_id_not_found():
autoscaler = AutoScaler(initial_agents=5, agent=agent)
agent = autoscaler.get_agent_by_id("fake_id")
assert agent == None
assert agent is None
@patch("swarms.swarms.Agent.is_healthy")

@ -1,4 +1,3 @@
import os
from unittest.mock import Mock
import pytest
@ -163,4 +162,4 @@ def test_execute():
agent = Agent()
task = Task(id="5", task="Task5", result=None, agents=[agent])
# Assuming execute method returns True on successful execution
assert task.execute() == True
assert task.execute() is True

@ -26,7 +26,7 @@ def test_collaboration_initialization(collaboration):
assert callable(collaboration.select_next_speaker)
assert collaboration.max_iters == 10
assert collaboration.results == []
assert collaboration.logging == True
assert collaboration.logging is True
def test_reset(collaboration):

@ -1,7 +1,5 @@
import os
import subprocess
import json
import re
import requests
from dotenv import load_dotenv

@ -1,8 +1,6 @@
# Import necessary modules and functions for testing
import functools
import subprocess
import sys
import traceback
import pytest

Loading…
Cancel
Save