diff --git a/playground/demos/chatbot/server/server.py b/playground/demos/chatbot/server/server.py index 7773fbf7..b718d493 100644 --- a/playground/demos/chatbot/server/server.py +++ b/playground/demos/chatbot/server/server.py @@ -1,14 +1,13 @@ """ Chatbot with RAG Server """ - import asyncio import logging import os - +from urllib.parse import urlparse, urljoin # import torch from contextlib import asynccontextmanager from typing import AsyncIterator from swarms.structs.agent import Agent -import tiktoken + from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware @@ -29,7 +28,7 @@ from swarms.prompts.conversational_RAG import ( ) from playground.demos.chatbot.server.responses import StreamingResponse from playground.demos.chatbot.server.server_models import ChatRequest -from playground.demos.chatbot.server.vector_store import VectorStorage +from playground.demos.chatbot.server.vector_storage import RedisVectorStorage from swarms.models.popular_llms import OpenAIChatLLM # Explicitly specify the path to the .env file @@ -42,13 +41,15 @@ load_dotenv(dotenv_path) hf_token = os.environ.get( "HUGGINFACEHUB_API_KEY" ) # Get the Huggingface API Token + uploads = os.environ.get( "UPLOADS" ) # Directory where user uploads files to be parsed for RAG -model_dir = os.environ.get("MODEL_DIR") + +# model_dir = os.environ.get("MODEL_DIR") # hugginface.co model (eg. meta-llama/Llama-2-70b-hf) -model_name = os.environ.get("MODEL_NAME") +# model_name = os.environ.get("MODEL_NAME") # Set OpenAI's API key to 'EMPTY' and API base URL to use vLLM's API server # or set them to OpenAI API key and base URL. @@ -60,8 +61,6 @@ openai_api_base = ( env_vars = [ hf_token, uploads, - model_dir, - model_name, openai_api_key, openai_api_base, ] @@ -77,31 +76,22 @@ useMetal = os.environ.get("USE_METAL", "False") == "True" use_gpu = os.environ.get("USE_GPU", "False") == "True" print(f"Uploads={uploads}") -print(f"MODEL_DIR={model_dir}") -print(f"MODEL_NAME={model_name}") print(f"USE_METAL={useMetal}") print(f"USE_GPU={use_gpu}") print(f"OPENAI_API_KEY={openai_api_key}") print(f"OPENAI_API_BASE={openai_api_base}") -# update tiktoken to include the model name (avoids warning message) -tiktoken.model.MODEL_TO_ENCODING.update( - { - model_name: "cl100k_base", - } -) +# # update tiktoken to include the model name (avoids warning message) +# tiktoken.model.MODEL_TO_ENCODING.update( +# { +# model_name: "cl100k_base", +# } +# ) print("Logging in to huggingface.co...") login(token=hf_token) # login to huggingface.co -@asynccontextmanager -async def lifespan(app: FastAPI): - """Initializes the vector store in a background task.""" - asyncio.create_task(vector_store.init_retrievers()) - yield - - -app = FastAPI(title="Chatbot", lifespan=lifespan) +app = FastAPI(title="Chatbot") router = APIRouter() current_dir = os.path.dirname(__file__) @@ -125,11 +115,15 @@ if not os.path.exists(uploads): os.makedirs(uploads) # Initialize the vector store -vector_store = VectorStorage(directory=uploads, use_gpu=use_gpu) - +# Hardcoded for Swarms documention +URL = "https://docs.swarms.world/en/latest/" +vector_store = RedisVectorStorage(use_gpu=use_gpu) +vector_store.crawl(URL) +print("Vector storage initialized.") async def create_chat( messages: list[Message], + model_name: str, prompt: str = QA_PROMPT_TEMPLATE_STR, ): """Creates the RAG conversational retrieval chain.""" @@ -143,7 +137,6 @@ async def create_chat( streaming=True, ) - retriever = await vector_store.get_retriever("swarms") doc_retrieval_string = "" for message in messages: if message.role == Role.HUMAN: @@ -151,11 +144,15 @@ async def create_chat( elif message.role == Role.AI: doc_retrieval_string += f"{Role.AI}: {message.content}\r\n" - docs = retriever.invoke(doc_retrieval_string) + docs = vector_store.embed(messages[-1].content) # find {context} in prompt and replace it with the docs page_content. # Concatenate the content of all documents - context = "\n".join(doc.page_content for doc in docs) + context = "\n".join(doc["content"] for doc in docs) + + sources = [urlparse(URL).scheme + "://" + doc["source_url"] for doc in docs] + + print(f"context: {context}") # Replace {context} in the prompt with the concatenated document content prompt = prompt.replace("{context}", context) @@ -188,7 +185,7 @@ async def create_chat( # sop="Calculate the profit for a company.", # sop_list=["Calculate the profit for a company."], user_name="RAH@EntangleIT.com", - docs=[doc.page_content for doc in docs], + # docs=[doc["content"] for doc in docs], # # docs_folder="docs", retry_attempts=3, # context_length=1000, @@ -204,9 +201,17 @@ async def create_chat( agent.add_message_to_memory(message.content) elif message.role == Role.AI: agent.add_message_to_memory(message.content) + + # add docs to short term memory + # for data in [doc["content"] for doc in docs]: + # agent.add_message_to_memory(role=Role.HUMAN, content=data) async for response in agent.run_async(messages[-1].content): - yield response + res = response + res += "\n\nSources:\n" + for source in sources: + res += source + "\n" + yield res # memory = ConversationBufferMemory( # chat_memory=chat_memory, @@ -260,7 +265,8 @@ async def chat(request: ChatRequest): """ Handles chatbot chat POST requests """ response = create_chat( messages=request.messages, - prompt=request.prompt.strip() + prompt=request.prompt.strip(), + model_name=request.model.id ) # return response return StreamingResponse(content=response)