Merge d4176729f3
into 09159905b0
commit
2b6ad4e36b
@ -0,0 +1,22 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
.git,
|
||||
__pycache__,
|
||||
build,
|
||||
dist,
|
||||
*.egg-info,
|
||||
.eggs,
|
||||
.tox,
|
||||
.venv,
|
||||
venv,
|
||||
.env,
|
||||
.pytest_cache,
|
||||
.coverage,
|
||||
htmlcov,
|
||||
.mypy_cache,
|
||||
.ruff_cache
|
||||
per-file-ignores =
|
||||
__init__.py: F401
|
||||
max-complexity = 10
|
@ -0,0 +1,106 @@
|
||||
from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig
|
||||
|
||||
def main():
|
||||
# Initialize TypeDB wrapper with custom configuration
|
||||
config = TypeDBConfig(
|
||||
uri="localhost:1729",
|
||||
database="swarms_example",
|
||||
username="admin",
|
||||
password="password"
|
||||
)
|
||||
|
||||
# Define schema for a simple knowledge graph
|
||||
schema = """
|
||||
define
|
||||
person sub entity,
|
||||
owns name: string,
|
||||
owns age: long,
|
||||
plays role;
|
||||
|
||||
role sub entity,
|
||||
owns title: string,
|
||||
owns department: string;
|
||||
|
||||
works_at sub relation,
|
||||
relates person,
|
||||
relates role;
|
||||
"""
|
||||
|
||||
# Example data insertion
|
||||
insert_queries = [
|
||||
"""
|
||||
insert
|
||||
$p isa person, has name "John Doe", has age 30;
|
||||
$r isa role, has title "Software Engineer", has department "Engineering";
|
||||
(person: $p, role: $r) isa works_at;
|
||||
""",
|
||||
"""
|
||||
insert
|
||||
$p isa person, has name "Jane Smith", has age 28;
|
||||
$r isa role, has title "Data Scientist", has department "Data Science";
|
||||
(person: $p, role: $r) isa works_at;
|
||||
"""
|
||||
]
|
||||
|
||||
# Example queries
|
||||
query_queries = [
|
||||
# Get all people
|
||||
"match $p isa person; get;",
|
||||
|
||||
# Get people in Engineering department
|
||||
"""
|
||||
match
|
||||
$p isa person;
|
||||
$r isa role, has department "Engineering";
|
||||
(person: $p, role: $r) isa works_at;
|
||||
get $p;
|
||||
""",
|
||||
|
||||
# Get people with their roles
|
||||
"""
|
||||
match
|
||||
$p isa person, has name $n;
|
||||
$r isa role, has title $t;
|
||||
(person: $p, role: $r) isa works_at;
|
||||
get $n, $t;
|
||||
"""
|
||||
]
|
||||
|
||||
try:
|
||||
with TypeDBWrapper(config) as db:
|
||||
# Define schema
|
||||
print("Defining schema...")
|
||||
db.define_schema(schema)
|
||||
|
||||
# Insert data
|
||||
print("\nInserting data...")
|
||||
for query in insert_queries:
|
||||
db.insert_data(query)
|
||||
|
||||
# Query data
|
||||
print("\nQuerying data...")
|
||||
for i, query in enumerate(query_queries, 1):
|
||||
print(f"\nQuery {i}:")
|
||||
results = db.query_data(query)
|
||||
print(f"Results: {results}")
|
||||
|
||||
# Example of deleting data
|
||||
print("\nDeleting data...")
|
||||
delete_query = """
|
||||
match
|
||||
$p isa person, has name "John Doe";
|
||||
delete $p;
|
||||
"""
|
||||
db.delete_data(delete_query)
|
||||
|
||||
# Verify deletion
|
||||
print("\nVerifying deletion...")
|
||||
verify_query = "match $p isa person, has name $n; get $n;"
|
||||
results = db.query_data(verify_query)
|
||||
print(f"Remaining people: {results}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,4 @@
|
||||
[pyupgrade]
|
||||
py3-plus = True
|
||||
py39-plus = True
|
||||
keep-runtime-typing = True
|
@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python3
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def run_command(command: list[str], cwd: Path) -> bool:
|
||||
"""Run a command and return True if successful."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running {' '.join(command)}:")
|
||||
print(e.stdout)
|
||||
print(e.stderr, file=sys.stderr)
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all code quality checks."""
|
||||
root_dir = Path(__file__).parent.parent
|
||||
success = True
|
||||
|
||||
# Run flake8
|
||||
print("\nRunning flake8...")
|
||||
if not run_command(["flake8", "swarms", "tests"], root_dir):
|
||||
success = False
|
||||
|
||||
# Run pyupgrade
|
||||
print("\nRunning pyupgrade...")
|
||||
if not run_command(["pyupgrade", "--py39-plus", "swarms", "tests"], root_dir):
|
||||
success = False
|
||||
|
||||
# Run black
|
||||
print("\nRunning black...")
|
||||
if not run_command(["black", "--check", "swarms", "tests"], root_dir):
|
||||
success = False
|
||||
|
||||
# Run ruff
|
||||
print("\nRunning ruff...")
|
||||
if not run_command(["ruff", "check", "swarms", "tests"], root_dir):
|
||||
success = False
|
||||
|
||||
if not success:
|
||||
print("\nCode quality checks failed. Please fix the issues and try again.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("\nAll code quality checks passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,32 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
class ToolAgentError(Exception):
|
||||
"""Base exception for all tool agent errors."""
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
class ToolExecutionError(ToolAgentError):
|
||||
"""Raised when a tool fails to execute."""
|
||||
def __init__(self, tool_name: str, error: Exception, details: Optional[Dict[str, Any]] = None):
|
||||
message = f"Failed to execute tool '{tool_name}': {str(error)}"
|
||||
super().__init__(message, details)
|
||||
|
||||
class ToolValidationError(ToolAgentError):
|
||||
"""Raised when tool parameters fail validation."""
|
||||
def __init__(self, tool_name: str, param_name: str, error: str, details: Optional[Dict[str, Any]] = None):
|
||||
message = f"Validation error for tool '{tool_name}' parameter '{param_name}': {error}"
|
||||
super().__init__(message, details)
|
||||
|
||||
class ToolNotFoundError(ToolAgentError):
|
||||
"""Raised when a requested tool is not found."""
|
||||
def __init__(self, tool_name: str, details: Optional[Dict[str, Any]] = None):
|
||||
message = f"Tool '{tool_name}' not found"
|
||||
super().__init__(message, details)
|
||||
|
||||
class ToolParameterError(ToolAgentError):
|
||||
"""Raised when tool parameters are invalid."""
|
||||
def __init__(self, tool_name: str, error: str, details: Optional[Dict[str, Any]] = None):
|
||||
message = f"Invalid parameters for tool '{tool_name}': {error}"
|
||||
super().__init__(message, details)
|
@ -0,0 +1,168 @@
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from loguru import logger
|
||||
from typedb.client import TypeDB, SessionType, TransactionType
|
||||
from typedb.api.connection.transaction import Transaction
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
|
||||
@dataclass
|
||||
class TypeDBConfig:
|
||||
"""Configuration for TypeDB connection."""
|
||||
uri: str = "localhost:1729"
|
||||
database: str = "swarms"
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
timeout: int = 30
|
||||
|
||||
class TypeDBWrapper:
|
||||
"""
|
||||
A wrapper class for TypeDB that provides graph database operations for Swarms.
|
||||
This class handles connection, schema management, and data operations.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[TypeDBConfig] = None):
|
||||
"""
|
||||
Initialize the TypeDB wrapper with the given configuration.
|
||||
Args:
|
||||
config (Optional[TypeDBConfig]): Configuration for TypeDB connection.
|
||||
"""
|
||||
self.config = config or TypeDBConfig()
|
||||
self.client = None
|
||||
self.session = None
|
||||
self._connect()
|
||||
|
||||
def _connect(self) -> None:
|
||||
"""Establish connection to TypeDB."""
|
||||
try:
|
||||
self.client = TypeDB.core_client(self.config.uri)
|
||||
if self.config.username and self.config.password:
|
||||
self.session = self.client.session(
|
||||
self.config.database,
|
||||
SessionType.DATA,
|
||||
self.config.username,
|
||||
self.config.password
|
||||
)
|
||||
else:
|
||||
self.session = self.client.session(
|
||||
self.config.database,
|
||||
SessionType.DATA
|
||||
)
|
||||
logger.info(f"Connected to TypeDB at {self.config.uri}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to TypeDB: {e}")
|
||||
raise
|
||||
|
||||
def _ensure_connection(self) -> None:
|
||||
"""Ensure connection is active, reconnect if necessary."""
|
||||
if not self.session or not self.session.is_open():
|
||||
self._connect()
|
||||
|
||||
def define_schema(self, schema: str) -> None:
|
||||
"""
|
||||
Define the database schema.
|
||||
Args:
|
||||
schema (str): TypeQL schema definition.
|
||||
"""
|
||||
try:
|
||||
with self.session.transaction(TransactionType.WRITE) as transaction:
|
||||
transaction.query.define(schema)
|
||||
transaction.commit()
|
||||
logger.info("Schema defined successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to define schema: {e}")
|
||||
raise
|
||||
|
||||
def insert_data(self, query: str) -> None:
|
||||
"""
|
||||
Insert data using TypeQL query.
|
||||
Args:
|
||||
query (str): TypeQL insert query.
|
||||
"""
|
||||
try:
|
||||
with self.session.transaction(TransactionType.WRITE) as transaction:
|
||||
transaction.query.insert(query)
|
||||
transaction.commit()
|
||||
logger.info("Data inserted successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert data: {e}")
|
||||
raise
|
||||
|
||||
def query_data(self, query: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query data using TypeQL query.
|
||||
Args:
|
||||
query (str): TypeQL query.
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Query results.
|
||||
"""
|
||||
try:
|
||||
with self.session.transaction(TransactionType.READ) as transaction:
|
||||
result = transaction.query.get(query)
|
||||
return [self._convert_concept_to_dict(concept) for concept in result]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query data: {e}")
|
||||
raise
|
||||
|
||||
def _convert_concept_to_dict(self, concept: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a TypeDB concept to a dictionary.
|
||||
Args:
|
||||
concept: TypeDB concept.
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the concept.
|
||||
"""
|
||||
try:
|
||||
if hasattr(concept, "get_type"):
|
||||
concept_type = concept.get_type()
|
||||
if hasattr(concept, "get_value"):
|
||||
return {
|
||||
"type": concept_type.get_label_name(),
|
||||
"value": concept.get_value()
|
||||
}
|
||||
elif hasattr(concept, "get_attributes"):
|
||||
return {
|
||||
"type": concept_type.get_label_name(),
|
||||
"attributes": {
|
||||
attr.get_type().get_label_name(): attr.get_value()
|
||||
for attr in concept.get_attributes()
|
||||
}
|
||||
}
|
||||
return {"type": "unknown", "value": str(concept)}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert concept to dict: {e}")
|
||||
return {"type": "error", "value": str(e)}
|
||||
|
||||
def delete_data(self, query: str) -> None:
|
||||
"""
|
||||
Delete data using TypeQL query.
|
||||
Args:
|
||||
query (str): TypeQL delete query.
|
||||
"""
|
||||
try:
|
||||
with self.session.transaction(TransactionType.WRITE) as transaction:
|
||||
transaction.query.delete(query)
|
||||
transaction.commit()
|
||||
logger.info("Data deleted successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete data: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the TypeDB connection."""
|
||||
try:
|
||||
if self.session:
|
||||
self.session.close()
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("TypeDB connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close TypeDB connection: {e}")
|
||||
raise
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
@ -0,0 +1,134 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from swarms.structs.swarm_arange import SwarmRearrange
|
||||
from swarms import Agent
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent for testing."""
|
||||
return Mock(spec=Agent)
|
||||
|
||||
@pytest.fixture
|
||||
def swarm_rearrange(mock_agent):
|
||||
"""Create a SwarmRearrange instance with mock agent."""
|
||||
return SwarmRearrange(
|
||||
id="test_id",
|
||||
name="TestSwarm",
|
||||
description="Test swarm for testing",
|
||||
swarms=[mock_agent],
|
||||
flow="Agent1 -> Agent2",
|
||||
max_loops=2,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
def test_initialization(swarm_rearrange):
|
||||
"""Test SwarmRearrange initialization."""
|
||||
assert swarm_rearrange.id == "test_id"
|
||||
assert swarm_rearrange.name == "TestSwarm"
|
||||
assert swarm_rearrange.description == "Test swarm for testing"
|
||||
assert len(swarm_rearrange.swarms) == 1
|
||||
assert swarm_rearrange.flow == "Agent1 -> Agent2"
|
||||
assert swarm_rearrange.max_loops == 2
|
||||
assert swarm_rearrange.verbose is True
|
||||
|
||||
def test_reliability_checks_empty_swarms():
|
||||
"""Test reliability checks with empty swarms."""
|
||||
with pytest.raises(ValueError, match="No swarms found in the swarm."):
|
||||
SwarmRearrange(swarms=[], flow="test")
|
||||
|
||||
def test_reliability_checks_empty_flow():
|
||||
"""Test reliability checks with empty flow."""
|
||||
with pytest.raises(ValueError, match="No flow found in the swarm."):
|
||||
SwarmRearrange(swarms=[Mock()], flow="")
|
||||
|
||||
def test_reliability_checks_invalid_max_loops():
|
||||
"""Test reliability checks with invalid max_loops."""
|
||||
with pytest.raises(ValueError, match="Max loops must be a positive integer."):
|
||||
SwarmRearrange(swarms=[Mock()], flow="test", max_loops=0)
|
||||
|
||||
def test_add_swarm(swarm_rearrange, mock_agent):
|
||||
"""Test adding a new swarm."""
|
||||
new_agent = Mock(spec=Agent)
|
||||
swarm_rearrange.add_swarm(new_agent)
|
||||
assert len(swarm_rearrange.swarms) == 2
|
||||
assert new_agent in swarm_rearrange.swarms.values()
|
||||
|
||||
def test_remove_swarm(swarm_rearrange, mock_agent):
|
||||
"""Test removing a swarm."""
|
||||
swarm_rearrange.remove_swarm(mock_agent.name)
|
||||
assert len(swarm_rearrange.swarms) == 0
|
||||
assert mock_agent.name not in swarm_rearrange.swarms
|
||||
|
||||
def test_add_swarms(swarm_rearrange):
|
||||
"""Test adding multiple swarms."""
|
||||
new_agents = [Mock(spec=Agent) for _ in range(3)]
|
||||
swarm_rearrange.add_swarms(new_agents)
|
||||
assert len(swarm_rearrange.swarms) == 4
|
||||
for agent in new_agents:
|
||||
assert agent in swarm_rearrange.swarms.values()
|
||||
|
||||
def test_track_history(swarm_rearrange, mock_agent):
|
||||
"""Test tracking swarm history."""
|
||||
result = "Test result"
|
||||
swarm_rearrange.track_history(mock_agent.name, result)
|
||||
assert result in swarm_rearrange.swarm_history[mock_agent.name]
|
||||
|
||||
def test_set_custom_flow(swarm_rearrange):
|
||||
"""Test setting custom flow."""
|
||||
new_flow = "Agent1, Agent2 -> Agent3"
|
||||
swarm_rearrange.set_custom_flow(new_flow)
|
||||
assert swarm_rearrange.flow == new_flow
|
||||
|
||||
def test_context_manager(swarm_rearrange):
|
||||
"""Test context manager functionality."""
|
||||
with swarm_rearrange as db:
|
||||
assert db == swarm_rearrange
|
||||
# Verify cleanup was performed
|
||||
assert not swarm_rearrange.session.is_open()
|
||||
|
||||
def test_error_handling(swarm_rearrange):
|
||||
"""Test error handling in various operations."""
|
||||
# Test invalid flow pattern
|
||||
with pytest.raises(ValueError):
|
||||
swarm_rearrange.set_custom_flow("Invalid -> Flow -> Pattern")
|
||||
|
||||
# Test removing non-existent swarm
|
||||
with pytest.raises(KeyError):
|
||||
swarm_rearrange.remove_swarm("NonExistentSwarm")
|
||||
|
||||
def test_thread_safety(swarm_rearrange):
|
||||
"""Test thread safety of operations."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
def add_swarm_thread():
|
||||
for i in range(10):
|
||||
new_agent = Mock(spec=Agent)
|
||||
new_agent.name = f"Agent{i}"
|
||||
swarm_rearrange.add_swarm(new_agent)
|
||||
time.sleep(0.1)
|
||||
|
||||
def remove_swarm_thread():
|
||||
for i in range(10):
|
||||
try:
|
||||
swarm_rearrange.remove_swarm(f"Agent{i}")
|
||||
except KeyError:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
# Create and start threads
|
||||
threads = [
|
||||
threading.Thread(target=add_swarm_thread),
|
||||
threading.Thread(target=remove_swarm_thread)
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify no data corruption occurred
|
||||
assert isinstance(swarm_rearrange.swarms, dict)
|
||||
assert isinstance(swarm_rearrange.swarm_history, dict)
|
@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from swarms.utils.typedb_wrapper import TypeDBWrapper, TypeDBConfig
|
||||
|
||||
@pytest.fixture
|
||||
def mock_typedb():
|
||||
"""Mock TypeDB client and session."""
|
||||
with patch('swarms.utils.typedb_wrapper.TypeDB') as mock_typedb:
|
||||
mock_client = Mock()
|
||||
mock_session = Mock()
|
||||
mock_typedb.core_client.return_value = mock_client
|
||||
mock_client.session.return_value = mock_session
|
||||
yield mock_typedb, mock_client, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def typedb_wrapper(mock_typedb):
|
||||
"""Create a TypeDBWrapper instance with mocked dependencies."""
|
||||
config = TypeDBConfig(
|
||||
uri="test:1729",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
return TypeDBWrapper(config)
|
||||
|
||||
def test_initialization(typedb_wrapper):
|
||||
"""Test TypeDBWrapper initialization."""
|
||||
assert typedb_wrapper.config.uri == "test:1729"
|
||||
assert typedb_wrapper.config.database == "test_db"
|
||||
assert typedb_wrapper.config.username == "test_user"
|
||||
assert typedb_wrapper.config.password == "test_pass"
|
||||
|
||||
def test_connect(typedb_wrapper, mock_typedb):
|
||||
"""Test connection to TypeDB."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
typedb_wrapper._connect()
|
||||
|
||||
mock_typedb.core_client.assert_called_once_with("test:1729")
|
||||
mock_client.session.assert_called_once_with(
|
||||
"test_db",
|
||||
"DATA",
|
||||
"test_user",
|
||||
"test_pass"
|
||||
)
|
||||
|
||||
def test_define_schema(typedb_wrapper, mock_typedb):
|
||||
"""Test schema definition."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
schema = "define person sub entity;"
|
||||
|
||||
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||
mock_transaction.return_value.__enter__.return_value.query.define.return_value = None
|
||||
typedb_wrapper.define_schema(schema)
|
||||
|
||||
mock_transaction.assert_called_once_with("WRITE")
|
||||
mock_transaction.return_value.__enter__.return_value.query.define.assert_called_once_with(schema)
|
||||
|
||||
def test_insert_data(typedb_wrapper, mock_typedb):
|
||||
"""Test data insertion."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
query = "insert $p isa person;"
|
||||
|
||||
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||
mock_transaction.return_value.__enter__.return_value.query.insert.return_value = None
|
||||
typedb_wrapper.insert_data(query)
|
||||
|
||||
mock_transaction.assert_called_once_with("WRITE")
|
||||
mock_transaction.return_value.__enter__.return_value.query.insert.assert_called_once_with(query)
|
||||
|
||||
def test_query_data(typedb_wrapper, mock_typedb):
|
||||
"""Test data querying."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
query = "match $p isa person; get;"
|
||||
mock_result = [Mock()]
|
||||
|
||||
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||
mock_transaction.return_value.__enter__.return_value.query.get.return_value = mock_result
|
||||
result = typedb_wrapper.query_data(query)
|
||||
|
||||
mock_transaction.assert_called_once_with("READ")
|
||||
mock_transaction.return_value.__enter__.return_value.query.get.assert_called_once_with(query)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_delete_data(typedb_wrapper, mock_typedb):
|
||||
"""Test data deletion."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
query = "match $p isa person; delete $p;"
|
||||
|
||||
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||
mock_transaction.return_value.__enter__.return_value.query.delete.return_value = None
|
||||
typedb_wrapper.delete_data(query)
|
||||
|
||||
mock_transaction.assert_called_once_with("WRITE")
|
||||
mock_transaction.return_value.__enter__.return_value.query.delete.assert_called_once_with(query)
|
||||
|
||||
def test_close(typedb_wrapper, mock_typedb):
|
||||
"""Test connection closing."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
typedb_wrapper.close()
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
def test_context_manager(typedb_wrapper, mock_typedb):
|
||||
"""Test context manager functionality."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
|
||||
with typedb_wrapper as db:
|
||||
assert db == typedb_wrapper
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
def test_error_handling(typedb_wrapper, mock_typedb):
|
||||
"""Test error handling."""
|
||||
mock_typedb, mock_client, mock_session = mock_typedb
|
||||
|
||||
# Test connection error
|
||||
mock_typedb.core_client.side_effect = Exception("Connection failed")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
typedb_wrapper._connect()
|
||||
assert "Connection failed" in str(exc_info.value)
|
||||
|
||||
# Test query error
|
||||
with patch.object(typedb_wrapper.session, 'transaction') as mock_transaction:
|
||||
mock_transaction.return_value.__enter__.return_value.query.get.side_effect = Exception("Query failed")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
typedb_wrapper.query_data("test query")
|
||||
assert "Query failed" in str(exc_info.value)
|
Loading…
Reference in new issue