Start work on lints

pull/443/head
Wyatt Stanke 9 months ago
parent f7b00466e2
commit f5a56bef97
No known key found for this signature in database
GPG Key ID: CE6BA5FFF135536D

@ -12,7 +12,7 @@
# For more information on Codacy Analysis CLI in general, see # For more information on Codacy Analysis CLI in general, see
# https://github.com/codacy/codacy-analysis-cli. # https://github.com/codacy/codacy-analysis-cli.
name: Codacy Security Scan name: Codacy
on: on:
push: push:
branches: ["master"] branches: ["master"]

@ -1,5 +1,5 @@
--- ---
name: Docs WorkAgent name: Documentation
on: on:
push: push:
branches: branches:
@ -18,9 +18,3 @@ jobs:
- run: pip install mkdocs-glightbox - run: pip install mkdocs-glightbox
- run: pip install "mkdocstrings[python]" - run: pip install "mkdocstrings[python]"
- run: mkdocs gh-deploy --force - run: mkdocs gh-deploy --force
preview:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: swarms

@ -0,0 +1,16 @@
name: Documentation Links
on:
pull_request_target:
types:
- opened
permissions:
pull-requests: write
jobs:
documentation-links:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: "swarms"

@ -4,11 +4,11 @@
# You can adjust the behavior by modifying this file. # You can adjust the behavior by modifying this file.
# For more information, see: # For more information, see:
# https://github.com/actions/stale # https://github.com/actions/stale
name: Mark stale issues and pull requests name: Stale
on: on:
schedule: schedule:
# Scheduled to run at 1.30 UTC everyday # Scheduled to run at 1.30 UTC everyday
- cron: '30 1 * * *' - cron: "0 0 * * *"
jobs: jobs:
stale: stale:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -44,6 +44,6 @@ jobs:
Please open a new pull request if you need further assistance. Thanks! Please open a new pull request if you need further assistance. Thanks!
# Label that can be assigned to issues to exclude them from being marked as stale # Label that can be assigned to issues to exclude them from being marked as stale
exempt-issue-labels: 'override-stale' exempt-issue-labels: "override-stale"
# Label that can be assigned to PRs to exclude them from being marked as stale # Label that can be assigned to PRs to exclude them from being marked as stale
exempt-pr-labels: "override-stale" exempt-pr-labels: "override-stale"

@ -1,5 +1,5 @@
--- ---
name: Welcome WorkAgent name: Welcome
on: on:
issues: issues:
types: [opened] types: [opened]
@ -14,7 +14,9 @@ jobs:
- uses: actions/first-interaction@v1.3.0 - uses: actions/first-interaction@v1.3.0
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
issue-message: "Hello there, thank you for opening an Issue ! 🙏🏻 The team issue-message:
"Hello there, thank you for opening an Issue ! 🙏🏻 The team
was notified and they will get back to you asap." was notified and they will get back to you asap."
pr-message: "Hello there, thank you for opening an PR ! 🙏🏻 The team was pr-message:
"Hello there, thank you for opening an PR ! 🙏🏻 The team was
notified and they will get back to you asap." notified and they will get back to you asap."

@ -1,6 +1,6 @@
from swarms import Agent, Anthropic from swarms import Agent, Anthropic
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
agent_name="Transcript Generator", agent_name="Transcript Generator",
agent_description=( agent_description=(

@ -1,6 +1,6 @@
from swarms import Agent, OpenAIChat from swarms import Agent, OpenAIChat
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=OpenAIChat(), llm=OpenAIChat(),
max_loops="auto", max_loops="auto",

@ -10,7 +10,7 @@ def search_api(query: str, max_results: int = 10):
return f"Search API: {query} -> {max_results} results" return f"Search API: {query} -> {max_results} results"
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
agent_name="Youtube Transcript Generator", agent_name="Youtube Transcript Generator",
agent_description=( agent_description=(

@ -21,7 +21,7 @@ llm = GPT4VisionAPI(
task = "What is the color of the object?" task = "What is the color of the object?"
img = "images/swarms.jpeg" img = "images/swarms.jpeg"
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops="auto", max_loops="auto",

@ -11,7 +11,7 @@ task = (
) )
img = "assembly_line.jpg" img = "assembly_line.jpg"
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=1, max_loops=1,

@ -9,7 +9,7 @@ llm = GPT4VisionAPI()
task = "What is the color of the object?" task = "What is the color of the object?"
img = "images/swarms.jpeg" img = "images/swarms.jpeg"
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1, sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,

@ -6,7 +6,7 @@ llm = GPT4VisionAPI()
task = "What is the color of the object?" task = "What is the color of the object?"
img = "images/swarms.jpeg" img = "images/swarms.jpeg"
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops="auto", max_loops="auto",

@ -22,7 +22,7 @@ llm = GPT4VisionAPI(
task = "This is an eye test. What do you see?" task = "This is an eye test. What do you see?"
img = "playground/demos/multi_modal_chain_of_thought/eyetest.jpg" img = "playground/demos/multi_modal_chain_of_thought/eyetest.jpg"
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=2, max_loops=2,

@ -19,7 +19,7 @@ llm = HuggingfaceLLM(
temperature=0.5, temperature=0.5,
) )
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops="auto", max_loops="auto",

@ -26,7 +26,7 @@ print(
f" {sys.stderr}" f" {sys.stderr}"
) )
## Initialize the workflow # Initialize the workflow
agent = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True) agent = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True)
# Run the workflow on a task # Run the workflow on a task

@ -17,7 +17,7 @@ llm = OpenAIChat(
openai_api_key=api_key, openai_api_key=api_key,
) )
## Initialize the workflow # Initialize the workflow
agent = Agent(llm=llm, max_loops=1, agent_name="Social Media Manager") agent = Agent(llm=llm, max_loops=1, agent_name="Social Media Manager")
agent2 = Agent(llm=llm, max_loops=1, agent_name=" Product Manager") agent2 = Agent(llm=llm, max_loops=1, agent_name=" Product Manager")
agent3 = Agent(llm=llm, max_loops=1, agent_name="SEO Manager") agent3 = Agent(llm=llm, max_loops=1, agent_name="SEO Manager")

@ -11,7 +11,7 @@ llm = OpenAIChat(
# max_tokens=100, # max_tokens=100,
) )
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=2, max_loops=2,

@ -26,7 +26,7 @@ llm = OpenAIChat(
max_tokens=1000, max_tokens=1000,
) )
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=4, max_loops=4,

@ -66,7 +66,7 @@ llm = OpenAIChat(
) )
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
agent_name="Research Agent", agent_name="Research Agent",
llm=llm, llm=llm,

@ -20,7 +20,7 @@ llm = OpenAIChat(
) )
## Initialize the workflow # Initialize the workflow
agent = Agent(llm=llm, max_loops=1, dashboard=True) agent = Agent(llm=llm, max_loops=1, dashboard=True)

@ -1,6 +1,6 @@
from swarms import Agent, AzureOpenAI from swarms import Agent, AzureOpenAI
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=AzureOpenAI(), llm=AzureOpenAI(),
max_loops="auto", max_loops="auto",

@ -10,7 +10,7 @@ class ExampleLLM(AbstractLLM):
pass pass
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=ExampleLLM(), llm=ExampleLLM(),
max_loops="auto", max_loops="auto",

@ -1,6 +1,6 @@
from swarms import Agent, OpenAIChat from swarms import Agent, OpenAIChat
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=OpenAIChat(), llm=OpenAIChat(),
max_loops=1, max_loops=1,

@ -31,7 +31,7 @@ together_llm = TogetherLLM(
together_api_key=os.getenv("TOGETHER_API_KEY"), max_tokens=3000 together_api_key=os.getenv("TOGETHER_API_KEY"), max_tokens=3000
) )
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=anthropic, llm=anthropic,
max_loops=1, max_loops=1,

@ -26,7 +26,7 @@ def search_api(query: str) -> str:
print(f"Searching API for {query}") print(f"Searching API for {query}")
## Initialize the workflow # Initialize the workflow
agent = Agent( agent = Agent(
llm=llm, llm=llm,
max_loops=5, max_loops=5,

@ -71,13 +71,14 @@ pandas = "^2.2.2"
fastapi = "^0.110.1" fastapi = "^0.110.1"
[tool.ruff] [tool.ruff]
line-length = 127 line-length = 128
[tool.ruff.lint] [tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "W", "E501", "I", "UP"] select = ["E", "F", "W", "I", "UP"]
ignore = [] ignore = []
fixable = ["ALL"] fixable = ["ALL"]
unfixable = [] unfixable = []
preview = true
[tool.black] [tool.black]
line-length = 70 line-length = 70

@ -1,4 +1,4 @@
###### VERISON2 # VERISON2
import inspect import inspect
import os import os
import threading import threading

@ -34,7 +34,7 @@ commands: {
""" """
########### FEW SHOT EXAMPLES ################ # FEW SHOT EXAMPLES ################
SCENARIOS = """ SCENARIOS = """
commands: { commands: {
"tools": { "tools": {

@ -17,7 +17,8 @@ class GraphWorkflow(BaseStructure):
connect(from_node, to_node): Connects two nodes in the graph. connect(from_node, to_node): Connects two nodes in the graph.
set_entry_point(node_name): Sets the entry point node for the workflow. set_entry_point(node_name): Sets the entry point node for the workflow.
add_edge(from_node, to_node): Adds an edge between two nodes in the graph. add_edge(from_node, to_node): Adds an edge between two nodes in the graph.
add_conditional_edges(from_node, condition, edge_dict): Adds conditional edges from a node to multiple nodes based on a condition. add_conditional_edges(from_node, condition, edge_dict):
Adds conditional edges from a node to multiple nodes based on a condition.
run(): Runs the workflow and returns the graph. run(): Runs the workflow and returns the graph.
Examples: Examples:
@ -126,15 +127,11 @@ class GraphWorkflow(BaseStructure):
if from_node in self.graph: if from_node in self.graph:
for condition_value, to_node in edge_dict.items(): for condition_value, to_node in edge_dict.items():
if to_node in self.graph: if to_node in self.graph:
self.graph[from_node]["edges"][ self.graph[from_node]["edges"][to_node] = condition
to_node
] = condition
else: else:
raise ValueError("Node does not exist in graph") raise ValueError("Node does not exist in graph")
else: else:
raise ValueError( raise ValueError(f"Node {from_node} does not exist in graph")
f"Node {from_node} does not exist in graph"
)
def run(self): def run(self):
""" """
@ -160,9 +157,7 @@ class GraphWorkflow(BaseStructure):
ValueError: _description_ ValueError: _description_
""" """
if node_name not in self.graph: if node_name not in self.graph:
raise ValueError( raise ValueError(f"Node {node_name} does not exist in graph")
f"Node {node_name} does not exist in graph"
)
def _check_nodes_exist(self, from_node, to_node): def _check_nodes_exist(self, from_node, to_node):
""" """

@ -20,9 +20,7 @@ def _hash(input: str):
return hex_dig return hex_dig
def msg_hash( def msg_hash(agent: Agent, content: str, turn: int, msg_type: str = "text"):
agent: Agent, content: str, turn: int, msg_type: str = "text"
):
""" """
Generate a hash value for a message. Generate a hash value for a message.
@ -37,8 +35,7 @@ def msg_hash(
""" """
time = time_ns() time = time_ns()
return _hash( return _hash(
f"agent: {agent.agent_name}\ncontent: {content}\ntimestamp:" f"agent: {agent.agent_name}\ncontent: {content}\ntimestamp:" f" {str(time)}\nturn: {turn}\nmsg_type: {msg_type}"
f" {str(time)}\nturn: {turn}\nmsg_type: {msg_type}"
) )
@ -67,11 +64,17 @@ class MessagePool:
>>> message_pool.add(agent=agent2, content="Hello, agent1!", turn=1) >>> message_pool.add(agent=agent2, content="Hello, agent1!", turn=1)
>>> message_pool.add(agent=agent3, content="Hello, agent1!", turn=1) >>> message_pool.add(agent=agent3, content="Hello, agent1!", turn=1)
>>> message_pool.get_all_messages() >>> message_pool.get_all_messages()
[{'agent': Agent(agent_name='agent1'), 'content': 'Hello, agent2!', 'turn': 1, 'visible_to': 'all', 'logged': True}, {'agent': Agent(agent_name='agent2'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}, {'agent': Agent(agent_name='agent3'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}] [{'agent': Agent(agent_name='agent1'), 'content': 'Hello, agent2!', 'turn': 1, 'visible_to': 'all', 'logged': True},
{'agent': Agent(agent_name='agent2'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True},
{'agent': Agent(agent_name='agent3'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}]
>>> message_pool.get_visible_messages(agent=agent1, turn=1) >>> message_pool.get_visible_messages(agent=agent1, turn=1)
[{'agent': Agent(agent_name='agent1'), 'content': 'Hello, agent2!', 'turn': 1, 'visible_to': 'all', 'logged': True}, {'agent': Agent(agent_name='agent2'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}, {'agent': Agent(agent_name='agent3'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}] [{'agent': Agent(agent_name='agent1'), 'content': 'Hello, agent2!', 'turn': 1, 'visible_to': 'all', 'logged': True},
{'agent': Agent(agent_name='agent2'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True},
{'agent': Agent(agent_name='agent3'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}]
>>> message_pool.get_visible_messages(agent=agent2, turn=1) >>> message_pool.get_visible_messages(agent=agent2, turn=1)
[{'agent': Agent(agent_name='agent1'), 'content': 'Hello, agent2!', 'turn': 1, 'visible_to': 'all', 'logged': True}, {'agent': Agent(agent_name='agent2'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}, {'agent': Agent(agent_name='agent3'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}] [{'agent': Agent(agent_name='agent1'), 'content': 'Hello, agent2!', 'turn': 1, 'visible_to': 'all', 'logged': True},
{'agent': Agent(agent_name='agent2'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True},
{'agent': Agent(agent_name='agent3'), 'content': 'Hello, agent1!', 'turn': 1, 'visible_to': 'all', 'logged': True}]
""" """
def __init__( def __init__(
@ -98,9 +101,7 @@ class MessagePool:
logger.info("MessagePool initialized") logger.info("MessagePool initialized")
logger.info(f"Number of agents: {len(agents)}") logger.info(f"Number of agents: {len(agents)}")
logger.info( logger.info(f"Agents: {[agent.agent_name for agent in agents]}")
f"Agents: {[agent.agent_name for agent in agents]}"
)
logger.info(f"moderator: {moderator.agent_name} is available") logger.info(f"moderator: {moderator.agent_name} is available")
logger.info(f"Number of turns: {turns}") logger.info(f"Number of turns: {turns}")
@ -187,18 +188,11 @@ class MessagePool:
List[Dict]: The list of visible messages. List[Dict]: The list of visible messages.
""" """
# Get the messages before the current turn # Get the messages before the current turn
prev_messages = [ prev_messages = [message for message in self.messages if message["turn"] < turn]
message
for message in self.messages
if message["turn"] < turn
]
visible_messages = [] visible_messages = []
for message in prev_messages: for message in prev_messages:
if ( if message["visible_to"] == "all" or agent.agent_name in message["visible_to"]:
message["visible_to"] == "all"
or agent.agent_name in message["visible_to"]
):
visible_messages.append(message) visible_messages.append(message)
return visible_messages return visible_messages

@ -7,20 +7,18 @@ from swarms.prompts.meta_system_prompt import (
) )
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
meta_prompter_llm = OpenAIChat( meta_prompter_llm = OpenAIChat(system_prompt=str(meta_system_prompt_generator))
system_prompt=str(meta_system_prompt_generator)
)
def meta_system_prompt( def meta_system_prompt(agent: Union[Agent, AbstractLLM], system_prompt: str) -> str:
agent: Union[Agent, AbstractLLM], system_prompt: str
) -> str:
""" """
Generates a meta system prompt for the given agent using the provided system prompt. Generates a meta system prompt for the given agent using the provided system prompt.
Args: Args:
agent (Union[Agent, AbstractLLM]): The agent or LLM (Language Learning Model) for which the meta system prompt is generated. agent (Union[Agent, AbstractLLM]):
system_prompt (str): The system prompt used to generate the meta system prompt. The agent or LLM (Language Learning Model) for which the meta system prompt is generated.
system_prompt (str):
The system prompt used to generate the meta system prompt.
Returns: Returns:
str: The generated meta system prompt. str: The generated meta system prompt.

@ -17,7 +17,8 @@ def scrape_tool_func_docs(fn: Callable) -> str:
fn (Callable): The function to scrape. fn (Callable): The function to scrape.
Returns: Returns:
str: A string containing the function's name, documentation string, and a list of its parameters. Each parameter is represented as a line containing the parameter's name, default value, and annotation. str: A string containing the function's name, documentation string, and a list of its parameters.
Each parameter is represented as a line containing the parameter's name, default value, and annotation.
""" """
try: try:
# If the function is a tool, get the original function # If the function is a tool, get the original function
@ -34,10 +35,7 @@ def scrape_tool_func_docs(fn: Callable) -> str:
f" {param.annotation if param.annotation is not param.empty else 'None'}" f" {param.annotation if param.annotation is not param.empty else 'None'}"
) )
parameters_str = "\n".join(parameters) parameters_str = "\n".join(parameters)
return ( return f"Function: {fn.__name__}\nDocstring:" f" {inspect.getdoc(fn)}\nParameters:\n{parameters_str}"
f"Function: {fn.__name__}\nDocstring:"
f" {inspect.getdoc(fn)}\nParameters:\n{parameters_str}"
)
except Exception as error: except Exception as error:
print( print(
colored( colored(

@ -18,7 +18,8 @@ def load_model_torch(
model_path (str): Path to the saved model file. model_path (str): Path to the saved model file.
device (torch.device): Device to move the model to. device (torch.device): Device to move the model to.
model (nn.Module): The model architecture, if the model file only contains the state dictionary. model (nn.Module): The model architecture, if the model file only contains the state dictionary.
strict (bool): Whether to strictly enforce that the keys in the state dictionary match the keys returned by the model's `state_dict()` function. strict (bool): Whether to strictly enforce that the keys in the state dictionary match
the keys returned by the model's `state_dict()` function.
map_location (callable): A function to remap the storage locations of the loaded model. map_location (callable): A function to remap the storage locations of the loaded model.
*args: Additional arguments to pass to `torch.load`. *args: Additional arguments to pass to `torch.load`.
**kwargs: Additional keyword arguments to pass to `torch.load`. **kwargs: Additional keyword arguments to pass to `torch.load`.
@ -31,15 +32,11 @@ def load_model_torch(
RuntimeError: If there is an error while loading the model. RuntimeError: If there is an error while loading the model.
""" """
if device is None: if device is None:
device = torch.device( device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"cuda" if torch.cuda.is_available() else "cpu"
)
try: try:
if model is None: if model is None:
model = torch.load( model = torch.load(model_path, map_location=map_location, *args, **kwargs)
model_path, map_location=map_location, *args, **kwargs
)
else: else:
model.load_state_dict( model.load_state_dict(
torch.load( torch.load(

@ -18,10 +18,7 @@ def llm_instance():
# Test for instantiation and attributes # Test for instantiation and attributes
def test_llm_initialization(llm_instance): def test_llm_initialization(llm_instance):
assert ( assert llm_instance.model_id == "NousResearch/Nous-Hermes-2-Vision-Alpha"
llm_instance.model_id
== "NousResearch/Nous-Hermes-2-Vision-Alpha"
)
assert llm_instance.max_length == 500 assert llm_instance.max_length == 500
# ... add more assertions for all default attributes # ... add more assertions for all default attributes
@ -88,15 +85,11 @@ def test_llm_memory_consumption(llm_instance):
) )
def test_llm_initialization_params(model_id, max_length): def test_llm_initialization_params(model_id, max_length):
if max_length: if max_length:
instance = HuggingfaceLLM( instance = HuggingfaceLLM(model_id=model_id, max_length=max_length)
model_id=model_id, max_length=max_length
)
assert instance.max_length == max_length assert instance.max_length == max_length
else: else:
instance = HuggingfaceLLM(model_id=model_id) instance = HuggingfaceLLM(model_id=model_id)
assert ( assert instance.max_length == 500 # Assuming 500 is the default max_length
instance.max_length == 500
) # Assuming 500 is the default max_length
# Test for setting an invalid device # Test for setting an invalid device
@ -144,9 +137,7 @@ def test_llm_run_output_length(mock_run, llm_instance):
# Test the tokenizer handling special tokens correctly # Test the tokenizer handling special tokens correctly
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.encode")
@patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode") @patch("swarms.models.huggingface.HuggingfaceLLM._tokenizer.decode")
def test_llm_tokenizer_special_tokens( def test_llm_tokenizer_special_tokens(mock_decode, mock_encode, llm_instance):
mock_decode, mock_encode, llm_instance
):
mock_encode.return_value = "encoded input with special tokens" mock_encode.return_value = "encoded input with special tokens"
mock_decode.return_value = "decoded output with special tokens" mock_decode.return_value = "decoded output with special tokens"
result = llm_instance.run("test task with special tokens") result = llm_instance.run("test task with special tokens")
@ -172,9 +163,7 @@ def test_llm_response_time(mock_run, llm_instance):
start_time = time.time() start_time = time.time()
llm_instance.run("test task for response time") llm_instance.run("test task for response time")
end_time = time.time() end_time = time.time()
assert ( assert end_time - start_time < 1 # Assuming the response should be faster than 1 second
end_time - start_time < 1
) # Assuming the response should be faster than 1 second
# Test the logging of a warning for long inputs # Test the logging of a warning for long inputs
@ -197,13 +186,9 @@ def test_llm_run_model_exception(mock_generate, llm_instance):
# Test the behavior when GPU is forced but not available # Test the behavior when GPU is forced but not available
@patch("torch.cuda.is_available", return_value=False) @patch("torch.cuda.is_available", return_value=False)
def test_llm_force_gpu_when_unavailable( def test_llm_force_gpu_when_unavailable(mock_is_available, llm_instance):
mock_is_available, llm_instance
):
with pytest.raises(EnvironmentError): with pytest.raises(EnvironmentError):
llm_instance.set_device( llm_instance.set_device("cuda") # Attempt to set CUDA when it's not available
"cuda"
) # Attempt to set CUDA when it's not available
# Test for proper cleanup after model use (releasing resources) # Test for proper cleanup after model use (releasing resources)
@ -221,9 +206,7 @@ def test_llm_multilingual_input(mock_run, llm_instance):
mock_run.return_value = "mocked multilingual output" mock_run.return_value = "mocked multilingual output"
multilingual_input = "Bonjour, ceci est un test multilingue." multilingual_input = "Bonjour, ceci est un test multilingue."
result = llm_instance.run(multilingual_input) result = llm_instance.run(multilingual_input)
assert isinstance( assert isinstance(result, str) # Simple check to ensure output is string type
result, str
) # Simple check to ensure output is string type
# Test caching mechanism to prevent re-running the same inputs # Test caching mechanism to prevent re-running the same inputs
@ -238,5 +221,7 @@ def test_llm_caching_mechanism(mock_run, llm_instance):
assert first_run_result == second_run_result assert first_run_result == second_run_result
# These tests are provided as examples. In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class. # These tests are provided as examples.
# For instance, "mock_model.delete.assert_called_once()" and similar lines are based on hypothetical methods and behaviors that you need to replace with actual implementations. # In real-world scenarios, you will need to adapt these tests to the actual logic of your `HuggingfaceLLM` class.
# For instance, "mock_model.delete.assert_called_once()" and similar lines are based on hypothetical methods and behaviors
# that you need to replace with actual implementations.

Loading…
Cancel
Save