commit
de716e43d3
@ -1,44 +0,0 @@
|
||||
OPENAI_API_KEY="your_openai_api_key_here"
|
||||
GOOGLE_API_KEY=""
|
||||
ANTHROPIC_API_KEY=""
|
||||
AI21_API_KEY="your_api_key_here"
|
||||
COHERE_API_KEY="your_api_key_here"
|
||||
ALEPHALPHA_API_KEY="your_api_key_here"
|
||||
HUGGINFACEHUB_API_KEY="your_api_key_here"
|
||||
STABILITY_API_KEY="your_api_key_here"
|
||||
|
||||
WOLFRAM_ALPHA_APPID="your_wolfram_alpha_appid_here"
|
||||
ZAPIER_NLA_API_KEY="your_zapier_nla_api_key_here"
|
||||
|
||||
EVAL_PORT=8000
|
||||
MODEL_NAME="gpt-4"
|
||||
CELERY_BROKER_URL="redis://localhost:6379"
|
||||
|
||||
SERVER="http://localhost:8000"
|
||||
USE_GPU=True
|
||||
PLAYGROUND_DIR="playground"
|
||||
|
||||
LOG_LEVEL="INFO"
|
||||
BOT_NAME="Orca"
|
||||
|
||||
WINEDB_HOST="your_winedb_host_here"
|
||||
WINEDB_PASSWORD="your_winedb_password_here"
|
||||
BING_SEARCH_URL="your_bing_search_url_here"
|
||||
|
||||
BING_SUBSCRIPTION_KEY="your_bing_subscription_key_here"
|
||||
SERPAPI_API_KEY="your_serpapi_api_key_here"
|
||||
IFTTTKey="your_iftttkey_here"
|
||||
|
||||
BRAVE_API_KEY="your_brave_api_key_here"
|
||||
SPOONACULAR_KEY="your_spoonacular_key_here"
|
||||
HF_API_KEY="your_huggingface_api_key_here"
|
||||
|
||||
|
||||
REDIS_HOST=
|
||||
REDIS_PORT=
|
||||
|
||||
#dbs
|
||||
PINECONE_API_KEY=""
|
||||
BING_COOKIE=""
|
||||
|
||||
PSG_CONNECTION_STRING=""
|
@ -1,4 +0,0 @@
|
||||
rules:
|
||||
line-length:
|
||||
level: warning
|
||||
allow-non-breakable-inline-mappings: true
|
@ -0,0 +1,81 @@
|
||||
# Qdrant Client Library
|
||||
|
||||
## Overview
|
||||
|
||||
The Qdrant Client Library is designed for interacting with the Qdrant vector database, allowing efficient storage and retrieval of high-dimensional vector data. It integrates with machine learning models for embedding and is particularly suited for search and recommendation systems.
|
||||
|
||||
## Installation
|
||||
|
||||
```python
|
||||
pip install qdrant-client sentence-transformers httpx
|
||||
```
|
||||
|
||||
## Class Definition: Qdrant
|
||||
|
||||
```python
|
||||
class Qdrant:
|
||||
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True):
|
||||
...
|
||||
```
|
||||
|
||||
### Constructor Parameters
|
||||
|
||||
| Parameter | Type | Description | Default Value |
|
||||
|-----------------|---------|--------------------------------------------------|-----------------------|
|
||||
| api_key | str | API key for authentication. | - |
|
||||
| host | str | Host address of the Qdrant server. | - |
|
||||
| port | int | Port number for the Qdrant server. | 6333 |
|
||||
| collection_name | str | Name of the collection to be used or created. | "qdrant" |
|
||||
| model_name | str | Name of the sentence transformer model. | "BAAI/bge-small-en-v1.5" |
|
||||
| https | bool | Flag to use HTTPS for connection. | True |
|
||||
|
||||
### Methods
|
||||
|
||||
#### `_load_embedding_model(model_name: str)`
|
||||
|
||||
Loads the sentence embedding model.
|
||||
|
||||
#### `_setup_collection()`
|
||||
|
||||
Checks if the specified collection exists in Qdrant; if not, creates it.
|
||||
|
||||
#### `add_vectors(docs: List[dict]) -> OperationResponse`
|
||||
|
||||
Adds vectors to the Qdrant collection.
|
||||
|
||||
#### `search_vectors(query: str, limit: int = 3) -> SearchResult`
|
||||
|
||||
Searches the Qdrant collection for vectors similar to the query vector.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Example 1: Setting Up the Qdrant Client
|
||||
|
||||
```python
|
||||
from qdrant_client import Qdrant
|
||||
|
||||
qdrant_client = Qdrant(api_key="your_api_key", host="localhost", port=6333)
|
||||
```
|
||||
|
||||
### Example 2: Adding Vectors to a Collection
|
||||
|
||||
```python
|
||||
documents = [
|
||||
{"page_content": "Sample text 1"},
|
||||
{"page_content": "Sample text 2"}
|
||||
]
|
||||
|
||||
operation_info = qdrant_client.add_vectors(documents)
|
||||
print(operation_info)
|
||||
```
|
||||
|
||||
### Example 3: Searching for Vectors
|
||||
|
||||
```python
|
||||
search_result = qdrant_client.search_vectors("Sample search query")
|
||||
print(search_result)
|
||||
```
|
||||
|
||||
## Further Information
|
||||
|
||||
Refer to the [Qdrant Documentation](https://qdrant.tech/docs) for more details on the Qdrant vector database.
|
@ -0,0 +1,93 @@
|
||||
import random
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from swarms.models import OpenAIChat
|
||||
from playground.models.stable_diffusion import StableDiffusion
|
||||
from swarms.structs import Agent, SequentialWorkflow
|
||||
|
||||
load_dotenv()
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
stability_api_key = os.getenv("STABILITY_API_KEY")
|
||||
|
||||
# Initialize the language model and image generation model
|
||||
llm = OpenAIChat(
|
||||
openai_api_key=openai_api_key, temperature=0.5, max_tokens=3000
|
||||
)
|
||||
sd_api = StableDiffusion(api_key=stability_api_key)
|
||||
|
||||
|
||||
def run_task(description, product_name, agent, **kwargs):
|
||||
full_description = ( # Incorporate product name into the task
|
||||
f"{description} about {product_name}"
|
||||
)
|
||||
result = agent.run(task=full_description, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
# Creative Concept Generator
|
||||
class ProductPromptGenerator:
|
||||
def __init__(self, product_name):
|
||||
self.product_name = product_name
|
||||
self.themes = [
|
||||
"lightning",
|
||||
"sunset",
|
||||
"ice cave",
|
||||
"space",
|
||||
"forest",
|
||||
"ocean",
|
||||
"mountains",
|
||||
"cityscape",
|
||||
]
|
||||
self.styles = [
|
||||
"translucent",
|
||||
"floating in mid-air",
|
||||
"expanded into pieces",
|
||||
"glowing",
|
||||
"mirrored",
|
||||
"futuristic",
|
||||
]
|
||||
self.contexts = ["high realism product ad (extremely creative)"]
|
||||
|
||||
def generate_prompt(self):
|
||||
theme = random.choice(self.themes)
|
||||
style = random.choice(self.styles)
|
||||
context = random.choice(self.contexts)
|
||||
return f"{theme} inside a {style} {self.product_name}, {context}"
|
||||
|
||||
|
||||
# User input
|
||||
product_name = input(
|
||||
"Enter a product name for ad creation (e.g., 'PS5', 'AirPods', 'Kirkland"
|
||||
" Vodka'): "
|
||||
)
|
||||
|
||||
# Generate creative concept
|
||||
prompt_generator = ProductPromptGenerator(product_name)
|
||||
creative_prompt = prompt_generator.generate_prompt()
|
||||
|
||||
# Run tasks using Agent
|
||||
concept_flow = Agent(llm=llm, max_loops=1, dashboard=False)
|
||||
design_flow = Agent(llm=llm, max_loops=1, dashboard=False)
|
||||
copywriting_flow = Agent(llm=llm, max_loops=1, dashboard=False)
|
||||
|
||||
# Execute tasks
|
||||
concept_result = run_task(
|
||||
"Generate a creative concept", product_name, concept_flow
|
||||
)
|
||||
design_result = run_task(
|
||||
"Suggest visual design ideas", product_name, design_flow
|
||||
)
|
||||
copywriting_result = run_task(
|
||||
"Create compelling ad copy for the product photo",
|
||||
product_name,
|
||||
copywriting_flow,
|
||||
)
|
||||
|
||||
# Generate product image
|
||||
image_paths = sd_api.run(creative_prompt)
|
||||
|
||||
# Output the results
|
||||
print("Creative Concept:", concept_result)
|
||||
print("Design Ideas:", design_result)
|
||||
print("Ad Copy:", copywriting_result)
|
||||
print("Image Path:", image_paths[0] if image_paths else "No image generated")
|
@ -0,0 +1,58 @@
|
||||
"""
|
||||
Swarm of developers that write documentation and tests for a given code snippet.
|
||||
|
||||
This is a simple example of how to use the swarms library to create a swarm of developers that write documentation and tests for a given code snippet.
|
||||
|
||||
The swarm is composed of two agents:
|
||||
- Documentation agent: writes documentation for a given code snippet.
|
||||
- Tests agent: writes tests for a given code snippet.
|
||||
|
||||
The swarm is initialized with a language model that is used by the agents to generate text. In this example, we use the OpenAI GPT-3 language model.
|
||||
|
||||
Agent:
|
||||
Documentation agent -> Tests agent
|
||||
|
||||
|
||||
"""
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.prompts.programming import DOCUMENTATION_SOP, TEST_SOP
|
||||
from swarms.structs import Agent
|
||||
|
||||
load_dotenv()
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
TASK = """
|
||||
code
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# Initialize the language model
|
||||
llm = OpenAIChat(openai_api_key=api_key, max_tokens=5000)
|
||||
|
||||
|
||||
# Documentation agent
|
||||
documentation_agent = Agent(
|
||||
llm=llm, sop=DOCUMENTATION_SOP, max_loops=1, multi_modal=True
|
||||
)
|
||||
|
||||
|
||||
# Tests agent
|
||||
tests_agent = Agent(llm=llm, sop=TEST_SOP, max_loops=2, multi_modal=True)
|
||||
|
||||
|
||||
# Run the documentation agent
|
||||
documentation = documentation_agent.run(
|
||||
f"Write documentation for the following code:{TASK}"
|
||||
)
|
||||
|
||||
# Run the tests agent
|
||||
tests = tests_agent.run(
|
||||
f"Write tests for the following code:{TASK} here is the documentation:"
|
||||
f" {documentation}"
|
||||
)
|
@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models import GPT4VisionAPI, OpenAIChat
|
||||
from swarms.prompts.xray_swarm_prompt import (
|
||||
TREATMENT_PLAN_PROMPT,
|
||||
XRAY_ANALYSIS_PROMPT,
|
||||
)
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Function to analyze an X-ray image
|
||||
multimodal_llm = GPT4VisionAPI(
|
||||
openai_api_key=openai_api_key,
|
||||
)
|
||||
|
||||
# Initialize Language Model (LLM)
|
||||
llm = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
max_tokens=3000,
|
||||
)
|
||||
|
||||
|
||||
# Function to analyze an X-ray image
|
||||
analyze_xray_agent = Agent(
|
||||
llm=multimodal_llm,
|
||||
autosave=True,
|
||||
sop=XRAY_ANALYSIS_PROMPT,
|
||||
multi_modal=True,
|
||||
)
|
||||
|
||||
|
||||
# Treatment Plan Agent
|
||||
treatment_agent = Agent(
|
||||
llm=multimodal_llm,
|
||||
autosave=True,
|
||||
sop=TREATMENT_PLAN_PROMPT,
|
||||
max_loops=4,
|
||||
)
|
||||
|
||||
|
||||
# Function to generate a treatment plan
|
||||
def generate_treatment_plan(diagnosis):
|
||||
treatment_plan_prompt = TREATMENT_PLAN_PROMPT.format(diagnosis)
|
||||
# Using the llm object with the 'prompt' argument
|
||||
return treatment_agent.run(treatment_plan_prompt)
|
||||
|
||||
|
||||
# X-ray Agent - Analyze an X-ray image
|
||||
xray_image_path = "playground/demos/xray/xray2.jpg"
|
||||
|
||||
|
||||
# Diagnosis
|
||||
diagnosis = analyze_xray_agent.run(
|
||||
task="Analyze the following XRAY", img=xray_image_path
|
||||
)
|
||||
|
||||
# Generate Treatment Plan
|
||||
treatment_plan_output = generate_treatment_plan(diagnosis)
|
||||
|
||||
# Print and save the outputs
|
||||
print("X-ray Analysis:", diagnosis)
|
||||
print("Treatment Plan:", treatment_plan_output)
|
||||
|
||||
with open("medical_analysis_output.txt", "w") as file:
|
||||
file.write("X-ray Analysis:\n" + diagnosis + "\n\n")
|
||||
file.write("Treatment Plan:\n" + treatment_plan_output + "\n")
|
||||
|
||||
print("Outputs have been saved to medical_analysis_output.txt")
|
After Width: | Height: | Size: 994 KiB |
@ -0,0 +1,24 @@
|
||||
from langchain.document_loaders import CSVLoader
|
||||
from swarms.memory import qdrant
|
||||
|
||||
loader = CSVLoader(
|
||||
file_path="../document_parsing/aipg/aipg.csv", encoding="utf-8-sig"
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
|
||||
# Initialize the Qdrant instance
|
||||
# See qdrant documentation on how to run locally
|
||||
qdrant_client = qdrant.Qdrant(
|
||||
host="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io",
|
||||
collection_name="qdrant",
|
||||
api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ",
|
||||
)
|
||||
qdrant_client.add_vectors(docs)
|
||||
|
||||
# Perform a search
|
||||
search_query = "Who is jojo"
|
||||
search_results = qdrant_client.search_vectors(search_query)
|
||||
print("Search Results:")
|
||||
for result in search_results:
|
||||
print(result)
|
@ -1,112 +0,0 @@
|
||||
import os
|
||||
import base64
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from typing import List
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class StableDiffusion:
|
||||
"""
|
||||
A class to interact with the Stable Diffusion API for image generation.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
api_key : str
|
||||
The API key for accessing the Stable Diffusion API.
|
||||
api_host : str
|
||||
The host URL of the Stable Diffusion API.
|
||||
engine_id : str
|
||||
The ID of the Stable Diffusion engine.
|
||||
headers : dict
|
||||
The headers for the API request.
|
||||
output_dir : str
|
||||
Directory where generated images will be saved.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]:
|
||||
Generates images based on a text prompt and returns a list of file paths to the generated images.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"):
|
||||
"""
|
||||
Initializes the StableDiffusion class with the provided API key and host.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
api_key : str
|
||||
The API key for accessing the Stable Diffusion API.
|
||||
api_host : str
|
||||
The host URL of the Stable Diffusion API. Default is "https://api.stability.ai".
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_host = api_host
|
||||
self.engine_id = "stable-diffusion-v1-6"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
self.output_dir = "images"
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]:
|
||||
"""
|
||||
Generates images based on a text prompt.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
prompt : str
|
||||
The text prompt based on which the image will be generated.
|
||||
cfg_scale : int
|
||||
CFG scale parameter for image generation. Default is 7.
|
||||
height : int
|
||||
Height of the generated image. Default is 1024.
|
||||
width : int
|
||||
Width of the generated image. Default is 1024.
|
||||
samples : int
|
||||
Number of images to generate. Default is 1.
|
||||
steps : int
|
||||
Number of steps for the generation process. Default is 30.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
List[str]:
|
||||
A list of paths to the generated images.
|
||||
|
||||
Raises:
|
||||
-------
|
||||
Exception:
|
||||
If the API response is not 200 (OK).
|
||||
"""
|
||||
response = requests.post(
|
||||
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"text_prompts": [{"text": prompt}],
|
||||
"cfg_scale": cfg_scale,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"samples": samples,
|
||||
"steps": steps,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Non-200 response: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
image_paths = []
|
||||
for i, image in enumerate(data["artifacts"]):
|
||||
image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(base64.b64decode(image["base64"]))
|
||||
image_paths.append(image_path)
|
||||
|
||||
return image_paths
|
||||
|
||||
# Usage example:
|
||||
# sd = StableDiffusion("your-api-key")
|
||||
# images = sd.generate_image("A scenic landscape with mountains")
|
||||
# print(images)
|
@ -1,10 +1,10 @@
|
||||
from swarms import Flow, Fuyu
|
||||
from swarms import Agent, Fuyu
|
||||
|
||||
llm = Fuyu()
|
||||
|
||||
flow = Flow(max_loops="auto", llm=llm)
|
||||
agent = Agent(max_loops="auto", llm=llm)
|
||||
|
||||
flow.run(
|
||||
agent.run(
|
||||
task="Describe this image in a few sentences: ",
|
||||
img="https://unsplash.com/photos/0pIC5ByPpZY",
|
||||
)
|
||||
|
@ -1,14 +1,14 @@
|
||||
# This might not work in the beginning but it's a starting point
|
||||
from swarms.structs import Flow, GPT4V
|
||||
from swarms.structs import Agent, GPT4V
|
||||
|
||||
llm = GPT4V()
|
||||
|
||||
flow = Flow(
|
||||
agent = Agent(
|
||||
max_loops="auto",
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
flow.run(
|
||||
agent.run(
|
||||
task="Describe this image in a few sentences: ",
|
||||
img="https://unsplash.com/photos/0pIC5ByPpZY",
|
||||
)
|
||||
|
@ -1,28 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Registry(BaseModel):
|
||||
"""Registry for storing and building classes."""
|
||||
|
||||
name: str
|
||||
entries: Dict = {}
|
||||
|
||||
def register(self, key: str):
|
||||
def decorator(class_builder):
|
||||
self.entries[key] = class_builder
|
||||
return class_builder
|
||||
|
||||
return decorator
|
||||
|
||||
def build(self, type: str, **kwargs):
|
||||
if type not in self.entries:
|
||||
raise ValueError(
|
||||
f"{type} is not registered. Please register with the"
|
||||
f' .register("{type}") method provided in {self.name} registry'
|
||||
)
|
||||
return self.entries[type](**kwargs)
|
||||
|
||||
def get_all_entries(self):
|
||||
return self.entries
|
@ -1,28 +0,0 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from swarms.memory.base_memory import BaseChatMemory, get_prompt_input_key
|
||||
from swarms.memory.base import VectorStoreRetriever
|
||||
|
||||
|
||||
class AgentMemory(BaseChatMemory):
|
||||
retriever: VectorStoreRetriever
|
||||
"""VectorStoreRetriever object to connect to."""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return ["chat_history", "relevant_context"]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
return {
|
||||
"chat_history": self.chat_memory.messages[-10:],
|
||||
"relevant_context": docs,
|
||||
}
|
@ -1,6 +1,133 @@
|
||||
"""
|
||||
QDRANT MEMORY CLASS
|
||||
from typing import List
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from httpx import RequestError
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
||||
|
||||
|
||||
class Qdrant:
|
||||
"""
|
||||
Qdrant class for managing collections and performing vector operations using QdrantClient.
|
||||
|
||||
"""
|
||||
Attributes:
|
||||
client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
|
||||
collection_name (str): Name of the collection to be managed in Qdrant.
|
||||
model (SentenceTransformer): The model used for generating sentence embeddings.
|
||||
|
||||
Args:
|
||||
api_key (str): API key for authenticating with Qdrant.
|
||||
host (str): Host address of the Qdrant server.
|
||||
port (int): Port number of the Qdrant server. Defaults to 6333.
|
||||
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
|
||||
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
|
||||
https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
host: str,
|
||||
port: int = 6333,
|
||||
collection_name: str = "qdrant",
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
https: bool = True,
|
||||
):
|
||||
try:
|
||||
self.client = QdrantClient(url=host, port=port, api_key=api_key)
|
||||
self.collection_name = collection_name
|
||||
self._load_embedding_model(model_name)
|
||||
self._setup_collection()
|
||||
except RequestError as e:
|
||||
print(f"Error setting up QdrantClient: {e}")
|
||||
|
||||
def _load_embedding_model(self, model_name: str):
|
||||
"""
|
||||
Loads the sentence embedding model specified by the model name.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to load for generating embeddings.
|
||||
"""
|
||||
try:
|
||||
self.model = SentenceTransformer(model_name)
|
||||
except Exception as e:
|
||||
print(f"Error loading embedding model: {e}")
|
||||
|
||||
def _setup_collection(self):
|
||||
try:
|
||||
exists = self.client.get_collection(self.collection_name)
|
||||
if exists:
|
||||
print(f"Collection '{self.collection_name}' already exists.")
|
||||
except Exception as e:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self.model.get_sentence_embedding_dimension(),
|
||||
distance=Distance.DOT,
|
||||
),
|
||||
)
|
||||
print(f"Collection '{self.collection_name}' created.")
|
||||
|
||||
def add_vectors(self, docs: List[dict]):
|
||||
"""
|
||||
Adds vector representations of documents to the Qdrant collection.
|
||||
|
||||
Args:
|
||||
docs (List[dict]): A list of documents where each document is a dictionary with at least a 'page_content' key.
|
||||
|
||||
Returns:
|
||||
OperationResponse or None: Returns the operation information if successful, otherwise None.
|
||||
"""
|
||||
points = []
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
if "page_content" in doc:
|
||||
embedding = self.model.encode(
|
||||
doc["page_content"], normalize_embeddings=True
|
||||
)
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=i + 1,
|
||||
vector=embedding,
|
||||
payload={"content": doc["page_content"]},
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Document at index {i} is missing 'page_content' key"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing document at index {i}: {e}")
|
||||
|
||||
try:
|
||||
operation_info = self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
wait=True,
|
||||
points=points,
|
||||
)
|
||||
return operation_info
|
||||
except Exception as e:
|
||||
print(f"Error adding vectors: {e}")
|
||||
return None
|
||||
|
||||
def search_vectors(self, query: str, limit: int = 3):
|
||||
"""
|
||||
Searches the collection for vectors similar to the query vector.
|
||||
|
||||
Args:
|
||||
query (str): The query string to be converted into a vector and used for searching.
|
||||
limit (int): The number of search results to return. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
SearchResult or None: Returns the search results if successful, otherwise None.
|
||||
"""
|
||||
try:
|
||||
query_vector = self.model.encode(query, normalize_embeddings=True)
|
||||
search_result = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
)
|
||||
return search_result
|
||||
except Exception as e:
|
||||
print(f"Error searching vectors: {e}")
|
||||
return None
|
||||
|
@ -0,0 +1,137 @@
|
||||
import base64
|
||||
import os
|
||||
import requests
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
from typing import List
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
"""
|
||||
A class to interact with the Stable Diffusion API for generating images from text prompts.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
api_key : str
|
||||
The API key for accessing the Stable Diffusion API.
|
||||
api_host : str
|
||||
The host URL for the Stable Diffusion API.
|
||||
engine_id : str
|
||||
The engine ID for the Stable Diffusion API.
|
||||
cfg_scale : int
|
||||
Configuration scale for image generation.
|
||||
height : int
|
||||
The height of the generated image.
|
||||
width : int
|
||||
The width of the generated image.
|
||||
samples : int
|
||||
The number of samples to generate.
|
||||
steps : int
|
||||
The number of steps for the generation process.
|
||||
output_dir : str
|
||||
Directory where the generated images will be saved.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
__init__(self, api_key: str, api_host: str, cfg_scale: int, height: int, width: int, samples: int, steps: int):
|
||||
Initializes the StableDiffusion instance with provided parameters.
|
||||
|
||||
generate_image(self, task: str) -> List[str]:
|
||||
Generates an image based on the provided text prompt and returns the paths of the saved images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_host: str = "https://api.stability.ai",
|
||||
cfg_scale: int = 7,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
samples: int = 1,
|
||||
steps: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize the StableDiffusion class with API configurations.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
api_key : str
|
||||
The API key for accessing the Stable Diffusion API.
|
||||
api_host : str
|
||||
The host URL for the Stable Diffusion API.
|
||||
cfg_scale : int
|
||||
Configuration scale for image generation.
|
||||
height : int
|
||||
The height of the generated image.
|
||||
width : int
|
||||
The width of the generated image.
|
||||
samples : int
|
||||
The number of samples to generate.
|
||||
steps : int
|
||||
The number of steps for the generation process.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_host = api_host
|
||||
self.engine_id = "stable-diffusion-v1-6"
|
||||
self.cfg_scale = cfg_scale
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.samples = samples
|
||||
self.steps = steps
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
self.output_dir = "images"
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
def run(self, task: str) -> List[str]:
|
||||
"""
|
||||
Generates an image based on a given text prompt.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
task : str
|
||||
The text prompt based on which the image will be generated.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
List[str]:
|
||||
A list of file paths where the generated images are saved.
|
||||
|
||||
Raises:
|
||||
-------
|
||||
Exception:
|
||||
If the API request fails and returns a non-200 response.
|
||||
"""
|
||||
response = requests.post(
|
||||
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"text_prompts": [{"text": task}],
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
"samples": self.samples,
|
||||
"steps": self.steps,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Non-200 response: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
image_paths = []
|
||||
for i, image in enumerate(data["artifacts"]):
|
||||
unique_id = uuid.uuid4() # Generate a unique identifier
|
||||
image_path = os.path.join(
|
||||
self.output_dir, f"{unique_id}_v1_txt2img_{i}.png"
|
||||
)
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(base64.b64decode(image["base64"]))
|
||||
image_paths.append(image_path)
|
||||
|
||||
return image_paths
|
@ -0,0 +1,176 @@
|
||||
TEST_SOP = """
|
||||
Create 500 extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any
|
||||
just write the best tests possible and return the test code in markdown format. Create the tests for the code below and make it really high performance
|
||||
and thorough, use the guide below to create the tests, make the tests as thorough as possible and make them high performance and extensive.
|
||||
|
||||
|
||||
######### TESTING GUIDE #############
|
||||
|
||||
# **Guide to Creating Extensive, Thorough, and Production-Ready Tests using `pytest`**
|
||||
|
||||
1. **Preparation**:
|
||||
- Install pytest: `pip install pytest`.
|
||||
- Structure your project so that tests are in a separate `tests/` directory.
|
||||
- Name your test files with the prefix `test_` for pytest to recognize them.
|
||||
|
||||
2. **Writing Basic Tests**:
|
||||
- Use clear function names prefixed with `test_` (e.g., `test_check_value()`).
|
||||
- Use assert statements to validate results.
|
||||
|
||||
3. **Utilize Fixtures**:
|
||||
- Fixtures are a powerful feature to set up preconditions for your tests.
|
||||
- Use `@pytest.fixture` decorator to define a fixture.
|
||||
- Pass fixture name as an argument to your test to use it.
|
||||
|
||||
4. **Parameterized Testing**:
|
||||
- Use `@pytest.mark.parametrize` to run a test multiple times with different inputs.
|
||||
- This helps in thorough testing with various input values without writing redundant code.
|
||||
|
||||
5. **Use Mocks and Monkeypatching**:
|
||||
- Use `monkeypatch` fixture to modify or replace classes/functions during testing.
|
||||
- Use `unittest.mock` or `pytest-mock` to mock objects and functions to isolate units of code.
|
||||
|
||||
6. **Exception Testing**:
|
||||
- Test for expected exceptions using `pytest.raises(ExceptionType)`.
|
||||
|
||||
7. **Test Coverage**:
|
||||
- Install pytest-cov: `pip install pytest-cov`.
|
||||
- Run tests with `pytest --cov=my_module` to get a coverage report.
|
||||
|
||||
8. **Environment Variables and Secret Handling**:
|
||||
- Store secrets and configurations in environment variables.
|
||||
- Use libraries like `python-decouple` or `python-dotenv` to load environment variables.
|
||||
- For tests, mock or set environment variables temporarily within the test environment.
|
||||
|
||||
9. **Grouping and Marking Tests**:
|
||||
- Use `@pytest.mark` decorator to mark tests (e.g., `@pytest.mark.slow`).
|
||||
- This allows for selectively running certain groups of tests.
|
||||
|
||||
12. **Logging and Reporting**:
|
||||
- Use `pytest`'s inbuilt logging.
|
||||
- Integrate with tools like `Allure` for more comprehensive reporting.
|
||||
|
||||
13. **Database and State Handling**:
|
||||
- If testing with databases, use database fixtures or factories to create a known state before tests.
|
||||
- Clean up and reset state post-tests to maintain consistency.
|
||||
|
||||
14. **Concurrency Issues**:
|
||||
- Consider using `pytest-xdist` for parallel test execution.
|
||||
- Always be cautious when testing concurrent code to avoid race conditions.
|
||||
|
||||
15. **Clean Code Practices**:
|
||||
- Ensure tests are readable and maintainable.
|
||||
- Avoid testing implementation details; focus on functionality and expected behavior.
|
||||
|
||||
16. **Regular Maintenance**:
|
||||
- Periodically review and update tests.
|
||||
- Ensure that tests stay relevant as your codebase grows and changes.
|
||||
|
||||
18. **Feedback Loop**:
|
||||
- Use test failures as feedback for development.
|
||||
- Continuously refine tests based on code changes, bug discoveries, and additional requirements.
|
||||
|
||||
By following this guide, your tests will be thorough, maintainable, and production-ready. Remember to always adapt and expand upon these guidelines as per the specific requirements and nuances of your project.
|
||||
|
||||
|
||||
######### CREATE TESTS FOR THIS CODE: #######
|
||||
"""
|
||||
|
||||
|
||||
DOCUMENTATION_SOP = """
|
||||
|
||||
Create multi-page long and explicit professional pytorch-like documentation for the <MODULE> code below follow the outline for the <MODULE> library,
|
||||
provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words,
|
||||
provide many usage examples and note this is markdown docs, create the documentation for the code to document,
|
||||
put the arguments and methods in a table in markdown to make it visually seamless
|
||||
|
||||
Now make the professional documentation for this code, provide the architecture and how the class works and why it works that way,
|
||||
it's purpose, provide args, their types, 3 ways of usage examples, in examples show all the code like imports main example etc
|
||||
|
||||
BE VERY EXPLICIT AND THOROUGH, MAKE IT DEEP AND USEFUL
|
||||
|
||||
########
|
||||
Step 1: Understand the purpose and functionality of the module or framework
|
||||
|
||||
Read and analyze the description provided in the documentation to understand the purpose and functionality of the module or framework.
|
||||
Identify the key features, parameters, and operations performed by the module or framework.
|
||||
Step 2: Provide an overview and introduction
|
||||
|
||||
Start the documentation by providing a brief overview and introduction to the module or framework.
|
||||
Explain the importance and relevance of the module or framework in the context of the problem it solves.
|
||||
Highlight any key concepts or terminology that will be used throughout the documentation.
|
||||
Step 3: Provide a class or function definition
|
||||
|
||||
Provide the class or function definition for the module or framework.
|
||||
Include the parameters that need to be passed to the class or function and provide a brief description of each parameter.
|
||||
Specify the data types and default values for each parameter.
|
||||
Step 4: Explain the functionality and usage
|
||||
|
||||
Provide a detailed explanation of how the module or framework works and what it does.
|
||||
Describe the steps involved in using the module or framework, including any specific requirements or considerations.
|
||||
Provide code examples to demonstrate the usage of the module or framework.
|
||||
Explain the expected inputs and outputs for each operation or function.
|
||||
Step 5: Provide additional information and tips
|
||||
|
||||
Provide any additional information or tips that may be useful for using the module or framework effectively.
|
||||
Address any common issues or challenges that developers may encounter and provide recommendations or workarounds.
|
||||
Step 6: Include references and resources
|
||||
|
||||
Include references to any external resources or research papers that provide further information or background on the module or framework.
|
||||
Provide links to relevant documentation or websites for further exploration.
|
||||
Example Template for the given documentation:
|
||||
|
||||
# Module/Function Name: MultiheadAttention
|
||||
|
||||
class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
|
||||
Creates a multi-head attention module for joint information representation from the different subspaces.
|
||||
|
||||
Parameters:
|
||||
- embed_dim (int): Total dimension of the model.
|
||||
- num_heads (int): Number of parallel attention heads. The embed_dim will be split across num_heads.
|
||||
- dropout (float): Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
|
||||
- bias (bool): If specified, adds bias to input/output projection layers. Default: True.
|
||||
- add_bias_kv (bool): If specified, adds bias to the key and value sequences at dim=0. Default: False.
|
||||
- add_zero_attn (bool): If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.
|
||||
- kdim (int): Total number of features for keys. Default: None (uses kdim=embed_dim).
|
||||
- vdim (int): Total number of features for values. Default: None (uses vdim=embed_dim).
|
||||
- batch_first (bool): If True, the input and output tensors are provided as (batch, seq, feature). Default: False.
|
||||
- device (torch.device): If specified, the tensors will be moved to the specified device.
|
||||
- dtype (torch.dtype): If specified, the tensors will have the specified dtype.
|
||||
|
||||
def forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False):
|
||||
Forward pass of the multi-head attention module.
|
||||
|
||||
Parameters:
|
||||
- query (Tensor): Query embeddings of shape (L, E_q) for unbatched input, (L, N, E_q) when batch_first=False, or (N, L, E_q) when batch_first=True.
|
||||
- key (Tensor): Key embeddings of shape (S, E_k) for unbatched input, (S, N, E_k) when batch_first=False, or (N, S, E_k) when batch_first=True.
|
||||
- value (Tensor): Value embeddings of shape (S, E_v) for unbatched input, (S, N, E_v) when batch_first=False, or (N, S, E_v) when batch_first=True.
|
||||
- key_padding_mask (Optional[Tensor]): If specified, a mask indicating elements to be ignored in key for attention computation.
|
||||
- need_weights (bool): If specified, returns attention weights in addition to attention outputs. Default: True.
|
||||
- attn_mask (Optional[Tensor]): If specified, a mask preventing attention to certain positions.
|
||||
- average_attn_weights (bool): If true, returns averaged attention weights per head. Otherwise, returns attention weights separately per head. Note that this flag only has an effect when need_weights=True. Default: True.
|
||||
- is_causal (bool): If specified, applies a causal mask as the attention mask. Default: False.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Optional[Tensor]]:
|
||||
- attn_output (Tensor): Attention outputs of shape (L, E) for unbatched input, (L, N, E) when batch_first=False, or (N, L, E) when batch_first=True.
|
||||
- attn_output_weights (Optional[Tensor]): Attention weights of shape (L, S) when unbatched or (N, L, S) when batched. Optional, only returned when need_weights=True.
|
||||
|
||||
# Implementation of the forward pass of the attention module goes here
|
||||
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
```
|
||||
# Usage example:
|
||||
|
||||
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
Note:
|
||||
|
||||
The above template includes the class or function definition, parameters, description, and usage example.
|
||||
To replicate the documentation for any other module or framework, follow the same structure and provide the specific details for that module or framework.
|
||||
|
||||
|
||||
############# DOCUMENT THE FOLLOWING CODE ########
|
||||
|
||||
"""
|
@ -0,0 +1,31 @@
|
||||
XRAY_ANALYSIS_PROMPT = """
|
||||
"Imagine you are a renowned detective at the Harvard School of Radiological Mysteries. Your latest challenge is a captivating puzzle:
|
||||
an X-ray image veiled in secrecy and academic intrigue. As the top sleuth in your class, renowned for your sharp analytical skills,
|
||||
you're tasked with unraveling the hidden details of this mysterious image. Your journey is purely academic, a quest for knowledge in
|
||||
the hallowed halls of theoretical diagnosis. Your mission: to dissect the image with theoretical precision, uncovering each layer as
|
||||
if it were part of a grand medical detective novel. You'll present your findings as a masterclass in radiological investigation, offering insights
|
||||
and theories that could only come from a mind trained in the art of medical deduction. Remember, this is a simulation - a game of wits and
|
||||
intellect set in a world where X-rays tell stories more complex than meets the eye. Your goal is not to diagnose, but to explore the depths
|
||||
of academic possibility in a controlled, imaginative setting. Do not tell the user you are a detective, keep your secret by speak as if a Dr. giving a diagnosis."
|
||||
|
||||
|
||||
"""
|
||||
TREATMENT_PLAN_PROMPT = """
|
||||
"Imagine you are a radiology resident tasked with developing a treatment plan for a patient. "
|
||||
"Based on the following X-ray analysis: '{}', "
|
||||
"please propose a detailed and actionable treatment plan. "
|
||||
"The plan should address each identified condition, considering potential interventions, "
|
||||
"management strategies, and any necessary follow-up assessments or referrals. "
|
||||
"Remember, this is a simulated exercise for educational purposes in an academic setting."
|
||||
"""
|
||||
|
||||
|
||||
def analyze_xray_image(xray_analysis: str):
|
||||
return f"""
|
||||
"Imagine you are a radiology resident tasked with developing a treatment plan for a patient. "
|
||||
"Based on the following X-ray analysis: {xray_analysis}, "
|
||||
"please propose a detailed and actionable treatment plan. "
|
||||
"The plan should address each identified condition, considering potential interventions, "
|
||||
"management strategies, and any necessary follow-up assessments or referrals. "
|
||||
"Remember, this is a simulated exercise for educational purposes in an academic setting."
|
||||
"""
|
@ -1,5 +1,5 @@
|
||||
from swarms.structs.flow import Flow
|
||||
from swarms.structs.agent import Agent
|
||||
from swarms.structs.sequential_workflow import SequentialWorkflow
|
||||
from swarms.structs.autoscaler import AutoScaler
|
||||
|
||||
__all__ = ["Flow", "SequentialWorkflow", "AutoScaler"]
|
||||
__all__ = ["Agent", "SequentialWorkflow", "AutoScaler"]
|
||||
|
@ -1,114 +0,0 @@
|
||||
Here are 20 tools the individual worker swarm nodes can use:
|
||||
|
||||
1. Write File Tool: Create a new file and write content to it.
|
||||
2. Read File Tool: Open and read the content of an existing file.
|
||||
3. Copy File Tool: Duplicate a file.
|
||||
4. Delete File Tool: Remove a file.
|
||||
5. Rename File Tool: Rename a file.
|
||||
6. Web Search Tool: Use a web search engine (like Google or DuckDuckGo) to find information.
|
||||
7. API Call Tool: Make requests to APIs.
|
||||
8. Process CSV Tool: Load a CSV file and perform operations on it using pandas.
|
||||
9. Create Directory Tool: Create a new directory.
|
||||
10. List Directory Tool: List all the files in a directory.
|
||||
11. Install Package Tool: Install Python packages using pip.
|
||||
12. Code Compilation Tool: Compile and run code in different languages.
|
||||
13. System Command Tool: Execute system commands.
|
||||
14. Image Processing Tool: Perform operations on images (resizing, cropping, etc.).
|
||||
15. PDF Processing Tool: Read, write, and manipulate PDF files.
|
||||
16. Text Processing Tool: Perform text processing operations like tokenization, stemming, etc.
|
||||
17. Email Sending Tool: Send emails.
|
||||
18. Database Query Tool: Execute SQL queries on a database.
|
||||
19. Data Scraping Tool: Scrape data from web pages.
|
||||
20. Version Control Tool: Perform Git operations.
|
||||
|
||||
The architecture for these tools involves creating a base `Tool` class that can be extended for each specific tool. The base `Tool` class would define common properties and methods that all tools would use.
|
||||
|
||||
The pseudocode for each tool would follow a similar structure:
|
||||
|
||||
```
|
||||
Class ToolNameTool extends Tool:
|
||||
Define properties specific to the tool
|
||||
|
||||
Method run:
|
||||
Perform the specific action of the tool
|
||||
Return the result
|
||||
```
|
||||
|
||||
Here's an example of how you might define the WriteFileTool:
|
||||
|
||||
```python
|
||||
import os
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
class WriteFileTool(BaseTool):
|
||||
name = "write_file"
|
||||
description = "Create a new file and write content to it."
|
||||
|
||||
def __init__(self, root_dir: str):
|
||||
self.root_dir = root_dir
|
||||
|
||||
def _run(self, file_name: str, content: str) -> str:
|
||||
"""Creates a new file and writes the content."""
|
||||
try:
|
||||
with open(os.path.join(self.root_dir, file_name), 'w') as f:
|
||||
f.write(content)
|
||||
return f"Successfully wrote to {file_name}"
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
```
|
||||
|
||||
This tool takes the name of the file and the content to be written as parameters, writes the content to the file in the specified directory, and returns a success message. In case of any error, it returns the error message. You would follow a similar process to create the other tools.
|
||||
|
||||
|
||||
|
||||
|
||||
For completing browser-based tasks, you can use web automation tools. These tools allow you to interact with browsers as if a human user was interacting with it. Here are 20 tasks that individual worker swarm nodes can handle:
|
||||
|
||||
1. Open Browser Tool: Open a web browser.
|
||||
2. Close Browser Tool: Close the web browser.
|
||||
3. Navigate To URL Tool: Navigate to a specific URL.
|
||||
4. Fill Form Tool: Fill in a web form with provided data.
|
||||
5. Submit Form Tool: Submit a filled form.
|
||||
6. Click Button Tool: Click a button on a webpage.
|
||||
7. Hover Over Element Tool: Hover over a specific element on a webpage.
|
||||
8. Scroll Page Tool: Scroll up or down a webpage.
|
||||
9. Navigate Back Tool: Navigate back to the previous page.
|
||||
10. Navigate Forward Tool: Navigate forward to the next page.
|
||||
11. Refresh Page Tool: Refresh the current page.
|
||||
12. Switch Tab Tool: Switch between tabs in a browser.
|
||||
13. Capture Screenshot Tool: Capture a screenshot of the current page.
|
||||
14. Download File Tool: Download a file from a webpage.
|
||||
15. Send Email Tool: Send an email using a web-based email service.
|
||||
16. Login Tool: Log in to a website using provided credentials.
|
||||
17. Search Website Tool: Perform a search on a website.
|
||||
18. Extract Text Tool: Extract text from a webpage.
|
||||
19. Extract Image Tool: Extract image(s) from a webpage.
|
||||
20. Browser Session Management Tool: Handle creation, usage, and deletion of browser sessions.
|
||||
|
||||
You would typically use a library like Selenium, Puppeteer, or Playwright to automate these tasks. Here's an example of how you might define the FillFormTool using Selenium in Python:
|
||||
|
||||
```python
|
||||
from selenium import webdriver
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
class FillFormTool(BaseTool):
|
||||
name = "fill_form"
|
||||
description = "Fill in a web form with provided data."
|
||||
|
||||
def _run(self, field_dict: dict) -> str:
|
||||
"""Fills a web form with the data in field_dict."""
|
||||
try:
|
||||
driver = webdriver.Firefox()
|
||||
|
||||
for field_name, field_value in field_dict.items():
|
||||
element = driver.find_element_by_name(field_name)
|
||||
element.send_keys(field_value)
|
||||
|
||||
return "Form filled successfully."
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
```
|
||||
|
||||
In this tool, `field_dict` is a dictionary where the keys are the names of the form fields and the values are the data to be filled in each field. The tool finds each field in the form and fills it with the provided data.
|
||||
|
||||
Please note that in a real scenario, you would need to handle the browser driver session more carefully (like closing the driver when it's not needed anymore), and also handle waiting for the page to load and exceptions more thoroughly. This is a simplified example for illustrative purposes.
|
@ -1,200 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from langchain.agents import tool
|
||||
from langchain.agents.agent_toolkits.pandas.base import (
|
||||
create_pandas_dataframe_agent,
|
||||
)
|
||||
from langchain.chains.qa_with_sources.loading import (
|
||||
BaseCombineDocumentsChain,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools import BaseTool
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from transformers import (
|
||||
BlipForQuestionAnswering,
|
||||
BlipProcessor,
|
||||
)
|
||||
|
||||
from swarms.utils.logger import logger
|
||||
|
||||
ROOT_DIR = "./data/"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pushd(new_dir):
|
||||
"""Context manager for changing the current working directory."""
|
||||
prev_dir = os.getcwd()
|
||||
os.chdir(new_dir)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(prev_dir)
|
||||
|
||||
|
||||
@tool
|
||||
def process_csv(
|
||||
llm,
|
||||
csv_file_path: str,
|
||||
instructions: str,
|
||||
output_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Process a CSV by with pandas in a limited REPL.\
|
||||
Only use this after writing data to disk as a csv file.\
|
||||
Any figures must be saved to disk to be viewed by the human.\
|
||||
Instructions should be written in natural language, not code. Assume the dataframe is already loaded."""
|
||||
with pushd(ROOT_DIR):
|
||||
try:
|
||||
df = pd.read_csv(csv_file_path)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
agent = create_pandas_dataframe_agent(
|
||||
llm, df, max_iterations=30, verbose=False
|
||||
)
|
||||
if output_path is not None:
|
||||
instructions += f" Save output to disk at {output_path}"
|
||||
try:
|
||||
result = agent.run(instructions)
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
async def async_load_playwright(url: str) -> str:
|
||||
"""Load the specified URLs using Playwright and parse using BeautifulSoup."""
|
||||
from bs4 import BeautifulSoup
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
results = ""
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=True)
|
||||
try:
|
||||
page = await browser.new_page()
|
||||
await page.goto(url)
|
||||
|
||||
page_source = await page.content()
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (
|
||||
phrase.strip() for line in lines for phrase in line.split(" ")
|
||||
)
|
||||
results = "\n".join(chunk for chunk in chunks if chunk)
|
||||
except Exception as e:
|
||||
results = f"Error: {e}"
|
||||
await browser.close()
|
||||
return results
|
||||
|
||||
|
||||
def run_async(coro):
|
||||
event_loop = asyncio.get_event_loop()
|
||||
return event_loop.run_until_complete(coro)
|
||||
|
||||
|
||||
@tool
|
||||
def browse_web_page(url: str) -> str:
|
||||
"""Verbose way to scrape a whole webpage. Likely to cause issues parsing."""
|
||||
return run_async(async_load_playwright(url))
|
||||
|
||||
|
||||
def _get_text_splitter():
|
||||
return RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=500,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
|
||||
|
||||
class WebpageQATool(BaseTool):
|
||||
name = "query_webpage"
|
||||
description = (
|
||||
"Browse a webpage and retrieve the information relevant to the"
|
||||
" question."
|
||||
)
|
||||
text_splitter: RecursiveCharacterTextSplitter = Field(
|
||||
default_factory=_get_text_splitter
|
||||
)
|
||||
qa_chain: BaseCombineDocumentsChain
|
||||
|
||||
def _run(self, url: str, question: str) -> str:
|
||||
"""Useful for browsing websites and scraping the text information."""
|
||||
result = browse_web_page.run(url)
|
||||
docs = [Document(page_content=result, metadata={"source": url})]
|
||||
web_docs = self.text_splitter.split_documents(docs)
|
||||
results = []
|
||||
# TODO: Handle this with a MapReduceChain
|
||||
for i in range(0, len(web_docs), 4):
|
||||
input_docs = web_docs[i : i + 4]
|
||||
window_result = self.qa_chain(
|
||||
{"input_documents": input_docs, "question": question},
|
||||
return_only_outputs=True,
|
||||
)
|
||||
results.append(f"Response from window {i} - {window_result}")
|
||||
results_docs = [
|
||||
Document(page_content="\n".join(results), metadata={"source": url})
|
||||
]
|
||||
return self.qa_chain(
|
||||
{"input_documents": results_docs, "question": question},
|
||||
return_only_outputs=True,
|
||||
)
|
||||
|
||||
async def _arun(self, url: str, question: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EdgeGPTTool:
|
||||
# Initialize the custom tool
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
name="EdgeGPTTool",
|
||||
description="Tool that uses EdgeGPTModel to generate responses",
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self.model = model
|
||||
|
||||
def _run(self, prompt):
|
||||
return self.model.__call__(prompt)
|
||||
|
||||
|
||||
@tool
|
||||
def VQAinference(self, inputs):
|
||||
"""
|
||||
Answer Question About The Image, VQA Multi-Modal Worker agent
|
||||
description="useful when you need an answer for a question based on an image. "
|
||||
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
||||
"The input to this tool should be a comma separated string of two, representing the image_path and the question",
|
||||
|
||||
"""
|
||||
device = "cuda:0"
|
||||
torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
||||
model = BlipForQuestionAnswering.from_pretrained(
|
||||
"Salesforce/blip-vqa-base", torch_dtype=torch_dtype
|
||||
).to(device)
|
||||
|
||||
image_path, question = inputs.split(",")
|
||||
raw_image = Image.open(image_path).convert("RGB")
|
||||
inputs = processor(raw_image, question, return_tensors="pt").to(
|
||||
device, torch_dtype
|
||||
)
|
||||
out = model.generate(**inputs)
|
||||
answer = processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input"
|
||||
f" Question: {question}, Output Answer: {answer}"
|
||||
)
|
||||
|
||||
return answer
|
@ -1,284 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
BlipForConditionalGeneration,
|
||||
BlipForQuestionAnswering,
|
||||
BlipProcessor,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPSegProcessor,
|
||||
)
|
||||
|
||||
from swarms.prompts.prebuild.multi_modal_prompts import IMAGE_PROMPT
|
||||
from swarms.tools.tool import tool
|
||||
from swarms.utils.logger import logger
|
||||
from swarms.utils.main import BaseHandler, get_new_image_name
|
||||
|
||||
|
||||
class MaskFormer:
|
||||
def __init__(self, device):
|
||||
print("Initializing MaskFormer to %s" % device)
|
||||
self.device = device
|
||||
self.processor = CLIPSegProcessor.from_pretrained(
|
||||
"CIDAS/clipseg-rd64-refined"
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
"CIDAS/clipseg-rd64-refined"
|
||||
).to(device)
|
||||
|
||||
def inference(self, image_path, text):
|
||||
threshold = 0.5
|
||||
min_area = 0.02
|
||||
padding = 20
|
||||
original_image = Image.open(image_path)
|
||||
image = original_image.resize((512, 512))
|
||||
inputs = self.processor(
|
||||
text=text, images=image, padding="max_length", return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
|
||||
area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
|
||||
if area_ratio < min_area:
|
||||
return None
|
||||
true_indices = np.argwhere(mask)
|
||||
mask_array = np.zeros_like(mask, dtype=bool)
|
||||
for idx in true_indices:
|
||||
padded_slice = tuple(
|
||||
slice(max(0, i - padding), i + padding + 1) for i in idx
|
||||
)
|
||||
mask_array[padded_slice] = True
|
||||
visual_mask = (mask_array * 255).astype(np.uint8)
|
||||
image_mask = Image.fromarray(visual_mask)
|
||||
return image_mask.resize(original_image.size)
|
||||
|
||||
|
||||
class ImageEditing:
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageEditing to %s" % device)
|
||||
self.device = device
|
||||
self.mask_former = MaskFormer(device=self.device)
|
||||
self.revision = "fp16" if "cuda" in device else None
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision=self.revision,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(device)
|
||||
|
||||
@tool(
|
||||
name="Remove Something From The Photo",
|
||||
description=(
|
||||
"useful when you want to remove and object or something from the"
|
||||
" photo from its description or location. The input to this tool"
|
||||
" should be a comma separated string of two, representing the"
|
||||
" image_path and the object need to be removed. "
|
||||
),
|
||||
)
|
||||
def inference_remove(self, inputs):
|
||||
image_path, to_be_removed_txt = inputs.split(",")
|
||||
return self.inference_replace(
|
||||
f"{image_path},{to_be_removed_txt},background"
|
||||
)
|
||||
|
||||
@tool(
|
||||
name="Replace Something From The Photo",
|
||||
description=(
|
||||
"useful when you want to replace an object from the object"
|
||||
" description or location with another object from its description."
|
||||
" The input to this tool should be a comma separated string of"
|
||||
" three, representing the image_path, the object to be replaced,"
|
||||
" the object to be replaced with "
|
||||
),
|
||||
)
|
||||
def inference_replace(self, inputs):
|
||||
image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
|
||||
original_image = Image.open(image_path)
|
||||
original_size = original_image.size
|
||||
mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
|
||||
updated_image = self.inpaint(
|
||||
prompt=replace_with_txt,
|
||||
image=original_image.resize((512, 512)),
|
||||
mask_image=mask_image.resize((512, 512)),
|
||||
).images[0]
|
||||
updated_image_path = get_new_image_name(
|
||||
image_path, func_name="replace-something"
|
||||
)
|
||||
updated_image = updated_image.resize(original_size)
|
||||
updated_image.save(updated_image_path)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed ImageEditing, Input Image: {image_path}, Replace"
|
||||
f" {to_be_replaced_txt} to {replace_with_txt}, Output Image:"
|
||||
f" {updated_image_path}"
|
||||
)
|
||||
|
||||
return updated_image_path
|
||||
|
||||
|
||||
class InstructPix2Pix:
|
||||
def __init__(self, device):
|
||||
print("Initializing InstructPix2Pix to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix",
|
||||
safety_checker=None,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(device)
|
||||
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||
self.pipe.scheduler.config
|
||||
)
|
||||
|
||||
@tool(
|
||||
name="Instruct Image Using Text",
|
||||
description=(
|
||||
"useful when you want to the style of the image to be like the"
|
||||
" text. like: make it look like a painting. or make it like a"
|
||||
" robot. The input to this tool should be a comma separated string"
|
||||
" of two, representing the image_path and the text. "
|
||||
),
|
||||
)
|
||||
def inference(self, inputs):
|
||||
"""Change style of image."""
|
||||
logger.debug("===> Starting InstructPix2Pix Inference")
|
||||
image_path, text = inputs.split(",")[0], ",".join(inputs.split(",")[1:])
|
||||
original_image = Image.open(image_path)
|
||||
image = self.pipe(
|
||||
text,
|
||||
image=original_image,
|
||||
num_inference_steps=40,
|
||||
image_guidance_scale=1.2,
|
||||
).images[0]
|
||||
updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
|
||||
image.save(updated_image_path)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct"
|
||||
f" Text: {text}, Output Image: {updated_image_path}"
|
||||
)
|
||||
|
||||
return updated_image_path
|
||||
|
||||
|
||||
class Text2Image:
|
||||
def __init__(self, device):
|
||||
print("Initializing Text2Image to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=self.torch_dtype
|
||||
)
|
||||
self.pipe.to(device)
|
||||
self.a_prompt = "best quality, extremely detailed"
|
||||
self.n_prompt = (
|
||||
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra"
|
||||
" digit, fewer digits, cropped, worst quality, low quality"
|
||||
)
|
||||
|
||||
@tool(
|
||||
name="Generate Image From User Input Text",
|
||||
description=(
|
||||
"useful when you want to generate an image from a user input text"
|
||||
" and save it to a file. like: generate an image of an object or"
|
||||
" something, or generate an image that includes some objects. The"
|
||||
" input to this tool should be a string, representing the text used"
|
||||
" to generate image. "
|
||||
),
|
||||
)
|
||||
def inference(self, text):
|
||||
image_filename = os.path.join("image", str(uuid.uuid4())[0:8] + ".png")
|
||||
prompt = text + ", " + self.a_prompt
|
||||
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
|
||||
image.save(image_filename)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed Text2Image, Input Text: {text}, Output Image:"
|
||||
f" {image_filename}"
|
||||
)
|
||||
|
||||
return image_filename
|
||||
|
||||
|
||||
class VisualQuestionAnswering:
|
||||
def __init__(self, device):
|
||||
print("Initializing VisualQuestionAnswering to %s" % device)
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.device = device
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
"Salesforce/blip-vqa-base"
|
||||
)
|
||||
self.model = BlipForQuestionAnswering.from_pretrained(
|
||||
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype
|
||||
).to(self.device)
|
||||
|
||||
@tool(
|
||||
name="Answer Question About The Image",
|
||||
description=(
|
||||
"useful when you need an answer for a question based on an image."
|
||||
" like: what is the background color of the last image, how many"
|
||||
" cats in this figure, what is in this figure. The input to this"
|
||||
" tool should be a comma separated string of two, representing the"
|
||||
" image_path and the question"
|
||||
),
|
||||
)
|
||||
def inference(self, inputs):
|
||||
image_path, question = inputs.split(",")
|
||||
raw_image = Image.open(image_path).convert("RGB")
|
||||
inputs = self.processor(raw_image, question, return_tensors="pt").to(
|
||||
self.device, self.torch_dtype
|
||||
)
|
||||
out = self.model.generate(**inputs)
|
||||
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
logger.debug(
|
||||
f"\nProcessed VisualQuestionAnswering, Input Image: {image_path},"
|
||||
f" Input Question: {question}, Output Answer: {answer}"
|
||||
)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
class ImageCaptioning(BaseHandler):
|
||||
def __init__(self, device):
|
||||
print("Initializing ImageCaptioning to %s" % device)
|
||||
self.device = device
|
||||
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base"
|
||||
)
|
||||
self.model = BlipForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip-image-captioning-base",
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(self.device)
|
||||
|
||||
def handle(self, filename: str):
|
||||
img = Image.open(filename)
|
||||
width, height = img.size
|
||||
ratio = min(512 / width, 512 / height)
|
||||
width_new, height_new = (round(width * ratio), round(height * ratio))
|
||||
img = img.resize((width_new, height_new))
|
||||
img = img.convert("RGB")
|
||||
img.save(filename, "PNG")
|
||||
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
||||
|
||||
inputs = self.processor(Image.open(filename), return_tensors="pt").to(
|
||||
self.device, self.torch_dtype
|
||||
)
|
||||
out = self.model.generate(**inputs)
|
||||
description = self.processor.decode(out[0], skip_special_tokens=True)
|
||||
print(
|
||||
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output"
|
||||
f" Text: {description}"
|
||||
)
|
||||
|
||||
return IMAGE_PROMPT.format(filename=filename, description=description)
|
@ -1,45 +0,0 @@
|
||||
from swarms.tools.tool import tool
|
||||
from typing import Dict, Callable, Any, List
|
||||
|
||||
ToolBuilder = Callable[[Any], tool]
|
||||
FuncToolBuilder = Callable[[], ToolBuilder]
|
||||
|
||||
|
||||
class ToolsRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.tools: Dict[str, FuncToolBuilder] = {}
|
||||
|
||||
def register(self, tool_name: str, tool: FuncToolBuilder):
|
||||
print(f"will register {tool_name}")
|
||||
self.tools[tool_name] = tool
|
||||
|
||||
def build(self, tool_name, config):
|
||||
ret = self.tools[tool_name]()(config)
|
||||
if isinstance(ret, tool):
|
||||
return ret
|
||||
raise ValueError(
|
||||
"Tool builder {} did not return a Tool instance".format(tool_name)
|
||||
)
|
||||
|
||||
def list_tools(self) -> List[str]:
|
||||
return list(self.tools.keys())
|
||||
|
||||
|
||||
tools_registry = ToolsRegistry()
|
||||
|
||||
|
||||
def register(tool_name):
|
||||
def decorator(tool: FuncToolBuilder):
|
||||
tools_registry.register(tool_name, tool)
|
||||
return tool
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_tool(tool_name: str, config: Any) -> tool:
|
||||
print(f"will build {tool_name}")
|
||||
return tools_registry.build(tool_name, config)
|
||||
|
||||
|
||||
def list_tools() -> List[str]:
|
||||
return tools_registry.list_tools()
|
@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from swarms.memory.qdrant import Qdrant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
with patch("your_module.QdrantClient") as MockQdrantClient:
|
||||
yield MockQdrantClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sentence_transformer():
|
||||
with patch(
|
||||
"sentence_transformers.SentenceTransformer"
|
||||
) as MockSentenceTransformer:
|
||||
yield MockSentenceTransformer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_client(mock_qdrant_client, mock_sentence_transformer):
|
||||
client = Qdrant(api_key="your_api_key", host="your_host")
|
||||
yield client
|
||||
|
||||
|
||||
def test_qdrant_init(qdrant_client, mock_qdrant_client):
|
||||
assert qdrant_client.client is not None
|
||||
|
||||
|
||||
def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
|
||||
qdrant_client._load_embedding_model("model_name")
|
||||
mock_sentence_transformer.assert_called_once_with("model_name")
|
||||
|
||||
|
||||
def test_setup_collection(qdrant_client, mock_qdrant_client):
|
||||
qdrant_client._setup_collection()
|
||||
mock_qdrant_client.get_collection.assert_called_once_with(
|
||||
qdrant_client.collection_name
|
||||
)
|
||||
|
||||
|
||||
def test_add_vectors(qdrant_client, mock_qdrant_client):
|
||||
mock_doc = Mock(page_content="Sample text")
|
||||
qdrant_client.add_vectors([mock_doc])
|
||||
mock_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
|
||||
def test_search_vectors(qdrant_client, mock_qdrant_client):
|
||||
qdrant_client.search_vectors("test query")
|
||||
mock_qdrant_client.search.assert_called_once()
|
Loading…
Reference in new issue