Ruff lint fixes

pull/570/head
Richard Hein 8 months ago
parent de8edca6d4
commit d30f5f8259

@ -1,15 +1,12 @@
""" Customized Langchain StreamingResponse for Server-Side Events (SSE) """
import asyncio
from functools import partial
from typing import Any, AsyncIterator
from typing import Any, AsyncIterator
from fastapi import status
from sse_starlette import ServerSentEvent
from sse_starlette.sse import EventSourceResponse, ensure_bytes
from starlette.types import Send
class StreamingResponse(EventSourceResponse):
"""`Response` class for streaming server-sent events.
@ -52,7 +49,11 @@ class StreamingResponse(EventSourceResponse):
chunk = ensure_bytes(data, self.sep)
print(f"chunk: {chunk.decode()}")
await send(
{"type": "http.response.body", "body": chunk, "more_body": True}
{
"type": "http.response.body",
"body": chunk,
"more_body": True
}
)
except Exception as e:
print(f"body iterator error: {e}")
@ -71,7 +72,13 @@ class StreamingResponse(EventSourceResponse):
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})
await send(
{
"type": "http.response.body",
"body": b"",
"more_body": False
}
)
def enable_compression(self, force: bool = False):
raise NotImplementedError

@ -1,36 +1,24 @@
""" Chatbot with RAG Server """
import asyncio
import logging
import os
from urllib.parse import urlparse, urljoin
# import torch
from contextlib import asynccontextmanager
from typing import AsyncIterator
from urllib.parse import urlparse
from swarms.structs.agent import Agent
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.responses import JSONResponse
from fastapi.routing import APIRouter
from fastapi.staticfiles import StaticFiles
from huggingface_hub import login
from swarms.prompts.chat_prompt import Message, Role
from swarms.prompts.conversational_RAG import (
B_INST,
B_SYS,
CONDENSE_PROMPT_TEMPLATE,
DOCUMENT_PROMPT_TEMPLATE,
E_INST,
E_SYS,
QA_PROMPT_TEMPLATE_STR,
)
from swarms.prompts.conversational_RAG import QA_PROMPT_TEMPLATE_STR
from playground.demos.chatbot.server.responses import StreamingResponse
from playground.demos.chatbot.server.server_models import ChatRequest
from playground.demos.chatbot.server.vector_storage import RedisVectorStorage
from swarms.models.popular_llms import OpenAIChatLLM
logging.basicConfig(level=logging.ERROR)
# Explicitly specify the path to the .env file
# Two folders above the current file's directory
dotenv_path = os.path.join(
@ -68,7 +56,8 @@ missing_vars = [var for var in env_vars if not var]
if missing_vars:
print(
f"Error: The following environment variables are not set: {', '.join(missing_vars)}"
"Error: The following environment variables are not set: "
+ ", ".join(missing_vars)
)
exit(1)
@ -80,17 +69,10 @@ print(f"USE_METAL={useMetal}")
print(f"USE_GPU={use_gpu}")
print(f"OPENAI_API_KEY={openai_api_key}")
print(f"OPENAI_API_BASE={openai_api_base}")
# # update tiktoken to include the model name (avoids warning message)
# tiktoken.model.MODEL_TO_ENCODING.update(
# {
# model_name: "cl100k_base",
# }
# )
print("Logging in to huggingface.co...")
login(token=hf_token) # login to huggingface.co
app = FastAPI(title="Chatbot")
router = APIRouter()
@ -121,6 +103,7 @@ vector_store = RedisVectorStorage(use_gpu=use_gpu)
vector_store.crawl(URL)
print("Vector storage initialized.")
async def create_chat(
messages: list[Message],
model_name: str,
@ -146,22 +129,10 @@ async def create_chat(
docs = vector_store.embed(messages[-1].content)
# find {context} in prompt and replace it with the docs page_content.
# Concatenate the content of all documents
context = "\n".join(doc["content"] for doc in docs)
sources = [urlparse(URL).scheme + "://" + doc["source_url"] for doc in docs]
print(f"context: {context}")
# Replace {context} in the prompt with the concatenated document content
prompt = prompt.replace("{context}", context)
# Replace {chat_history} in the prompt with doc_retrieval_string
prompt = prompt.replace("{chat_history}", doc_retrieval_string)
# Replace {question} in the prompt with the last message.
prompt = prompt.replace("{question}", messages[-1].content)
sources = [
urlparse(URL).scheme + "://" + doc["source_url"]
for doc in docs
]
# Initialize the agent
agent = Agent(
@ -196,6 +167,7 @@ async def create_chat(
# agent_ops_on=True,
)
# add chat history messages to short term memory
for message in messages[:-1]:
if message.role == Role.HUMAN:
agent.add_message_to_memory(message.content)
@ -203,8 +175,8 @@ async def create_chat(
agent.add_message_to_memory(message.content)
# add docs to short term memory
# for data in [doc["content"] for doc in docs]:
# agent.add_message_to_memory(role=Role.HUMAN, content=data)
for data in [doc["content"] for doc in docs]:
agent.add_message_to_memory(role=Role.HUMAN, content=data)
async for response in agent.run_async(messages[-1].content):
res = response
@ -213,48 +185,6 @@ async def create_chat(
res += source + "\n"
yield res
# memory = ConversationBufferMemory(
# chat_memory=chat_memory,
# memory_key="chat_history",
# input_key="question",
# output_key="answer",
# return_messages=True,
# )
# question_generator = LLMChain(
# llm=llm,
# prompt=CONDENSE_PROMPT_TEMPLATE,
# memory=memory,
# verbose=True,
# output_key="answer",
# )
# stuff_chain = LLMChain(
# llm=llm,
# prompt=prompt,
# verbose=True,
# output_key="answer",
# )
# doc_chain = StuffDocumentsChain(
# llm_chain=stuff_chain,
# document_variable_name="context",
# document_prompt=DOCUMENT_PROMPT_TEMPLATE,
# verbose=True,
# output_key="answer",
# memory=memory,
# )
# return ConversationalRetrievalChain(
# combine_docs_chain=doc_chain,
# memory=memory,
# retriever=retriever,
# question_generator=question_generator,
# return_generated_question=False,
# return_source_documents=True,
# output_key="answer",
# verbose=True,
# )
@app.post(
"/chat",
@ -268,25 +198,8 @@ async def chat(request: ChatRequest):
prompt=request.prompt.strip(),
model_name=request.model.id
)
# return response
return StreamingResponse(content=response)
# json_config = {
# "question": request.messages[-1].content,
# "chat_history": [
# message.content for message in request.messages[:-1]
# ],
# # "callbacks": [
# # StreamingStdOutCallbackHandler(),
# # TokenStreamingCallbackHandler(output_key="answer"),
# # SourceDocumentsStreamingCallbackHandler(),
# # ],
# }
# return LangchainStreamingResponse(
# chain=chain,
# config=json_config,
# run_mode="async"
# )
@app.get("/")
def root():
@ -294,22 +207,6 @@ def root():
return {"message": "Swarms Chatbot API"}
@app.get("/favicon.ico")
def favicon():
""" Returns a favicon """
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
return FileResponse(
path=file_path,
headers={
"Content-Disposition": "attachment; filename=" + file_name
},
)
logging.basicConfig(level=logging.ERROR)
@app.exception_handler(HTTPException)
async def http_exception_handler(r: Request, exc: HTTPException):
"""Log and return exception details in response."""

@ -5,6 +5,7 @@ from strenum import StrEnum
from pydantic import BaseModel
from swarms.prompts import QA_PROMPT_TEMPLATE_STR as DefaultSystemPrompt
class AIModel(BaseModel):
""" Defines the model a user selected. """
id: str
@ -40,7 +41,7 @@ class Message(BaseModel):
class ChatRequest(BaseModel):
""" The model for a ChatRequest expected by the Chatbot Chat POST endpoint. """
""" The model for a ChatRequest for theChatbot Chat POST endpoint"""
id: str
model: AIModel = AIModel(
id="NousResearch/Meta-Llama-3-8B-Instruct",

@ -9,12 +9,18 @@ from redisvl.schema import IndexSchema
from redisvl.query.filter import Tag
from redisvl.query import VectorQuery, FilterQuery
class RedisVectorStorage:
""" Provides vector storage database operations using Redis """
def __init__(self, context: str="swarms", use_gpu=False, overwrite=False):
def __init__(self,
context: str = "swarms",
use_gpu=False,
overwrite=False):
self.use_gpu = use_gpu
self.context = context
# Initialize the FirecrawlApp with your API key
# Or use the default local Firecrawl instance
self.app = FirecrawlApp(
api_key="EMPTY",
api_url="http://localhost:3002") # EMPTY for localhost
@ -22,7 +28,7 @@ class RedisVectorStorage:
# Connect to the local Redis server
self.redis_client = redis.Redis(host='localhost', port=6379, db=0)
# Initialize the Cohere text vectorizer
# Initialize the huggingface text vectorizer
self.vectorizer = HFTextVectorizer()
index_name = self.context
@ -85,7 +91,8 @@ class RedisVectorStorage:
return parsed_url.netloc == '' or parsed_url.netloc == base_domain
def split_markdown_content(self, markdown_text, max_length=5000):
""" Split markdown content into chunks of max 5000 characters at natural breakpoints """
""" Split markdown content into chunks of max 5000 characters at
natural breakpoints """
paragraphs = markdown_text.split('\n\n') # Split by paragraphs
chunks = []
current_chunk = ''
@ -117,17 +124,26 @@ class RedisVectorStorage:
def store_chunks_in_redis(self, url, chunks):
""" Store chunks and their embeddings in Redis """
parsed_url = urlparse(url)
trimmed_url = parsed_url.netloc + parsed_url.path # Remove scheme (http:// or https://)
# Remove scheme (http:// or https://)
trimmed_url = parsed_url.netloc + parsed_url.path
data = []
for i, chunk in enumerate(chunks):
embedding = self.vectorizer.embed(chunk, input_type="search_document", as_buffer=True)
embedding = self.vectorizer.embed(
chunk,
input_type="search_document",
as_buffer=True)
# Prepare the data to be stored in Redis
data.append({
"id": f"{trimmed_url}::chunk::{i+1}",
"content": chunk,
"content_embedding": embedding,
"source_url": trimmed_url
})
# Store the data in Redis
self.index.load(data)
print(f"Stored {len(chunks)} chunks for URL {url} in Redis.")
@ -142,10 +158,13 @@ class RedisVectorStorage:
continue
parsed_url = urlparse(url)
trimmed_url = parsed_url.netloc + parsed_url.path # Remove scheme (http:// or https://)
# Remove scheme (http:// or https://)
trimmed_url = parsed_url.netloc + parsed_url.path
# Check if the URL has already been processed
t = Tag("id") == f"{trimmed_url}::chunk::1" # Use the original URL format
# Use the original URL format
t = Tag("id") == f"{trimmed_url}::chunk::1"
# Use a simple filter query instead of a vector query
filter_query = FilterQuery(filter_expression=t)
@ -166,7 +185,7 @@ class RedisVectorStorage:
}
crawl_result = []
if self.is_internal_link(url, base_domain) and not url in visited:
if self.is_internal_link(url, base_domain) and url not in visited:
crawl_result.append(self.app.scrape_url(url, params=params))
visited.add(url)
@ -217,10 +236,12 @@ class RedisVectorStorage:
results = self.index.query(vector_query)
return results
if __name__ == "__main__":
storage = RedisVectorStorage(overwrite=False)
storage.crawl("https://docs.swarms.world/en/latest/")
responses = storage.embed("What is Swarms, and how do I install swarms?", 5)
responses = storage.embed(
"What is Swarms, and how do I install swarms?", 5)
for response in responses:
encoded_id = response['id'] # Access the 'id' field directly
source_url = response['source_url']

@ -1,276 +0,0 @@
""" Vector storage with RAG (Retrieval Augmented Generation) support for Markdown."""
import asyncio
import glob
import json
import os
from datetime import datetime
from typing import Dict
from chromadb.config import Settings
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.schema import BaseRetriever
from langchain.storage import LocalFileStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from playground.demos.chatbot.server.async_parent_document_retriever import \
AsyncParentDocumentRetriever
STORE_TYPE = "local" # "redis" or "local"
class VectorStorage:
"""Vector storage class handles loading documents from a given directory."""
def __init__(self, directory, use_gpu=False):
self.embeddings = HuggingFaceBgeEmbeddings(
cache_folder="./.embeddings",
model_name="BAAI/bge-large-en",
model_kwargs={"device": "cuda" if use_gpu else "cpu"},
encode_kwargs={"normalize_embeddings": True},
query_instruction="Represent this sentence for searching relevant passages: ",
)
self.directory = directory
self.child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=20
)
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=200
)
if STORE_TYPE == "redis":
from langchain_community.storage import RedisStore
from langchain_community.storage.redis import get_client
username = r"username"
password = r"password"
client = get_client(
redis_url=f"redis://{username}:{password}@localhost:6239"
)
self.store = RedisStore(client=client)
else:
self.store = LocalFileStore(root_path="./.parent_documents")
self.settings = Settings(
persist_directory="./.chroma_db",
is_persistent=True,
anonymized_telemetry=False,
)
# create a new vectorstore or get an existing one, with default collection
self.vectorstore = self.get_vector_store()
self.client = self.vectorstore._client
self.retrievers: Dict[str, BaseRetriever] = {}
# default retriever for when no collection title is specified
self.retrievers["swarms"] = self.vectorstore.as_retriever()
async def init_retrievers(self, directories: list[str] | None = None):
"""Initializes the vector storage retrievers."""
start_time = datetime.now()
print(f"Start vectorstore initialization time: {start_time}")
# for each subdirectory in the directory, create a new collection if it doesn't exist
dirs = directories or os.listdir(self.directory)
# make sure the subdir is not a file on MacOS (which has a hidden .DS_Store file)
dirs = [
subdir
for subdir in dirs
if not os.path.isfile(f"{self.directory}/{subdir}")
]
print(f"{len(dirs)} subdirectories to load: {dirs}")
self.retrievers[self.directory] = await self.init_retriever(
self.directory
)
end_time = datetime.now()
print("Vectorstore initialization complete.")
print(f"Vectorstore initialization end time: {end_time}")
print(f"Total time taken: {end_time - start_time}")
return self.retrievers
async def init_retriever(self, subdir: str) -> BaseRetriever:
""" Initialize each retriever. """
# Ensure only one process/thread is executing this method at a time
lock = asyncio.Lock()
async with lock:
subdir_start_time = datetime.now()
print(f"Start {subdir} processing time: {subdir_start_time}")
# get all existing collections
collections = self.client.list_collections()
print(f"Existing collections: {collections}")
# Initialize an empty list to hold the documents
documents = []
# Define the maximum number of files to load at a time
max_files = 1000
# Load existing metadata
metadata_file = f"{self.directory}/metadata.json"
metadata = {
"processDate": str(datetime.now()),
"processed_files": [],
}
processed_files = set() # Track processed files
if os.path.isfile(metadata_file):
with open(
metadata_file, "r",
) as metadata_file_handle:
metadata = dict[str, str](json.load(metadata_file_handle))
processed_files = {
entry["file"]
for entry in metadata.get("processed_files", [])
}
# Get a list of all files in the directory and exclude processed files
all_files = [
file
for file in glob.glob(
f"{self.directory}/**/*.md", recursive=True
)
if file not in processed_files
]
print(
f"Loading {len(all_files)} documents for title version {subdir}."
)
# Load files in chunks of max_files
for i in range(0, len(all_files), max_files):
chunks_start_time = datetime.now()
chunk_files = all_files[i : i + max_files]
for file in chunk_files:
loader = UnstructuredMarkdownLoader(
file, mode="single", strategy="fast"
)
print(f"Loaded {file} in {subdir} ...")
documents.extend(loader.load())
# Record the file as processed in metadata
metadata["processed_files"].append(
{"file": file, "processed_at": str(datetime.now())}
)
print(
f"Creating new collection for {self.directory}..."
)
# Create or get the collection
collection = self.client.create_collection(
name=self.directory,
get_or_create=True,
metadata={"processDate": metadata["processDate"]},
)
# Reload vectorstore based on collection
vectorstore = self.get_vector_store(
collection_name=collection.name
)
# Create a new parent document retriever
retriever = AsyncParentDocumentRetriever(
docstore=self.store,
vectorstore=vectorstore,
child_splitter=self.child_splitter,
parent_splitter=self.parent_splitter,
)
# force reload of collection to make sure we don't have
# the default langchain collection
collection = self.client.get_collection(
name=self.directory
)
vectorstore = self.get_vector_store(
collection_name=self.directory
)
# Add documents to the collection and docstore
print(
f"Adding {len(documents)} documents to collection..."
)
add_docs_start_time = datetime.now()
await retriever.aadd_documents(
documents=documents, add_to_docstore=True
)
add_docs_end_time = datetime.now()
total_time = add_docs_end_time - add_docs_start_time
print(
f"Adding {len(documents)} documents to collection took: {total_time}"
)
documents = [] # clear documents list for next chunk
# Save metadata to the metadata.json file
with open(
metadata_file, "w"
) as metadata_file_handle: # Changed variable name here
json.dump(metadata, metadata_file_handle, indent=4)
print(
f"Loaded {len(documents)} documents for directory '{subdir}'."
)
chunks_end_time = datetime.now()
chunk_time = chunks_end_time - chunks_start_time
print(
f"{max_files} markdown file chunks processing time: {chunk_time}"
)
subdir_end_time = datetime.now()
print(
f"Subdir {subdir} processing end time: {subdir_end_time}"
)
print(f"Time taken: {subdir_end_time - subdir_start_time}")
# Reload vectorstore based on collection to pass to parent doc retriever
# collection = self.client.get_collection(name=self.directory)
vectorstore = self.get_vector_store()
retriever = AsyncParentDocumentRetriever(
docstore=self.store,
vectorstore=vectorstore,
child_splitter=self.child_splitter,
parent_splitter=self.parent_splitter,
)
return retriever
def get_vector_store(self, collection_name: str | None = None) -> Chroma:
""" get a specific vector store for a collection """
if collection_name is None or "" or "None":
collection_name = "swarms"
print("collection_name: " + collection_name)
vectorstore = Chroma(
client_settings=self.settings,
embedding_function=self.embeddings,
collection_name=collection_name,
)
return vectorstore
def list_collections(self):
""" Get a list of all collections in the vectorstore """
vectorstore = Chroma(
client_settings=self.settings,
embedding_function=self.embeddings,
)
return vectorstore._client.list_collections()
async def get_retriever(self, collection_name: str | None = None):
""" get a specific retriever for a collection in the vectorstore """
if self.retrievers is None:
self.retrievers = await self.init_retrievers()
if (
collection_name is None
or collection_name == ""
or collection_name == "None"
):
name = "swarms"
else:
name = collection_name
try:
retriever = self.retrievers[name]
except KeyError:
print(f"Retriever for {name} not found, using default...")
retriever = self.retrievers[
"swarms"
]
return retriever
Loading…
Cancel
Save