You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/playground/demos/chatbot/server/vector_storage.py

249 lines
8.9 KiB

""" Vector storage using Firecrawl and Redis """
import re
from urllib.parse import urlparse, urljoin
import redis
from firecrawl import FirecrawlApp
from redisvl.utils.vectorize import HFTextVectorizer
from redisvl.index import SearchIndex
from redisvl.schema import IndexSchema
from redisvl.query.filter import Tag
from redisvl.query import VectorQuery, FilterQuery
class RedisVectorStorage:
""" Provides vector storage database operations using Redis """
def __init__(self,
context: str = "swarms",
use_gpu=False,
overwrite=False):
self.use_gpu = use_gpu
self.context = context
# Initialize the FirecrawlApp with your API key
# Or use the default local Firecrawl instance
self.app = FirecrawlApp(
api_key="EMPTY",
api_url="http://localhost:3002") # EMPTY for localhost
# Connect to the local Redis server
self.redis_client = redis.Redis(host='localhost', port=6379, db=0)
# Initialize the huggingface text vectorizer
self.vectorizer = HFTextVectorizer()
index_name = self.context
schema = IndexSchema.from_dict({
"index": {
"name": index_name,
},
"fields": [
{
"name": "id",
"type": "tag",
"attrs": {
"sortable": True
}
},
{
"name": "content",
"type": "text",
"attrs": {
"sortable": True
}
},
{
"name": "content_embedding",
"type": "vector",
"attrs": {
"dims": self.vectorizer.dims,
"distance_metric": "cosine",
"algorithm": "hnsw",
"datatype": "float32"
}
},
{
"name": "source_url",
"type": "text",
"attrs": {
"sortable": True
}
}
]
})
self.schema = schema
self.index = SearchIndex(self.schema, self.redis_client)
self.index.create(overwrite=overwrite, drop=overwrite)
def extract_markdown_links(self, markdown_text):
""" Extract Markdown links from the given markdown text """
pattern = r'\[([^\]]+)\]\(([^)]+?)(?:\s+"[^"]*")?\)'
links = re.findall(pattern, markdown_text)
urls = [link[1] for link in links]
return urls
def is_internal_link(self, url: str, base_domain: str):
""" Check if a URL is internal to the initial domain """
if (url == '\\' or url.startswith("mailto")):
return False
parsed_url = urlparse(url)
return parsed_url.netloc == '' or parsed_url.netloc == base_domain
def split_markdown_content(self, markdown_text, max_length=5000):
""" Split markdown content into chunks of max 5000 characters at
natural breakpoints """
paragraphs = markdown_text.split('\n\n') # Split by paragraphs
chunks = []
current_chunk = ''
for paragraph in paragraphs:
if len(current_chunk) + len(paragraph) > max_length:
chunks.append(current_chunk)
current_chunk = paragraph
else:
current_chunk += '\n\n' + paragraph
if len(paragraph) > max_length:
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
for sentence in sentences:
if len(sentence) > max_length:
chunks.append(sentence[:max_length])
current_chunk = sentence[max_length:]
elif len(current_chunk) + len(sentence) > max_length:
chunks.append(current_chunk)
current_chunk = sentence
else:
current_chunk += ' ' + sentence
if current_chunk:
chunks.append(current_chunk)
return chunks
def store_chunks_in_redis(self, url, chunks):
""" Store chunks and their embeddings in Redis """
parsed_url = urlparse(url)
# Remove scheme (http:// or https://)
trimmed_url = parsed_url.netloc + parsed_url.path
data = []
for i, chunk in enumerate(chunks):
embedding = self.vectorizer.embed(
chunk,
input_type="search_document",
as_buffer=True)
# Prepare the data to be stored in Redis
data.append({
"id": f"{trimmed_url}::chunk::{i+1}",
"content": chunk,
"content_embedding": embedding,
"source_url": trimmed_url
})
# Store the data in Redis
self.index.load(data)
print(f"Stored {len(chunks)} chunks for URL {url} in Redis.")
def crawl_iterative(self, start_url, base_domain):
""" Iteratively crawl a URL and its Markdown links """
visited = set()
stack = [start_url]
while stack:
url = stack.pop()
if url in visited:
continue
parsed_url = urlparse(url)
# Remove scheme (http:// or https://)
trimmed_url = parsed_url.netloc + parsed_url.path
# Check if the URL has already been processed
# Use the original URL format
t = Tag("id") == f"{trimmed_url}::chunk::1"
# Use a simple filter query instead of a vector query
filter_query = FilterQuery(filter_expression=t)
results = self.index.query(filter_query)
if results:
print(f"URL {url} has already been processed. Skipping.")
visited.add(url)
continue
print(f"Crawling URL: {url}")
params = {
'pageOptions': {
'onlyMainContent': False,
'fetchPageContent': True,
'includeHTML': False,
}
}
crawl_result = []
if self.is_internal_link(url, base_domain) and url not in visited:
crawl_result.append(self.app.scrape_url(url, params=params))
visited.add(url)
for result in crawl_result:
markdown_content = result["markdown"]
result_url = result["metadata"]["sourceURL"]
print("Markdown sourceURL: " + result_url)
# print("Content:\n\n")
# print(markdown_content)
# print("\n\n")
# Split markdown content into natural chunks
chunks = self.split_markdown_content(markdown_content)
# Store the chunks and their embeddings in Redis
self.store_chunks_in_redis(result_url, chunks)
links = self.extract_markdown_links(markdown_content)
print("Extracted Links:", links)
# print("Extracted Links:", links)
for link in links:
absolute_link = urljoin(result_url, link)
if self.is_internal_link(absolute_link, base_domain):
if absolute_link not in visited:
stack.append(absolute_link)
print("Appended link: " + absolute_link)
else:
visited.add(absolute_link)
def crawl(self, crawl_url: str):
""" Start the iterative crawling from the initial URL """
base_domain = urlparse(crawl_url).netloc
self.crawl_iterative(crawl_url, base_domain)
def embed(self, query: str, num_results: int = 3):
""" Embed a string and perform a Redis vector database query """
query_embedding = self.vectorizer.embed(query)
vector_query = VectorQuery(
vector=query_embedding,
vector_field_name="content_embedding",
num_results=num_results,
return_fields=["id", "content", "source_url"],
return_score=True
)
# show the raw redis query
results = self.index.query(vector_query)
return results
if __name__ == "__main__":
storage = RedisVectorStorage(overwrite=False)
storage.crawl("https://docs.swarms.world/en/latest/")
responses = storage.embed(
"What is Swarms, and how do I install swarms?", 5)
for response in responses:
encoded_id = response['id'] # Access the 'id' field directly
source_url = response['source_url']
print(f"Decoded ID: {encoded_id}, Source URL: {source_url}")