You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
387 lines
13 KiB
387 lines
13 KiB
from typing import List, Dict, Optional, Union, Any
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
import numpy as np
|
|
from scipy.sparse import csr_matrix
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
from sentence_transformers import SentenceTransformer
|
|
import faiss
|
|
import pickle
|
|
import time
|
|
from loguru import logger
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import threading
|
|
import uuid
|
|
|
|
@dataclass
|
|
class Document:
|
|
"""Represents a document in the HQD-RAG system.
|
|
|
|
Attributes:
|
|
id (str): Unique identifier for the document
|
|
content (str): Raw text content of the document
|
|
embedding (Optional[np.ndarray]): Quantum-inspired embedding vector
|
|
cluster_id (Optional[int]): ID of the cluster this document belongs to
|
|
"""
|
|
id: str
|
|
content: str
|
|
embedding: Optional[np.ndarray] = None
|
|
cluster_id: Optional[int] = None
|
|
|
|
class HQDRAG:
|
|
"""
|
|
Hierarchical Quantum-Inspired Distributed RAG (HQD-RAG) System
|
|
|
|
A production-grade implementation of the HQD-RAG algorithm for ultra-fast
|
|
and reliable document retrieval. Uses quantum-inspired embeddings and
|
|
hierarchical clustering for efficient search.
|
|
|
|
Attributes:
|
|
embedding_dim (int): Dimension of the quantum-inspired embeddings
|
|
num_clusters (int): Number of hierarchical clusters
|
|
similarity_threshold (float): Threshold for quantum similarity matching
|
|
reliability_threshold (float): Threshold for reliability verification
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int = 768,
|
|
num_clusters: int = 128,
|
|
similarity_threshold: float = 0.75,
|
|
reliability_threshold: float = 0.85,
|
|
model_name: str = "all-MiniLM-L6-v2"
|
|
):
|
|
"""Initialize the HQD-RAG system.
|
|
|
|
Args:
|
|
embedding_dim: Dimension of document embeddings
|
|
num_clusters: Number of clusters for hierarchical organization
|
|
similarity_threshold: Minimum similarity score for retrieval
|
|
reliability_threshold: Minimum reliability score for verification
|
|
model_name: Name of the sentence transformer model to use
|
|
"""
|
|
logger.info(f"Initializing HQD-RAG with {embedding_dim} dimensions")
|
|
|
|
self.embedding_dim = embedding_dim
|
|
self.num_clusters = num_clusters
|
|
self.similarity_threshold = similarity_threshold
|
|
self.reliability_threshold = reliability_threshold
|
|
|
|
# Initialize components
|
|
self.documents: Dict[str, Document] = {}
|
|
self.encoder = SentenceTransformer(model_name)
|
|
self.index = faiss.IndexFlatIP(embedding_dim) # Inner product index
|
|
self.clustering = AgglomerativeClustering(
|
|
n_clusters=num_clusters,
|
|
metric='euclidean',
|
|
linkage='ward'
|
|
)
|
|
|
|
# Thread safety
|
|
self._lock = threading.Lock()
|
|
self._executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
logger.info("HQD-RAG system initialized successfully")
|
|
|
|
def _compute_quantum_embedding(self, text: str) -> np.ndarray:
|
|
"""Compute quantum-inspired embedding for text.
|
|
|
|
Args:
|
|
text: Input text to embed
|
|
|
|
Returns:
|
|
Quantum-inspired embedding vector
|
|
"""
|
|
# Get base embedding
|
|
base_embedding = self.encoder.encode([text])[0]
|
|
|
|
# Apply quantum-inspired transformation
|
|
# Simulate superposition by adding phase components
|
|
phase = np.exp(2j * np.pi * np.random.random(self.embedding_dim))
|
|
quantum_embedding = base_embedding * phase
|
|
|
|
# Normalize to unit length
|
|
return quantum_embedding / np.linalg.norm(quantum_embedding)
|
|
|
|
def _verify_reliability(self, doc: Document, query_embedding: np.ndarray) -> float:
|
|
"""Verify the reliability of a document match.
|
|
|
|
Args:
|
|
doc: Document to verify
|
|
query_embedding: Query embedding vector
|
|
|
|
Returns:
|
|
Reliability score between 0 and 1
|
|
"""
|
|
if doc.embedding is None:
|
|
return 0.0
|
|
|
|
# Compute consistency score
|
|
consistency = np.abs(np.dot(doc.embedding, query_embedding))
|
|
|
|
# Add quantum noise resistance check
|
|
noise = np.random.normal(0, 0.1, self.embedding_dim)
|
|
noisy_query = query_embedding + noise
|
|
noisy_query = noisy_query / np.linalg.norm(noisy_query)
|
|
noise_resistance = np.abs(np.dot(doc.embedding, noisy_query))
|
|
|
|
return (consistency + noise_resistance) / 2
|
|
|
|
def add(self, content: str, doc_id: Optional[str] = None) -> str:
|
|
"""Add a document to the system.
|
|
|
|
Args:
|
|
content: Document text content
|
|
doc_id: Optional custom document ID
|
|
|
|
Returns:
|
|
Document ID
|
|
"""
|
|
doc_id = doc_id or str(uuid.uuid4())
|
|
|
|
with self._lock:
|
|
try:
|
|
# Compute embedding
|
|
embedding = self._compute_quantum_embedding(content)
|
|
|
|
# Create document
|
|
doc = Document(
|
|
id=doc_id,
|
|
content=content,
|
|
embedding=embedding
|
|
)
|
|
|
|
# Add to storage
|
|
self.documents[doc_id] = doc
|
|
self.index.add(embedding.reshape(1, -1))
|
|
|
|
# Update clustering
|
|
self._update_clusters()
|
|
|
|
logger.info(f"Successfully added document {doc_id}")
|
|
return doc_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding document: {str(e)}")
|
|
raise
|
|
|
|
def query(
|
|
self,
|
|
query: str,
|
|
k: int = 5,
|
|
return_scores: bool = False
|
|
) -> Union[List[str], List[tuple[str, float]]]:
|
|
"""Query the system for relevant documents.
|
|
|
|
Args:
|
|
query: Query text
|
|
k: Number of results to return
|
|
return_scores: Whether to return similarity scores
|
|
|
|
Returns:
|
|
List of document IDs or (document ID, score) tuples
|
|
"""
|
|
try:
|
|
# Compute query embedding
|
|
query_embedding = self._compute_quantum_embedding(query)
|
|
|
|
# Search index
|
|
scores, indices = self.index.search(
|
|
query_embedding.reshape(1, -1),
|
|
k * 2 # Get extra results for reliability filtering
|
|
)
|
|
|
|
results = []
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
# Get document
|
|
doc_id = list(self.documents.keys())[idx]
|
|
doc = self.documents[doc_id]
|
|
|
|
# Verify reliability
|
|
reliability = self._verify_reliability(doc, query_embedding)
|
|
|
|
if reliability >= self.reliability_threshold:
|
|
results.append((doc_id, float(score)))
|
|
|
|
if len(results) >= k:
|
|
break
|
|
|
|
logger.info(f"Query returned {len(results)} results")
|
|
|
|
if return_scores:
|
|
return results
|
|
return [doc_id for doc_id, _ in results]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing query: {str(e)}")
|
|
raise
|
|
|
|
def update(self, doc_id: str, new_content: str) -> None:
|
|
"""Update an existing document.
|
|
|
|
Args:
|
|
doc_id: ID of document to update
|
|
new_content: New document content
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
if doc_id not in self.documents:
|
|
raise KeyError(f"Document {doc_id} not found")
|
|
|
|
# Remove old embedding
|
|
old_doc = self.documents[doc_id]
|
|
if old_doc.embedding is not None:
|
|
self.index.remove_ids(np.array([list(self.documents.keys()).index(doc_id)]))
|
|
|
|
# Compute new embedding
|
|
new_embedding = self._compute_quantum_embedding(new_content)
|
|
|
|
# Update document
|
|
self.documents[doc_id] = Document(
|
|
id=doc_id,
|
|
content=new_content,
|
|
embedding=new_embedding
|
|
)
|
|
|
|
# Add new embedding
|
|
self.index.add(new_embedding.reshape(1, -1))
|
|
|
|
# Update clustering
|
|
self._update_clusters()
|
|
|
|
logger.info(f"Successfully updated document {doc_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating document: {str(e)}")
|
|
raise
|
|
|
|
def delete(self, doc_id: str) -> None:
|
|
"""Delete a document from the system.
|
|
|
|
Args:
|
|
doc_id: ID of document to delete
|
|
"""
|
|
with self._lock:
|
|
try:
|
|
if doc_id not in self.documents:
|
|
raise KeyError(f"Document {doc_id} not found")
|
|
|
|
# Remove from index
|
|
idx = list(self.documents.keys()).index(doc_id)
|
|
self.index.remove_ids(np.array([idx]))
|
|
|
|
# Remove from storage
|
|
del self.documents[doc_id]
|
|
|
|
# Update clustering
|
|
self._update_clusters()
|
|
|
|
logger.info(f"Successfully deleted document {doc_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting document: {str(e)}")
|
|
raise
|
|
|
|
def _update_clusters(self) -> None:
|
|
"""Update hierarchical document clusters."""
|
|
if len(self.documents) < 2:
|
|
return
|
|
|
|
# Get all embeddings
|
|
embeddings = np.vstack([
|
|
doc.embedding for doc in self.documents.values()
|
|
if doc.embedding is not None
|
|
])
|
|
|
|
# Update clustering
|
|
clusters = self.clustering.fit_predict(embeddings)
|
|
|
|
# Assign cluster IDs
|
|
for doc, cluster_id in zip(self.documents.values(), clusters):
|
|
doc.cluster_id = int(cluster_id)
|
|
|
|
def save(self, path: Union[str, Path]) -> None:
|
|
"""Save the system state to disk.
|
|
|
|
Args:
|
|
path: Path to save directory
|
|
"""
|
|
path = Path(path)
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
# Save documents
|
|
with open(path / "documents.pkl", "wb") as f:
|
|
pickle.dump(self.documents, f)
|
|
|
|
# Save index
|
|
faiss.write_index(self.index, str(path / "index.faiss"))
|
|
|
|
logger.info(f"Successfully saved system state to {path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving system state: {str(e)}")
|
|
raise
|
|
|
|
def load(self, path: Union[str, Path]) -> None:
|
|
"""Load the system state from disk.
|
|
|
|
Args:
|
|
path: Path to save directory
|
|
"""
|
|
path = Path(path)
|
|
|
|
try:
|
|
# Load documents
|
|
with open(path / "documents.pkl", "rb") as f:
|
|
self.documents = pickle.load(f)
|
|
|
|
# Load index
|
|
self.index = faiss.read_index(str(path / "index.faiss"))
|
|
|
|
logger.info(f"Successfully loaded system state from {path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading system state: {str(e)}")
|
|
raise
|
|
|
|
# Example usage:
|
|
if __name__ == "__main__":
|
|
# Configure logging
|
|
logger.add(
|
|
"hqd_rag.log",
|
|
rotation="1 day",
|
|
retention="1 week",
|
|
level="INFO"
|
|
)
|
|
|
|
# Initialize system
|
|
rag = HQDRAG()
|
|
|
|
# Add some documents
|
|
doc_ids = []
|
|
docs = [
|
|
"The quick brown fox jumps over the lazy dog",
|
|
"Machine learning is a subset of artificial intelligence",
|
|
"Python is a popular programming language"
|
|
]
|
|
|
|
for doc in docs:
|
|
doc_id = rag.add(doc)
|
|
doc_ids.append(doc_id)
|
|
|
|
# Query
|
|
results = rag.query("What is machine learning?", return_scores=True)
|
|
print("Query results:", results)
|
|
|
|
# # Update a document
|
|
# rag.update(doc_ids[0], "The fast brown fox jumps over the sleepy dog")
|
|
|
|
# # Delete a document
|
|
# rag.delete(doc_ids[-1])
|
|
|
|
# # Save state
|
|
# rag.save("hqd_rag_state")
|
|
|
|
|
|
|