code quality

pull/209/head
Kye 1 year ago
parent a4d953ec08
commit 4671f2facb

@ -1,44 +0,0 @@
OPENAI_API_KEY="your_openai_api_key_here"
GOOGLE_API_KEY=""
ANTHROPIC_API_KEY=""
AI21_API_KEY="your_api_key_here"
COHERE_API_KEY="your_api_key_here"
ALEPHALPHA_API_KEY="your_api_key_here"
HUGGINFACEHUB_API_KEY="your_api_key_here"
STABILITY_API_KEY="your_api_key_here"
WOLFRAM_ALPHA_APPID="your_wolfram_alpha_appid_here"
ZAPIER_NLA_API_KEY="your_zapier_nla_api_key_here"
EVAL_PORT=8000
MODEL_NAME="gpt-4"
CELERY_BROKER_URL="redis://localhost:6379"
SERVER="http://localhost:8000"
USE_GPU=True
PLAYGROUND_DIR="playground"
LOG_LEVEL="INFO"
BOT_NAME="Orca"
WINEDB_HOST="your_winedb_host_here"
WINEDB_PASSWORD="your_winedb_password_here"
BING_SEARCH_URL="your_bing_search_url_here"
BING_SUBSCRIPTION_KEY="your_bing_subscription_key_here"
SERPAPI_API_KEY="your_serpapi_api_key_here"
IFTTTKey="your_iftttkey_here"
BRAVE_API_KEY="your_brave_api_key_here"
SPOONACULAR_KEY="your_spoonacular_key_here"
HF_API_KEY="your_huggingface_api_key_here"
REDIS_HOST=
REDIS_PORT=
#dbs
PINECONE_API_KEY=""
BING_COOKIE=""
PSG_CONNECTION_STRING=""

@ -13,8 +13,8 @@ img = "images/swarms.jpeg"
## Initialize the workflow ## Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
max_loops="auto", max_loops="auto",
sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
) )
agent.run(task=task, img=img) agent.run(task=task, img=img)

@ -10,11 +10,16 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
stability_api_key = os.getenv("STABILITY_API_KEY") stability_api_key = os.getenv("STABILITY_API_KEY")
# Initialize the language model and image generation model # Initialize the language model and image generation model
llm = OpenAIChat(openai_api_key=openai_api_key, temperature=0.5, max_tokens=3000) llm = OpenAIChat(
openai_api_key=openai_api_key, temperature=0.5, max_tokens=3000
)
sd_api = StableDiffusion(api_key=stability_api_key) sd_api = StableDiffusion(api_key=stability_api_key)
def run_task(description, product_name, agent, **kwargs): def run_task(description, product_name, agent, **kwargs):
full_description = f"{description} about {product_name}" # Incorporate product name into the task full_description = ( # Incorporate product name into the task
f"{description} about {product_name}"
)
result = agent.run(task=full_description, **kwargs) result = agent.run(task=full_description, **kwargs)
return result return result
@ -23,8 +28,24 @@ def run_task(description, product_name, agent, **kwargs):
class ProductPromptGenerator: class ProductPromptGenerator:
def __init__(self, product_name): def __init__(self, product_name):
self.product_name = product_name self.product_name = product_name
self.themes = ["lightning", "sunset", "ice cave", "space", "forest", "ocean", "mountains", "cityscape"] self.themes = [
self.styles = ["translucent", "floating in mid-air", "expanded into pieces", "glowing", "mirrored", "futuristic"] "lightning",
"sunset",
"ice cave",
"space",
"forest",
"ocean",
"mountains",
"cityscape",
]
self.styles = [
"translucent",
"floating in mid-air",
"expanded into pieces",
"glowing",
"mirrored",
"futuristic",
]
self.contexts = ["high realism product ad (extremely creative)"] self.contexts = ["high realism product ad (extremely creative)"]
def generate_prompt(self): def generate_prompt(self):
@ -33,8 +54,12 @@ class ProductPromptGenerator:
context = random.choice(self.contexts) context = random.choice(self.contexts)
return f"{theme} inside a {style} {self.product_name}, {context}" return f"{theme} inside a {style} {self.product_name}, {context}"
# User input # User input
product_name = input("Enter a product name for ad creation (e.g., 'PS5', 'AirPods', 'Kirkland Vodka'): ") product_name = input(
"Enter a product name for ad creation (e.g., 'PS5', 'AirPods', 'Kirkland"
" Vodka'): "
)
# Generate creative concept # Generate creative concept
prompt_generator = ProductPromptGenerator(product_name) prompt_generator = ProductPromptGenerator(product_name)
@ -46,9 +71,17 @@ design_flow = Agent(llm=llm, max_loops=1, dashboard=False)
copywriting_flow = Agent(llm=llm, max_loops=1, dashboard=False) copywriting_flow = Agent(llm=llm, max_loops=1, dashboard=False)
# Execute tasks # Execute tasks
concept_result = run_task("Generate a creative concept", product_name, concept_flow) concept_result = run_task(
design_result = run_task("Suggest visual design ideas", product_name, design_flow) "Generate a creative concept", product_name, concept_flow
copywriting_result = run_task("Create compelling ad copy for the product photo", product_name, copywriting_flow) )
design_result = run_task(
"Suggest visual design ideas", product_name, design_flow
)
copywriting_result = run_task(
"Create compelling ad copy for the product photo",
product_name,
copywriting_flow,
)
# Generate product image # Generate product image
image_paths = sd_api.run(creative_prompt) image_paths = sd_api.run(creative_prompt)

@ -33,28 +33,17 @@ code
""" """
# Initialize the language model # Initialize the language model
llm = OpenAIChat( llm = OpenAIChat(openai_api_key=api_key, max_tokens=5000)
openai_api_key=api_key,
max_tokens=5000
)
# Documentation agent # Documentation agent
documentation_agent = Agent( documentation_agent = Agent(
llm=llm, llm=llm, sop=DOCUMENTATION_SOP, max_loops=1, multi_modal=True
sop=DOCUMENTATION_SOP,
max_loops=1,
multi_modal=True
) )
# Tests agent # Tests agent
tests_agent = Agent( tests_agent = Agent(llm=llm, sop=TEST_SOP, max_loops=2, multi_modal=True)
llm=llm,
sop=TEST_SOP,
max_loops=2,
multi_modal=True
)
# Run the documentation agent # Run the documentation agent
@ -64,5 +53,6 @@ documentation = documentation_agent.run(
# Run the tests agent # Run the tests agent
tests = tests_agent.run( tests = tests_agent.run(
f"Write tests for the following code:{TASK} here is the documentation: {documentation}" f"Write tests for the following code:{TASK} here is the documentation:"
) f" {documentation}"
)

@ -21,15 +21,15 @@ from swarms.models import GPT4VisionAPI
load_dotenv() load_dotenv()
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
llm = GPT4VisionAPI( llm = GPT4VisionAPI(openai_api_key=api_key)
openai_api_key=api_key
)
assembly_line = "playground/demos/swarm_of_mma_manufacturing/assembly_line.jpg" assembly_line = "playground/demos/swarm_of_mma_manufacturing/assembly_line.jpg"
red_robots = "playground/demos/swarm_of_mma_manufacturing/red_robots.jpg" red_robots = "playground/demos/swarm_of_mma_manufacturing/red_robots.jpg"
robots = "playground/demos/swarm_of_mma_manufacturing/robots.jpg" robots = "playground/demos/swarm_of_mma_manufacturing/robots.jpg"
tesla_assembly_line = "playground/demos/swarm_of_mma_manufacturing/tesla_assembly.jpg" tesla_assembly_line = (
"playground/demos/swarm_of_mma_manufacturing/tesla_assembly.jpg"
)
# Define detailed prompts for each agent # Define detailed prompts for each agent
@ -73,55 +73,37 @@ efficiency_prompt = tasks["efficiency"]
# Health security agent # Health security agent
health_security_agent = Agent( health_security_agent = Agent(
llm=llm, llm=llm, sop_list=health_safety_prompt, max_loops=2, multi_modal=True
sop_list=health_safety_prompt,
max_loops=2,
multi_modal=True
) )
# Quality control agent # Quality control agent
productivity_check_agent = Agent( productivity_check_agent = Agent(
llm=llm, llm=llm, sop=productivity_prompt, max_loops=2, multi_modal=True
sop=productivity_prompt,
max_loops=2,
multi_modal=True
) )
# Security agent # Security agent
security_check_agent = Agent( security_check_agent = Agent(
llm=llm, llm=llm, sop=security_prompt, max_loops=2, multi_modal=True
sop=security_prompt,
max_loops=2,
multi_modal=True
) )
# Efficiency agent # Efficiency agent
efficiency_check_agent = Agent( efficiency_check_agent = Agent(
llm=llm, llm=llm, sop=efficiency_prompt, max_loops=2, multi_modal=True
sop=efficiency_prompt,
max_loops=2,
multi_modal=True
) )
# Add the first task to the health_security_agent # Add the first task to the health_security_agent
health_check = health_security_agent.run( health_check = health_security_agent.run(
"Analyze the safety of this factory", "Analyze the safety of this factory", robots
robots
) )
# Add the third task to the productivity_check_agent # Add the third task to the productivity_check_agent
productivity_check = productivity_check_agent.run( productivity_check = productivity_check_agent.run(health_check, assembly_line)
health_check, assembly_line
)
# Add the fourth task to the security_check_agent # Add the fourth task to the security_check_agent
security_check = security_check_agent.add( security_check = security_check_agent.add(productivity_check, red_robots)
productivity_check, red_robots
)
# Add the fifth task to the efficiency_check_agent # Add the fifth task to the efficiency_check_agent
efficiency_check = efficiency_check_agent.run( efficiency_check = efficiency_check_agent.run(
security_check, tesla_assembly_line security_check, tesla_assembly_line
) )

@ -1,13 +1,19 @@
from langchain.document_loaders import CSVLoader from langchain.document_loaders import CSVLoader
from swarms.memory import qdrant from swarms.memory import qdrant
loader = CSVLoader(file_path="../document_parsing/aipg/aipg.csv", encoding='utf-8-sig') loader = CSVLoader(
file_path="../document_parsing/aipg/aipg.csv", encoding="utf-8-sig"
)
docs = loader.load() docs = loader.load()
# Initialize the Qdrant instance # Initialize the Qdrant instance
# See qdrant documentation on how to run locally # See qdrant documentation on how to run locally
qdrant_client = qdrant.Qdrant(host ="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io", collection_name="qdrant", api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ") qdrant_client = qdrant.Qdrant(
host="https://697ea26c-2881-4e17-8af4-817fcb5862e8.europe-west3-0.gcp.cloud.qdrant.io",
collection_name="qdrant",
api_key="BhG2_yINqNU-aKovSEBadn69Zszhbo5uaqdJ6G_qDkdySjAljvuPqQ",
)
qdrant_client.add_vectors(docs) qdrant_client.add_vectors(docs)
# Perform a search # Perform a search

@ -4,24 +4,33 @@ from httpx import RequestError
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct from qdrant_client.http.models import Distance, VectorParams, PointStruct
class Qdrant: class Qdrant:
def __init__(self, api_key: str, host: str, port: int = 6333, collection_name: str = "qdrant", model_name: str = "BAAI/bge-small-en-v1.5", https: bool = True): def __init__(
self,
api_key: str,
host: str,
port: int = 6333,
collection_name: str = "qdrant",
model_name: str = "BAAI/bge-small-en-v1.5",
https: bool = True,
):
""" """
Qdrant class for managing collections and performing vector operations using QdrantClient. Qdrant class for managing collections and performing vector operations using QdrantClient.
Attributes: Attributes:
client (QdrantClient): The Qdrant client for interacting with the Qdrant server. client (QdrantClient): The Qdrant client for interacting with the Qdrant server.
collection_name (str): Name of the collection to be managed in Qdrant. collection_name (str): Name of the collection to be managed in Qdrant.
model (SentenceTransformer): The model used for generating sentence embeddings. model (SentenceTransformer): The model used for generating sentence embeddings.
Args: Args:
api_key (str): API key for authenticating with Qdrant. api_key (str): API key for authenticating with Qdrant.
host (str): Host address of the Qdrant server. host (str): Host address of the Qdrant server.
port (int): Port number of the Qdrant server. Defaults to 6333. port (int): Port number of the Qdrant server. Defaults to 6333.
collection_name (str): Name of the collection to be used or created. Defaults to "qdrant". collection_name (str): Name of the collection to be used or created. Defaults to "qdrant".
model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5". model_name (str): Name of the model to be used for embeddings. Defaults to "BAAI/bge-small-en-v1.5".
https (bool): Flag to indicate if HTTPS should be used. Defaults to True. https (bool): Flag to indicate if HTTPS should be used. Defaults to True.
""" """
try: try:
self.client = QdrantClient(url=host, port=port, api_key=api_key) self.client = QdrantClient(url=host, port=port, api_key=api_key)
self.collection_name = collection_name self.collection_name = collection_name
@ -50,7 +59,10 @@ class Qdrant:
except Exception as e: except Exception as e:
self.client.create_collection( self.client.create_collection(
collection_name=self.collection_name, collection_name=self.collection_name,
vectors_config=VectorParams(size=self.model.get_sentence_embedding_dimension(), distance=Distance.DOT), vectors_config=VectorParams(
size=self.model.get_sentence_embedding_dimension(),
distance=Distance.DOT,
),
) )
print(f"Collection '{self.collection_name}' created.") print(f"Collection '{self.collection_name}' created.")
@ -67,11 +79,21 @@ class Qdrant:
points = [] points = []
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
try: try:
if 'page_content' in doc: if "page_content" in doc:
embedding = self.model.encode(doc['page_content'], normalize_embeddings=True) embedding = self.model.encode(
points.append(PointStruct(id=i + 1, vector=embedding, payload={"content": doc['page_content']})) doc["page_content"], normalize_embeddings=True
)
points.append(
PointStruct(
id=i + 1,
vector=embedding,
payload={"content": doc["page_content"]},
)
)
else: else:
print(f"Document at index {i} is missing 'page_content' key") print(
f"Document at index {i} is missing 'page_content' key"
)
except Exception as e: except Exception as e:
print(f"Error processing document at index {i}: {e}") print(f"Error processing document at index {i}: {e}")
@ -102,7 +124,7 @@ class Qdrant:
search_result = self.client.search( search_result = self.client.search(
collection_name=self.collection_name, collection_name=self.collection_name,
query_vector=query_vector, query_vector=query_vector,
limit=limit limit=limit,
) )
return search_result return search_result
except Exception as e: except Exception as e:

@ -15,8 +15,8 @@ from termcolor import colored
class BaseMultiModalModel: class BaseMultiModalModel:
""" """
Base class for multimodal models Base class for multimodal models
Args: Args:
model_name (Optional[str], optional): Model name. Defaults to None. model_name (Optional[str], optional): Model name. Defaults to None.
temperature (Optional[int], optional): Temperature. Defaults to 0.5. temperature (Optional[int], optional): Temperature. Defaults to 0.5.
@ -28,7 +28,7 @@ class BaseMultiModalModel:
device (Optional[str], optional): Device. Defaults to "cuda". device (Optional[str], optional): Device. Defaults to "cuda".
max_new_tokens (Optional[int], optional): Max new tokens. Defaults to 500. max_new_tokens (Optional[int], optional): Max new tokens. Defaults to 500.
retries (Optional[int], optional): Retries. Defaults to 3. retries (Optional[int], optional): Retries. Defaults to 3.
Examples: Examples:
>>> from swarms.models.base_multimodal_model import BaseMultiModalModel >>> from swarms.models.base_multimodal_model import BaseMultiModalModel
>>> model = BaseMultiModalModel() >>> model = BaseMultiModalModel()
@ -54,8 +54,9 @@ class BaseMultiModalModel:
>>> model.unique_chat_history() >>> model.unique_chat_history()
>>> model.clear_chat_history() >>> model.clear_chat_history()
>>> model.get_img_from_web("https://www.google.com/images/branding/googlelogo/") >>> model.get_img_from_web("https://www.google.com/images/branding/googlelogo/")
""" """
def __init__( def __init__(
self, self,
model_name: Optional[str], model_name: Optional[str],

@ -95,7 +95,13 @@ class GPT4VisionAPI:
pass pass
# Function to handle vision tasks # Function to handle vision tasks
def run(self, task: Optional[str] = None, img: Optional[str] = None, *args, **kwargs): def run(
self,
task: Optional[str] = None,
img: Optional[str] = None,
*args,
**kwargs,
):
"""Run the model.""" """Run the model."""
try: try:
base64_image = self.encode_image(img) base64_image = self.encode_image(img)
@ -286,14 +292,14 @@ class GPT4VisionAPI:
): ):
""" """
Run the model on multiple tasks and images all at once using concurrent Run the model on multiple tasks and images all at once using concurrent
Args: Args:
tasks (List[str]): List of tasks tasks (List[str]): List of tasks
imgs (List[str]): List of image paths imgs (List[str]): List of image paths
Returns: Returns:
List[str]: List of responses List[str]: List of responses
""" """
# Instantiate the thread pool executor # Instantiate the thread pool executor

@ -7,6 +7,7 @@ from typing import List
load_dotenv() load_dotenv()
class StableDiffusion: class StableDiffusion:
""" """
A class to interact with the Stable Diffusion API for generating images from text prompts. A class to interact with the Stable Diffusion API for generating images from text prompts.
@ -41,7 +42,16 @@ class StableDiffusion:
Generates an image based on the provided text prompt and returns the paths of the saved images. Generates an image based on the provided text prompt and returns the paths of the saved images.
""" """
def __init__(self, api_key: str, api_host: str = "https://api.stability.ai", cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30): def __init__(
self,
api_key: str,
api_host: str = "https://api.stability.ai",
cfg_scale: int = 7,
height: int = 1024,
width: int = 1024,
samples: int = 1,
steps: int = 30,
):
""" """
Initialize the StableDiffusion class with API configurations. Initialize the StableDiffusion class with API configurations.
@ -73,7 +83,7 @@ class StableDiffusion:
self.headers = { self.headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json" "Accept": "application/json",
} }
self.output_dir = "images" self.output_dir = "images"
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
@ -117,7 +127,9 @@ class StableDiffusion:
image_paths = [] image_paths = []
for i, image in enumerate(data["artifacts"]): for i, image in enumerate(data["artifacts"]):
unique_id = uuid.uuid4() # Generate a unique identifier unique_id = uuid.uuid4() # Generate a unique identifier
image_path = os.path.join(self.output_dir, f"{unique_id}_v1_txt2img_{i}.png") image_path = os.path.join(
self.output_dir, f"{unique_id}_v1_txt2img_{i}.png"
)
with open(image_path, "wb") as f: with open(image_path, "wb") as f:
f.write(base64.b64decode(image["base64"])) f.write(base64.b64decode(image["base64"]))
image_paths.append(image_path) image_paths.append(image_path)

@ -1,4 +1,3 @@
TEST_SOP = """ TEST_SOP = """
Create 500 extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any Create 500 extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any
just write the best tests possible and return the test code in markdown format. Create the tests for the code below and make it really high performance just write the best tests possible and return the test code in markdown format. Create the tests for the code below and make it really high performance
@ -174,4 +173,4 @@ To replicate the documentation for any other module or framework, follow the sam
############# DOCUMENT THE FOLLOWING CODE ######## ############# DOCUMENT THE FOLLOWING CODE ########
""" """

@ -620,7 +620,9 @@ class Agent:
# If autosave is enabled then save the state # If autosave is enabled then save the state
if self.autosave: if self.autosave:
save_path = self.saved_state_path or "flow_state.json" save_path = self.saved_state_path or "flow_state.json"
print(colored(f"Autosaving agent state to {save_path}", "green")) print(
colored(f"Autosaving agent state to {save_path}", "green")
)
self.save_state(save_path) self.save_state(save_path)
# If return history is enabled then return the response and history # If return history is enabled then return the response and history

@ -140,7 +140,9 @@ class SequentialWorkflow:
""" """
# If the agent is a Agent instance, we include the task in kwargs for Agent.run() # If the agent is a Agent instance, we include the task in kwargs for Agent.run()
if isinstance(agent, Agent): if isinstance(agent, Agent):
kwargs["task"] = task # Set the task as a keyword argument for Agent kwargs["task"] = (
task # Set the task as a keyword argument for Agent
)
# Append the task to the tasks list # Append the task to the tasks list
if self.img: if self.img:
@ -156,7 +158,10 @@ class SequentialWorkflow:
else: else:
self.tasks.append( self.tasks.append(
Task( Task(
description=task, agent=agent, args=list(args), kwargs=kwargs description=task,
agent=agent,
args=list(args),
kwargs=kwargs,
) )
) )
@ -448,7 +453,9 @@ class SequentialWorkflow:
) )
else: else:
# If it's not a Agent instance, call the agent directly # If it's not a Agent instance, call the agent directly
task.result = await task.agent(*task.args, **task.kwargs) task.result = await task.agent(
*task.args, **task.kwargs
)
# Pass the result as an argument to the next task if it exists # Pass the result as an argument to the next task if it exists
next_task_index = self.tasks.index(task) + 1 next_task_index = self.tasks.index(task) + 1

@ -6,35 +6,46 @@ from swarms.memory.qdrant import Qdrant
@pytest.fixture @pytest.fixture
def mock_qdrant_client(): def mock_qdrant_client():
with patch('your_module.QdrantClient') as MockQdrantClient: with patch("your_module.QdrantClient") as MockQdrantClient:
yield MockQdrantClient() yield MockQdrantClient()
@pytest.fixture @pytest.fixture
def mock_sentence_transformer(): def mock_sentence_transformer():
with patch('sentence_transformers.SentenceTransformer') as MockSentenceTransformer: with patch(
"sentence_transformers.SentenceTransformer"
) as MockSentenceTransformer:
yield MockSentenceTransformer() yield MockSentenceTransformer()
@pytest.fixture @pytest.fixture
def qdrant_client(mock_qdrant_client, mock_sentence_transformer): def qdrant_client(mock_qdrant_client, mock_sentence_transformer):
client = Qdrant(api_key="your_api_key", host="your_host") client = Qdrant(api_key="your_api_key", host="your_host")
yield client yield client
def test_qdrant_init(qdrant_client, mock_qdrant_client): def test_qdrant_init(qdrant_client, mock_qdrant_client):
assert qdrant_client.client is not None assert qdrant_client.client is not None
def test_load_embedding_model(qdrant_client, mock_sentence_transformer): def test_load_embedding_model(qdrant_client, mock_sentence_transformer):
qdrant_client._load_embedding_model("model_name") qdrant_client._load_embedding_model("model_name")
mock_sentence_transformer.assert_called_once_with("model_name") mock_sentence_transformer.assert_called_once_with("model_name")
def test_setup_collection(qdrant_client, mock_qdrant_client): def test_setup_collection(qdrant_client, mock_qdrant_client):
qdrant_client._setup_collection() qdrant_client._setup_collection()
mock_qdrant_client.get_collection.assert_called_once_with(qdrant_client.collection_name) mock_qdrant_client.get_collection.assert_called_once_with(
qdrant_client.collection_name
)
def test_add_vectors(qdrant_client, mock_qdrant_client): def test_add_vectors(qdrant_client, mock_qdrant_client):
mock_doc = Mock(page_content="Sample text") mock_doc = Mock(page_content="Sample text")
qdrant_client.add_vectors([mock_doc]) qdrant_client.add_vectors([mock_doc])
mock_qdrant_client.upsert.assert_called_once() mock_qdrant_client.upsert.assert_called_once()
def test_search_vectors(qdrant_client, mock_qdrant_client): def test_search_vectors(qdrant_client, mock_qdrant_client):
qdrant_client.search_vectors("test query") qdrant_client.search_vectors("test query")
mock_qdrant_client.search.assert_called_once() mock_qdrant_client.search.assert_called_once()

@ -1168,7 +1168,9 @@ def test_flow_from_llm_and_template_file():
llm_instance = mocked_llm # Replace with your LLM class llm_instance = mocked_llm # Replace with your LLM class
template_file = "template.txt" # Create a template file for testing template_file = "template.txt" # Create a template file for testing
flow_instance = Agent.from_llm_and_template_file(llm_instance, template_file) flow_instance = Agent.from_llm_and_template_file(
llm_instance, template_file
)
assert isinstance(flow_instance, Agent) assert isinstance(flow_instance, Agent)

Loading…
Cancel
Save