From afb69288e27d34ffb767b8e5cc65312333e92d25 Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 9 Oct 2023 19:58:36 -0400 Subject: [PATCH] tests for structs and models Former-commit-id: 2baf79f70ad4fe794e405cd00aeae968ea9214dc --- swarms/structs/link.py | 352 ---------------------------- tests/models/mistral.py | 41 ++++ tests/structs/nonlinear_workflow.py | 63 +++++ tests/structs/workflow.py | 59 +++++ 4 files changed, 163 insertions(+), 352 deletions(-) delete mode 100644 swarms/structs/link.py create mode 100644 tests/models/mistral.py create mode 100644 tests/structs/nonlinear_workflow.py create mode 100644 tests/structs/workflow.py diff --git a/swarms/structs/link.py b/swarms/structs/link.py deleted file mode 100644 index cb1ac2af..00000000 --- a/swarms/structs/link.py +++ /dev/null @@ -1,352 +0,0 @@ -from __future__ import annotations - -"""Links are like Chains from Langlink but more fluid and seamless""" -"""Chain that just formats a prompt and calls an LLM.""" - -import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - -from langlink.callbacks.manager import ( - AsyncCallbackManager, - AsyncCallbackManagerForChainRun, - CallbackManager, - CallbackManagerForChainRun, - Callbacks, -) -from langlink.links.base import Chain -from langlink.load.dump import dumpd -from langlink.prompts.prompt import PromptTemplate -from langlink.pydantic_v1 import Extra, Field -from langlink.schema import ( - BaseLLMOutputParser, - BasePromptTemplate, - LLMResult, - PromptValue, - StrOutputParser, -) -from langlink.schema.language_model import BaseLanguageModel -from langlink.utils.input import get_colored_text - - -class Link(Chain): - """Chain to run queries against LLMs. - - Example: - .. code-block:: python - - from langlink.links import Link - from langlink.llms import OpenAI - from langlink.prompts import PromptTemplate - prompt_template = "Tell me a {adjective} joke" - prompt = PromptTemplate( - input_variables=["adjective"], template=prompt_template - ) - llm = Link(llm=OpenAI(), prompt=prompt) - """ - - @classmethod - def is_lc_serializable(self) -> bool: - return True - - prompt: BasePromptTemplate - """Prompt object to use.""" - - llm: BaseLanguageModel - """Language model to call.""" - - output_key: str = "text" #: :meta private: - - output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser) - """Output parser to use. - Defaults to one that takes the most likely string but does not change it - otherwise.""" - - return_final_only: bool = True - """Whether to return only the final parsed result. Defaults to True. - If false, will return a bunch of extra information about the generation.""" - - llm_kwargs: dict = Field(default_factory=dict) - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Will be whatever keys the prompt expects. - - :meta private: - """ - return self.prompt.input_variables - - @property - def output_keys(self) -> List[str]: - """Will always return text key. - - :meta private: - """ - if self.return_final_only: - return [self.output_key] - else: - return [self.output_key, "full_generation"] - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - response = self.run([inputs], run_manager=run_manager) - return self.create_outputs(response)[0] - - def run( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> LLMResult: - """Generate LLM result from inputs.""" - prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) - return self.llm.run_prompt( - prompts, - stop, - callbacks=run_manager.get_child() if run_manager else None, - **self.llm_kwargs, - ) - - async def arun( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> LLMResult: - """Generate LLM result from inputs.""" - prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager) - return await self.llm.arun_prompt( - prompts, - stop, - callbacks=run_manager.get_child() if run_manager else None, - **self.llm_kwargs, - ) - - def prep_prompts( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Tuple[List[PromptValue], Optional[List[str]]]: - """Prepare prompts from inputs.""" - stop = None - if len(input_list) == 0: - return [], stop - if "stop" in input_list[0]: - stop = input_list[0]["stop"] - prompts = [] - for inputs in input_list: - selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} - prompt = self.prompt.format_prompt(**selected_inputs) - _colored_text = get_colored_text(prompt.to_string(), "green") - _text = "Prompt after formatting:\n" + _colored_text - if run_manager: - run_manager.on_text(_text, end="\n", verbose=self.verbose) - if "stop" in inputs and inputs["stop"] != stop: - raise ValueError( - "If `stop` is present in any inputs, should be present in all." - ) - prompts.append(prompt) - return prompts, stop - - async def aprep_prompts( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Tuple[List[PromptValue], Optional[List[str]]]: - """Prepare prompts from inputs.""" - stop = None - if len(input_list) == 0: - return [], stop - if "stop" in input_list[0]: - stop = input_list[0]["stop"] - prompts = [] - for inputs in input_list: - selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} - prompt = self.prompt.format_prompt(**selected_inputs) - _colored_text = get_colored_text(prompt.to_string(), "green") - _text = "Prompt after formatting:\n" + _colored_text - if run_manager: - await run_manager.on_text(_text, end="\n", verbose=self.verbose) - if "stop" in inputs and inputs["stop"] != stop: - raise ValueError( - "If `stop` is present in any inputs, should be present in all." - ) - prompts.append(prompt) - return prompts, stop - - def apply( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> List[Dict[str, str]]: - """Utilize the LLM run method for speed gains.""" - callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose - ) - run_manager = callback_manager.on_link_start( - dumpd(self), - {"input_list": input_list}, - ) - try: - response = self.run(input_list, run_manager=run_manager) - except BaseException as e: - run_manager.on_link_error(e) - raise e - outputs = self.create_outputs(response) - run_manager.on_link_end({"outputs": outputs}) - return outputs - - async def aapply( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> List[Dict[str, str]]: - """Utilize the LLM run method for speed gains.""" - callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose - ) - run_manager = await callback_manager.on_link_start( - dumpd(self), - {"input_list": input_list}, - ) - try: - response = await self.arun(input_list, run_manager=run_manager) - except BaseException as e: - await run_manager.on_link_error(e) - raise e - outputs = self.create_outputs(response) - await run_manager.on_link_end({"outputs": outputs}) - return outputs - - @property - def _run_output_key(self) -> str: - return self.output_key - - def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]: - """Create outputs from response.""" - result = [ - # Get the text of the top rund string. - { - self.output_key: self.output_parser.parse_result(generation), - "full_generation": generation, - } - for generation in llm_result.generations - ] - if self.return_final_only: - result = [{self.output_key: r[self.output_key]} for r in result] - return result - - async def _acall( - self, - inputs: Dict[str, Any], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - response = await self.arun([inputs], run_manager=run_manager) - return self.create_outputs(response)[0] - - def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: - """Format prompt with kwargs and pass to LLM. - - Args: - callbacks: Callbacks to pass to Link - **kwargs: Keys to pass to prompt template. - - Returns: - Completion from LLM. - - Example: - .. code-block:: python - - completion = llm.predict(adjective="funny") - """ - return self(kwargs, callbacks=callbacks)[self.output_key] - - async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: - """Format prompt with kwargs and pass to LLM. - - Args: - callbacks: Callbacks to pass to Link - **kwargs: Keys to pass to prompt template. - - Returns: - Completion from LLM. - - Example: - .. code-block:: python - - completion = llm.predict(adjective="funny") - """ - return (await self.acall(kwargs, callbacks=callbacks))[self.output_key] - - def predict_and_parse( - self, callbacks: Callbacks = None, **kwargs: Any - ) -> Union[str, List[str], Dict[str, Any]]: - """Call predict and then parse the results.""" - warnings.warn( - "The predict_and_parse method is deprecated, " - "instead pass an output parser directly to Link." - ) - result = self.predict(callbacks=callbacks, **kwargs) - if self.prompt.output_parser is not None: - return self.prompt.output_parser.parse(result) - else: - return result - - async def apredict_and_parse( - self, callbacks: Callbacks = None, **kwargs: Any - ) -> Union[str, List[str], Dict[str, str]]: - """Call apredict and then parse the results.""" - warnings.warn( - "The apredict_and_parse method is deprecated, " - "instead pass an output parser directly to Link." - ) - result = await self.apredict(callbacks=callbacks, **kwargs) - if self.prompt.output_parser is not None: - return self.prompt.output_parser.parse(result) - else: - return result - - def apply_and_parse( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> Sequence[Union[str, List[str], Dict[str, str]]]: - """Call apply and then parse the results.""" - warnings.warn( - "The apply_and_parse method is deprecated, " - "instead pass an output parser directly to Link." - ) - result = self.apply(input_list, callbacks=callbacks) - return self._parse_generation(result) - - def _parse_generation( - self, generation: List[Dict[str, str]] - ) -> Sequence[Union[str, List[str], Dict[str, str]]]: - if self.prompt.output_parser is not None: - return [ - self.prompt.output_parser.parse(res[self.output_key]) - for res in generation - ] - else: - return generation - - async def aapply_and_parse( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> Sequence[Union[str, List[str], Dict[str, str]]]: - """Call apply and then parse the results.""" - warnings.warn( - "The aapply_and_parse method is deprecated, " - "instead pass an output parser directly to Link." - ) - result = await self.aapply(input_list, callbacks=callbacks) - return self._parse_generation(result) - - @property - def _link_type(self) -> str: - return "llm_link" - - @classmethod - def from_string(cls, llm: BaseLanguageModel, template: str) -> Link: - """Create Link from LLM and template.""" - prompt_template = PromptTemplate.from_template(template) - return cls(llm=llm, prompt=prompt_template) diff --git a/tests/models/mistral.py b/tests/models/mistral.py new file mode 100644 index 00000000..8296b106 --- /dev/null +++ b/tests/models/mistral.py @@ -0,0 +1,41 @@ +import pytest +from unittest.mock import patch, MagicMock +from swarms.models.mistral import Mistral + +def test_mistral_initialization(): + mistral = Mistral(device="cpu") + assert isinstance(mistral, Mistral) + assert mistral.ai_name == "Node Model Agent" + assert mistral.system_prompt == None + assert mistral.model_name == "mistralai/Mistral-7B-v0.1" + assert mistral.device == "cpu" + assert mistral.use_flash_attention == False + assert mistral.temperature == 1.0 + assert mistral.max_length == 100 + assert mistral.history == [] + +@patch('your_module.AutoModelForCausalLM.from_pretrained') +@patch('your_module.AutoTokenizer.from_pretrained') +def test_mistral_load_model(mock_tokenizer, mock_model): + mistral = Mistral(device="cpu") + mistral.load_model() + mock_model.assert_called_once() + mock_tokenizer.assert_called_once() + +@patch('your_module.Mistral.load_model') +def test_mistral_run(mock_load_model): + mistral = Mistral(device="cpu") + mistral.run("What's the weather in miami") + mock_load_model.assert_called_once() + +@patch('your_module.Mistral.run') +def test_mistral_chat(mock_run): + mistral = Mistral(device="cpu") + mistral.chat("What's the weather in miami") + mock_run.assert_called_once() + +def test_mistral__stream_response(): + mistral = Mistral(device="cpu") + response = "It's sunny in Miami." + tokens = list(mistral._stream_response(response)) + assert tokens == ["It's", "sunny", "in", "Miami."] \ No newline at end of file diff --git a/tests/structs/nonlinear_workflow.py b/tests/structs/nonlinear_workflow.py new file mode 100644 index 00000000..295ec71f --- /dev/null +++ b/tests/structs/nonlinear_workflow.py @@ -0,0 +1,63 @@ +import pytest +from unittest.mock import patch, MagicMock +from swarms.structs.nonlinear_workflow import NonLinearWorkflow, Task + +class MockTask(Task): + def can_execute(self): + return True + + def execute(self): + return "Task executed" + +def test_nonlinearworkflow_initialization(): + agents = MagicMock() + iters_per_task = MagicMock() + workflow = NonLinearWorkflow(agents, iters_per_task) + assert isinstance(workflow, NonLinearWorkflow) + assert workflow.agents == agents + assert workflow.tasks == [] + +def test_nonlinearworkflow_add(): + agents = MagicMock() + iters_per_task = MagicMock() + workflow = NonLinearWorkflow(agents, iters_per_task) + task = MockTask("task1") + workflow.add(task) + assert workflow.tasks == [task] + +@patch('your_module.NonLinearWorkflow.is_finished') +@patch('your_module.NonLinearWorkflow.output_tasks') +def test_nonlinearworkflow_run(mock_output_tasks, mock_is_finished): + agents = MagicMock() + iters_per_task = MagicMock() + workflow = NonLinearWorkflow(agents, iters_per_task) + task = MockTask("task1") + workflow.add(task) + mock_is_finished.return_value = False + mock_output_tasks.return_value = [task] + workflow.run() + assert mock_output_tasks.called + +def test_nonlinearworkflow_output_tasks(): + agents = MagicMock() + iters_per_task = MagicMock() + workflow = NonLinearWorkflow(agents, iters_per_task) + task = MockTask("task1") + workflow.add(task) + assert workflow.output_tasks() == [task] + +def test_nonlinearworkflow_to_graph(): + agents = MagicMock() + iters_per_task = MagicMock() + workflow = NonLinearWorkflow(agents, iters_per_task) + task = MockTask("task1") + workflow.add(task) + assert workflow.to_graph() == {"task1": set()} + +def test_nonlinearworkflow_order_tasks(): + agents = MagicMock() + iters_per_task = MagicMock() + workflow = NonLinearWorkflow(agents, iters_per_task) + task = MockTask("task1") + workflow.add(task) + assert workflow.order_tasks() == [task] \ No newline at end of file diff --git a/tests/structs/workflow.py b/tests/structs/workflow.py new file mode 100644 index 00000000..5d973b18 --- /dev/null +++ b/tests/structs/workflow.py @@ -0,0 +1,59 @@ +import pytest +from unittest.mock import patch, MagicMock +from swarms.structs.workflow import Workflow + +def test_workflow_initialization(): + agent = MagicMock() + workflow = Workflow(agent) + assert isinstance(workflow, Workflow) + assert workflow.agent == agent + assert workflow.tasks == [] + assert workflow.parallel == False + +def test_workflow_add(): + agent = MagicMock() + workflow = Workflow(agent) + task = workflow.add("What's the weather in miami") + assert isinstance(task, Workflow.Task) + assert task.task == "What's the weather in miami" + assert task.parents == [] + assert task.children == [] + assert task.output == None + assert task.structure == workflow + +def test_workflow_first_task(): + agent = MagicMock() + workflow = Workflow(agent) + assert workflow.first_task() == None + workflow.add("What's the weather in miami") + assert workflow.first_task().task == "What's the weather in miami" + +def test_workflow_last_task(): + agent = MagicMock() + workflow = Workflow(agent) + assert workflow.last_task() == None + workflow.add("What's the weather in miami") + assert workflow.last_task().task == "What's the weather in miami" + +@patch('your_module.Workflow.__run_from_task') +def test_workflow_run(mock_run_from_task): + agent = MagicMock() + workflow = Workflow(agent) + workflow.add("What's the weather in miami") + workflow.run() + mock_run_from_task.assert_called_once() + +def test_workflow_context(): + agent = MagicMock() + workflow = Workflow(agent) + task = workflow.add("What's the weather in miami") + assert workflow.context(task) == {"parent_output": None, "parent": None, "child": None} + +@patch('your_module.Workflow.Task.execute') +def test_workflow___run_from_task(mock_execute): + agent = MagicMock() + workflow = Workflow(agent) + task = workflow.add("What's the weather in miami") + mock_execute.return_value = "Sunny" + workflow.__run_from_task(task) + mock_execute.assert_called_once() \ No newline at end of file