From bf9a747fa340bf4d1cb362ce33bb5e43217bb4f6 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 11 Nov 2023 14:24:26 -0500 Subject: [PATCH] 50+ tests for cohere Former-commit-id: bdc7337b2c0f0bca3fd9e61191808d4e96ddd3dd --- playground/models/cohere_example.py | 6 + pyproject.toml | 1 + requirements.txt | 1 + swarms/models/cohere.py | 335 -------------- swarms/models/cohere_chat.py | 247 +++++++++++ tests/memory/pg.py | 1 + tests/models/anthropic.py | 21 +- tests/models/cohere.py | 655 ++++++++++++++++++++++++++++ tests/swarms/groupchat.py | 2 + 9 files changed, 931 insertions(+), 338 deletions(-) create mode 100644 playground/models/cohere_example.py delete mode 100644 swarms/models/cohere.py create mode 100644 swarms/models/cohere_chat.py create mode 100644 tests/models/cohere.py diff --git a/playground/models/cohere_example.py b/playground/models/cohere_example.py new file mode 100644 index 00000000..eb389db0 --- /dev/null +++ b/playground/models/cohere_example.py @@ -0,0 +1,6 @@ +from swarms.models.cohere_chat import Cohere + + +cohere = Cohere(model="command-light", cohere_api_key="") + +out = cohere("Hello, how are you?") diff --git a/pyproject.toml b/pyproject.toml index 0f5a37fd..145594ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ attrs = "*" ggl = "*" ratelimit = "*" beautifulsoup4 = "*" +cohere = "*" huggingface-hub = "*" pydantic = "*" tenacity = "*" diff --git a/requirements.txt b/requirements.txt index e1148c30..56cdfd20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,6 +64,7 @@ webdataset yapf autopep8 dalle3 +cohere torchvision rich diff --git a/swarms/models/cohere.py b/swarms/models/cohere.py deleted file mode 100644 index a4ba75c5..00000000 --- a/swarms/models/cohere.py +++ /dev/null @@ -1,335 +0,0 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import ( - BaseChatModel, - _agenerate_from_stream, - _generate_from_stream, -) -from langchain.llms.cohere import BaseCohere -from langchain.schema.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) -from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult - - -def get_role(message: BaseMessage) -> str: - """Get the role of the message. - - Args: - message: The message. - - Returns: - The role of the message. - - Raises: - ValueError: If the message is of an unknown type. - """ - if isinstance(message, ChatMessage) or isinstance(message, HumanMessage): - return "User" - elif isinstance(message, AIMessage): - return "Chatbot" - elif isinstance(message, SystemMessage): - return "System" - else: - raise ValueError(f"Got unknown type {message}") - - -def get_cohere_chat_request( - messages: List[BaseMessage], - *, - connectors: Optional[List[Dict[str, str]]] = None, - **kwargs: Any, -) -> Dict[str, Any]: - """Get the request for the Cohere chat API. - - Args: - messages: The messages. - connectors: The connectors. - **kwargs: The keyword arguments. - - Returns: - The request for the Cohere chat API. - """ - documents = ( - None - if "source_documents" not in kwargs - else [ - { - "snippet": doc.page_content, - "id": doc.metadata.get("id") or f"doc-{str(i)}", - } - for i, doc in enumerate(kwargs["source_documents"]) - ] - ) - kwargs.pop("source_documents", None) - maybe_connectors = connectors if documents is None else None - - # by enabling automatic prompt truncation, the probability of request failure is - # reduced with minimal impact on response quality - prompt_truncation = ( - "AUTO" if documents is not None or connectors is not None else None - ) - - return { - "message": messages[0].content, - "chat_history": [ - {"role": get_role(x), "message": x.content} for x in messages[1:] - ], - "documents": documents, - "connectors": maybe_connectors, - "prompt_truncation": prompt_truncation, - **kwargs, - } - - -class CohereChat(BaseChatModel, BaseCohere): - """`Cohere` chat large language models. - - To use, you should have the ``cohere`` python package installed, and the - environment variable ``COHERE_API_KEY`` set with your API key, or pass - it as a named parameter to the constructor. - - Example: - .. code-block:: python - - from swarms.models.cohere import CohereChat, HumanMessage - - chat = CohereChat(model="foo") - result = chat([HumanMessage(content="Hello")]) - print(result.content) - """ - - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "cohere-chat" - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling Cohere API.""" - return { - "temperature": self.temperature, - } - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return {**{"model": self.model}, **self._default_params} - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - """ - Stream the output - - Args: - messages: The messages. - stop: The stop tokens. - run_manager: The callback manager. - **kwargs: The keyword arguments. - - """ - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - stream = self.client.chat(**request, stream=True) - - for data in stream: - if data.event_type == "text-generation": - delta = data.text - yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - run_manager.on_llm_new_token(delta) - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - """ - Stream generations from the model. - - Args: - messages: The messages. - stop: The stop tokens. - run_manager: The callback manager. - **kwargs: The keyword arguments. - - Yields: - The generations. - - Examples: - .. code-block:: python - - async for generation in model._astream(messages): - print(generation.message.content) - """ - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - stream = await self.async_client.chat(**request, stream=True) - - async for data in stream: - if data.event_type == "text-generation": - delta = data.text - yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - await run_manager.on_llm_new_token(delta) - - def _get_generation_info(self, response: Any) -> Dict[str, Any]: - """Get the generation info from cohere API response.""" - return { - "documents": response.documents, - "citations": response.citations, - "search_results": response.search_results, - "search_queries": response.search_queries, - "token_count": response.token_count, - } - - def _run( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """ - Run the model. - - Args: - messages: The messages. - stop: The stop tokens. - run_manager: The callback manager. - **kwargs: The keyword arguments. - - Returns: - The result. - - Examples: - .. code-block:: python - - result = model._run(messages) - print(result.content) - """ - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return _generate_from_stream(stream_iter) - - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - response = self.client.chat(**request) - - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) - return ChatResult( - generations=[ - ChatGeneration(message=message, generation_info=generation_info) - ] - ) - - def __call__( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """ - __Call__ the model. - - Args: - messages: The messages. - stop: The stop tokens. - run_manager: The callback manager. - **kwargs: The keyword arguments. - - Returns: - The result. - """ - if self.streaming: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return _generate_from_stream(stream_iter) - - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - response = self.client.chat(**request) - - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) - return ChatResult( - generations=[ - ChatGeneration(message=message, generation_info=generation_info) - ] - ) - - async def _arun( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - """ - Asynchronously run the model. - - Args: - messages: The messages. - stop: The stop tokens. - run_manager: The callback manager. - **kwargs: The keyword arguments. - - Returns: - The result. - - Examples: - .. code-block:: python - - result = await model._arun(messages) - print(result.content) - - """ - if self.streaming: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await _agenerate_from_stream(stream_iter) - - request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - response = self.client.chat(**request, stream=False) - - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) - return ChatResult( - generations=[ - ChatGeneration(message=message, generation_info=generation_info) - ] - ) - - def get_num_tokens(self, text: str) -> int: - """Calculate number of tokens.""" - return len(self.client.tokenize(text).tokens) \ No newline at end of file diff --git a/swarms/models/cohere_chat.py b/swarms/models/cohere_chat.py new file mode 100644 index 00000000..c583b827 --- /dev/null +++ b/swarms/models/cohere_chat.py @@ -0,0 +1,247 @@ +import logging +from typing import Any, Callable, Dict, List, Optional + +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens +from langchain.load.serializable import Serializable +from pydantic import Extra, Field, root_validator +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(llm) -> Callable[[Any], Any]: + import cohere + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=(retry_if_exception_type(cohere.error.CohereError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def completion_with_retry(llm, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return llm.client.generate(**kwargs) + + return _completion_with_retry(**kwargs) + + +def acompletion_with_retry(llm, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return await llm.async_client.generate(**kwargs) + + return _completion_with_retry(**kwargs) + + +class BaseCohere(Serializable): + """Base class for Cohere models.""" + + client: Any #: :meta private: + async_client: Any #: :meta private: + model: Optional[str] = Field(default=None, description="Model name to use.") + """Model name to use.""" + + temperature: float = 0.75 + """A non-negative float that tunes the degree of randomness in generation.""" + + cohere_api_key: Optional[str] = None + + stop: Optional[List[str]] = None + + streaming: bool = Field(default=False) + """Whether to stream the results.""" + + user_agent: str = "langchain" + """Identifier for the application making the request.""" + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import cohere + except ImportError: + raise ImportError( + "Could not import cohere python package. " + "Please install it with `pip install cohere`." + ) + else: + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) + client_name = values["user_agent"] + values["client"] = cohere.Client(cohere_api_key, client_name=client_name) + values["async_client"] = cohere.AsyncClient( + cohere_api_key, client_name=client_name + ) + return values + + +class Cohere(LLM, BaseCohere): + """Cohere large language models. + + To use, you should have the ``cohere`` python package installed, and the + environment variable ``COHERE_API_KEY`` set with your API key, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain.llms import Cohere + + cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key") + """ + + max_tokens: int = 256 + """Denotes the number of tokens to predict per generation.""" + + k: int = 0 + """Number of most likely tokens to consider at each step.""" + + p: int = 1 + """Total probability mass of tokens to consider at each step.""" + + frequency_penalty: float = 0.0 + """Penalizes repeated tokens according to frequency. Between 0 and 1.""" + + presence_penalty: float = 0.0 + """Penalizes repeated tokens. Between 0 and 1.""" + + truncate: Optional[str] = None + """Specify how the client handles inputs longer than the maximum token + length: Truncate from START, END or NONE""" + + max_retries: int = 10 + """Maximum number of retries to make when generating.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Cohere API.""" + return { + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "k": self.k, + "p": self.p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "truncate": self.truncate, + } + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"cohere_api_key": "COHERE_API_KEY"} + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "cohere" + + def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: + params = self._default_params + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + params["stop_sequences"] = self.stop + else: + params["stop_sequences"] = stop + return {**params, **kwargs} + + def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: + text = response.generations[0].text + # If stop tokens are provided, Cohere's endpoint returns them. + # In order to make this consistent with other endpoints, we strip them. + if stop: + text = enforce_stop_tokens(text, stop) + return text + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Cohere's generate endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = cohere("Tell me a joke.") + """ + params = self._invocation_params(stop, **kwargs) + response = completion_with_retry( + self, model=self.model, prompt=prompt, **params + ) + _stop = params.get("stop_sequences") + return self._process_response(response, _stop) + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Async call out to Cohere's generate endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = await cohere("Tell me a joke.") + """ + params = self._invocation_params(stop, **kwargs) + response = await acompletion_with_retry( + self, model=self.model, prompt=prompt, **params + ) + _stop = params.get("stop_sequences") + return self._process_response(response, _stop) diff --git a/tests/memory/pg.py b/tests/memory/pg.py index f639e6c2..e7b0587d 100644 --- a/tests/memory/pg.py +++ b/tests/memory/pg.py @@ -9,6 +9,7 @@ load_dotenv() PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING") + def test_init(): with patch("sqlalchemy.create_engine") as MockEngine: store = PgVectorVectorStore( diff --git a/tests/models/anthropic.py b/tests/models/anthropic.py index feb703a6..e2447614 100644 --- a/tests/models/anthropic.py +++ b/tests/models/anthropic.py @@ -142,32 +142,39 @@ class MockAnthropicResponse: def __init__(self): self.completion = "Mocked Response from Anthropic" + def test_anthropic_instance_creation(anthropic_instance): assert isinstance(anthropic_instance, Anthropic) + def test_anthropic_call_method(anthropic_instance): response = anthropic_instance("What is the meaning of life?") assert response == "Mocked Response from Anthropic" + def test_anthropic_stream_method(anthropic_instance): generator = anthropic_instance.stream("Write a story.") for token in generator: assert isinstance(token, str) + def test_anthropic_async_call_method(anthropic_instance): response = anthropic_instance.async_call("Tell me a joke.") assert response == "Mocked Response from Anthropic" + def test_anthropic_async_stream_method(anthropic_instance): async_generator = anthropic_instance.async_stream("Translate to French.") for token in async_generator: assert isinstance(token, str) + def test_anthropic_get_num_tokens(anthropic_instance): text = "This is a test sentence." num_tokens = anthropic_instance.get_num_tokens(text) assert num_tokens > 0 + # Add more test cases to cover other functionalities and edge cases of the Anthropic class @@ -177,47 +184,55 @@ def test_anthropic_wrap_prompt(anthropic_instance): assert wrapped_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert wrapped_prompt.endswith(anthropic_instance.AI_PROMPT) + def test_anthropic_convert_prompt(anthropic_instance): prompt = "What is the meaning of life?" converted_prompt = anthropic_instance.convert_prompt(prompt) assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) + def test_anthropic_call_with_stop(anthropic_instance): response = anthropic_instance("Translate to French.", stop=["stop1", "stop2"]) assert response == "Mocked Response from Anthropic" + def test_anthropic_stream_with_stop(anthropic_instance): generator = anthropic_instance.stream("Write a story.", stop=["stop1", "stop2"]) for token in generator: assert isinstance(token, str) + def test_anthropic_async_call_with_stop(anthropic_instance): response = anthropic_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) assert response == "Mocked Response from Anthropic" + def test_anthropic_async_stream_with_stop(anthropic_instance): - async_generator = anthropic_instance.async_stream("Translate to French.", stop=["stop1", "stop2"]) + async_generator = anthropic_instance.async_stream( + "Translate to French.", stop=["stop1", "stop2"] + ) for token in async_generator: assert isinstance(token, str) + def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance): anthropic_instance.count_tokens = Mock(return_value=10) text = "This is a test sentence." num_tokens = anthropic_instance.get_num_tokens(text) assert num_tokens == 10 + def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance): del anthropic_instance.count_tokens with pytest.raises(NameError): text = "This is a test sentence." anthropic_instance.get_num_tokens(text) + def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance): del anthropic_instance.HUMAN_PROMPT del anthropic_instance.AI_PROMPT prompt = "What is the meaning of life?" with pytest.raises(NameError): anthropic_instance._wrap_prompt(prompt) - - diff --git a/tests/models/cohere.py b/tests/models/cohere.py new file mode 100644 index 00000000..17bc2ddc --- /dev/null +++ b/tests/models/cohere.py @@ -0,0 +1,655 @@ +import os +from unittest.mock import Mock, patch + +import pytest +from cohere.models.response import GenerationChunk +from dotenv import load_dotenv + +from swarms.models.cohere_chat import BaseCohere, Cohere + +# Load the environment variables +load_dotenv() +api_key = os.getenv("COHERE_API_KEY") + + +@pytest.fixture +def cohere_instance(): + return Cohere(cohere_api_key=api_key) + + +def test_cohere_wrap_prompt(cohere_instance): + prompt = "What is the meaning of life?" + wrapped_prompt = cohere_instance._wrap_prompt(prompt) + assert wrapped_prompt.startswith(cohere_instance.HUMAN_PROMPT) + assert wrapped_prompt.endswith(cohere_instance.AI_PROMPT) + + +def test_cohere_convert_prompt(cohere_instance): + prompt = "What is the meaning of life?" + converted_prompt = cohere_instance.convert_prompt(prompt) + assert converted_prompt.startswith(cohere_instance.HUMAN_PROMPT) + assert converted_prompt.endswith(cohere_instance.AI_PROMPT) + + +def test_cohere_call_with_stop(cohere_instance): + response = cohere_instance("Translate to French.", stop=["stop1", "stop2"]) + assert response == "Mocked Response from Cohere" + + +def test_cohere_stream_with_stop(cohere_instance): + generator = cohere_instance.stream("Write a story.", stop=["stop1", "stop2"]) + for token in generator: + assert isinstance(token, str) + + +def test_cohere_async_call_with_stop(cohere_instance): + response = cohere_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) + assert response == "Mocked Response from Cohere" + + +def test_cohere_async_stream_with_stop(cohere_instance): + async_generator = cohere_instance.async_stream( + "Translate to French.", stop=["stop1", "stop2"] + ) + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_get_num_tokens_with_count_tokens(cohere_instance): + cohere_instance.count_tokens = Mock(return_value=10) + text = "This is a test sentence." + num_tokens = cohere_instance.get_num_tokens(text) + assert num_tokens == 10 + + +def test_cohere_get_num_tokens_without_count_tokens(cohere_instance): + del cohere_instance.count_tokens + with pytest.raises(NameError): + text = "This is a test sentence." + cohere_instance.get_num_tokens(text) + + +def test_cohere_wrap_prompt_without_human_ai_prompt(cohere_instance): + del cohere_instance.HUMAN_PROMPT + del cohere_instance.AI_PROMPT + prompt = "What is the meaning of life?" + with pytest.raises(NameError): + cohere_instance._wrap_prompt(prompt) + + +def test_base_cohere_import(): + with patch.dict("sys.modules", {"cohere": None}): + with pytest.raises(ImportError): + pass + + +def test_base_cohere_validate_environment(): + values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"} + validated_values = BaseCohere.validate_environment(values) + assert "client" in validated_values + assert "async_client" in validated_values + + +def test_base_cohere_validate_environment_without_cohere(): + values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"} + with patch.dict("sys.modules", {"cohere": None}): + with pytest.raises(ImportError): + BaseCohere.validate_environment(values) + + +# Test cases for benchmarking generations with various models +def test_cohere_generate_with_command_light(cohere_instance): + cohere_instance.model = "command-light" + response = cohere_instance("Generate text with Command Light model.") + assert response.startswith("Generated text with Command Light model") + + +def test_cohere_generate_with_command(cohere_instance): + cohere_instance.model = "command" + response = cohere_instance("Generate text with Command model.") + assert response.startswith("Generated text with Command model") + + +def test_cohere_generate_with_base_light(cohere_instance): + cohere_instance.model = "base-light" + response = cohere_instance("Generate text with Base Light model.") + assert response.startswith("Generated text with Base Light model") + + +def test_cohere_generate_with_base(cohere_instance): + cohere_instance.model = "base" + response = cohere_instance("Generate text with Base model.") + assert response.startswith("Generated text with Base model") + + +def test_cohere_generate_with_embed_english_v2(cohere_instance): + cohere_instance.model = "embed-english-v2.0" + response = cohere_instance("Generate embeddings with English v2.0 model.") + assert response.startswith("Generated embeddings with English v2.0 model") + + +def test_cohere_generate_with_embed_english_light_v2(cohere_instance): + cohere_instance.model = "embed-english-light-v2.0" + response = cohere_instance("Generate embeddings with English Light v2.0 model.") + assert response.startswith("Generated embeddings with English Light v2.0 model") + + +def test_cohere_generate_with_embed_multilingual_v2(cohere_instance): + cohere_instance.model = "embed-multilingual-v2.0" + response = cohere_instance("Generate embeddings with Multilingual v2.0 model.") + assert response.startswith("Generated embeddings with Multilingual v2.0 model") + + +def test_cohere_generate_with_embed_english_v3(cohere_instance): + cohere_instance.model = "embed-english-v3.0" + response = cohere_instance("Generate embeddings with English v3.0 model.") + assert response.startswith("Generated embeddings with English v3.0 model") + + +def test_cohere_generate_with_embed_english_light_v3(cohere_instance): + cohere_instance.model = "embed-english-light-v3.0" + response = cohere_instance("Generate embeddings with English Light v3.0 model.") + assert response.startswith("Generated embeddings with English Light v3.0 model") + + +def test_cohere_generate_with_embed_multilingual_v3(cohere_instance): + cohere_instance.model = "embed-multilingual-v3.0" + response = cohere_instance("Generate embeddings with Multilingual v3.0 model.") + assert response.startswith("Generated embeddings with Multilingual v3.0 model") + + +def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance): + cohere_instance.model = "embed-multilingual-light-v3.0" + response = cohere_instance( + "Generate embeddings with Multilingual Light v3.0 model." + ) + assert response.startswith( + "Generated embeddings with Multilingual Light v3.0 model" + ) + + +# Add more test cases to benchmark other models and functionalities + + +def test_cohere_call_with_command_model(cohere_instance): + cohere_instance.model = "command" + response = cohere_instance("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_call_with_base_model(cohere_instance): + cohere_instance.model = "base" + response = cohere_instance("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_call_with_embed_english_v2_model(cohere_instance): + cohere_instance.model = "embed-english-v2.0" + response = cohere_instance("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_call_with_embed_english_v3_model(cohere_instance): + cohere_instance.model = "embed-english-v3.0" + response = cohere_instance("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v2.0" + response = cohere_instance("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v3.0" + response = cohere_instance("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_call_with_invalid_model(cohere_instance): + cohere_instance.model = "invalid-model" + with pytest.raises(ValueError): + response = cohere_instance("Translate to French.") + + +def test_cohere_call_with_long_prompt(cohere_instance): + prompt = "This is a very long prompt. " * 100 + response = cohere_instance(prompt) + assert isinstance(response, str) + + +def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance): + cohere_instance.max_tokens = 10 + prompt = "This is a test prompt that will exceed the max tokens limit." + with pytest.raises(ValueError): + response = cohere_instance(prompt) + + +def test_cohere_stream_with_command_model(cohere_instance): + cohere_instance.model = "command" + generator = cohere_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + + +def test_cohere_stream_with_base_model(cohere_instance): + cohere_instance.model = "base" + generator = cohere_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + + +def test_cohere_stream_with_embed_english_v2_model(cohere_instance): + cohere_instance.model = "embed-english-v2.0" + generator = cohere_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + + +def test_cohere_stream_with_embed_english_v3_model(cohere_instance): + cohere_instance.model = "embed-english-v3.0" + generator = cohere_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + + +def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v2.0" + generator = cohere_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + + +def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v3.0" + generator = cohere_instance.stream("Write a story.") + for token in generator: + assert isinstance(token, str) + + +def test_cohere_async_call_with_command_model(cohere_instance): + cohere_instance.model = "command" + response = cohere_instance.async_call("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_async_call_with_base_model(cohere_instance): + cohere_instance.model = "base" + response = cohere_instance.async_call("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_async_call_with_embed_english_v2_model(cohere_instance): + cohere_instance.model = "embed-english-v2.0" + response = cohere_instance.async_call("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_async_call_with_embed_english_v3_model(cohere_instance): + cohere_instance.model = "embed-english-v3.0" + response = cohere_instance.async_call("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v2.0" + response = cohere_instance.async_call("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v3.0" + response = cohere_instance.async_call("Translate to French.") + assert isinstance(response, str) + + +def test_cohere_async_stream_with_command_model(cohere_instance): + cohere_instance.model = "command" + async_generator = cohere_instance.async_stream("Write a story.") + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_async_stream_with_base_model(cohere_instance): + cohere_instance.model = "base" + async_generator = cohere_instance.async_stream("Write a story.") + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance): + cohere_instance.model = "embed-english-v2.0" + async_generator = cohere_instance.async_stream("Write a story.") + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance): + cohere_instance.model = "embed-english-v3.0" + async_generator = cohere_instance.async_stream("Write a story.") + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v2.0" + async_generator = cohere_instance.async_stream("Write a story.") + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance): + cohere_instance.model = "embed-multilingual-v3.0" + async_generator = cohere_instance.async_stream("Write a story.") + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_custom_configuration(cohere_instance): + # Test customizing Cohere configurations + cohere_instance.model = "base" + cohere_instance.temperature = 0.5 + cohere_instance.max_tokens = 100 + cohere_instance.k = 1 + cohere_instance.p = 0.8 + cohere_instance.frequency_penalty = 0.2 + cohere_instance.presence_penalty = 0.4 + response = cohere_instance("Customize configurations.") + assert isinstance(response, str) + + +def test_cohere_api_error_handling(cohere_instance): + # Test error handling when the API key is invalid + cohere_instance.model = "base" + cohere_instance.cohere_api_key = "invalid-api-key" + with pytest.raises(Exception): + response = cohere_instance("Error handling with invalid API key.") + + +def test_cohere_async_api_error_handling(cohere_instance): + # Test async error handling when the API key is invalid + cohere_instance.model = "base" + cohere_instance.cohere_api_key = "invalid-api-key" + with pytest.raises(Exception): + response = cohere_instance.async_call("Error handling with invalid API key.") + + +def test_cohere_stream_api_error_handling(cohere_instance): + # Test error handling in streaming mode when the API key is invalid + cohere_instance.model = "base" + cohere_instance.cohere_api_key = "invalid-api-key" + with pytest.raises(Exception): + generator = cohere_instance.stream("Error handling with invalid API key.") + for token in generator: + pass + + +def test_cohere_streaming_mode(cohere_instance): + # Test the streaming mode for large text generation + cohere_instance.model = "base" + cohere_instance.streaming = True + prompt = "Generate a lengthy text using streaming mode." + generator = cohere_instance.stream(prompt) + for token in generator: + assert isinstance(token, str) + + +def test_cohere_streaming_mode_async(cohere_instance): + # Test the async streaming mode for large text generation + cohere_instance.model = "base" + cohere_instance.streaming = True + prompt = "Generate a lengthy text using async streaming mode." + async_generator = cohere_instance.async_stream(prompt) + for token in async_generator: + assert isinstance(token, str) + + +def test_cohere_representation_model_embedding(cohere_instance): + # Test using the Representation model for text embedding + cohere_instance.model = "embed-english-v3.0" + embedding = cohere_instance.embed("Generate an embedding for this text.") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +def test_cohere_representation_model_classification(cohere_instance): + # Test using the Representation model for text classification + cohere_instance.model = "embed-english-v3.0" + classification = cohere_instance.classify("Classify this text.") + assert isinstance(classification, dict) + assert "class" in classification + assert "score" in classification + + +def test_cohere_representation_model_language_detection(cohere_instance): + # Test using the Representation model for language detection + cohere_instance.model = "embed-english-v3.0" + language = cohere_instance.detect_language("Detect the language of this text.") + assert isinstance(language, str) + + +def test_cohere_representation_model_max_tokens_limit_exceeded(cohere_instance): + # Test handling max tokens limit exceeded error + cohere_instance.model = "embed-english-v3.0" + cohere_instance.max_tokens = 10 + prompt = "This is a test prompt that will exceed the max tokens limit." + with pytest.raises(ValueError): + embedding = cohere_instance.embed(prompt) + + +# Add more production-grade test cases based on real-world scenarios + + +def test_cohere_representation_model_multilingual_embedding(cohere_instance): + # Test using the Representation model for multilingual text embedding + cohere_instance.model = "embed-multilingual-v3.0" + embedding = cohere_instance.embed("Generate multilingual embeddings.") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +def test_cohere_representation_model_multilingual_classification(cohere_instance): + # Test using the Representation model for multilingual text classification + cohere_instance.model = "embed-multilingual-v3.0" + classification = cohere_instance.classify("Classify multilingual text.") + assert isinstance(classification, dict) + assert "class" in classification + assert "score" in classification + + +def test_cohere_representation_model_multilingual_language_detection(cohere_instance): + # Test using the Representation model for multilingual language detection + cohere_instance.model = "embed-multilingual-v3.0" + language = cohere_instance.detect_language( + "Detect the language of multilingual text." + ) + assert isinstance(language, str) + + +def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded( + cohere_instance, +): + # Test handling max tokens limit exceeded error for multilingual model + cohere_instance.model = "embed-multilingual-v3.0" + cohere_instance.max_tokens = 10 + prompt = "This is a test prompt that will exceed the max tokens limit for multilingual model." + with pytest.raises(ValueError): + embedding = cohere_instance.embed(prompt) + + +def test_cohere_representation_model_multilingual_light_embedding(cohere_instance): + # Test using the Representation model for multilingual light text embedding + cohere_instance.model = "embed-multilingual-light-v3.0" + embedding = cohere_instance.embed("Generate multilingual light embeddings.") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +def test_cohere_representation_model_multilingual_light_classification(cohere_instance): + # Test using the Representation model for multilingual light text classification + cohere_instance.model = "embed-multilingual-light-v3.0" + classification = cohere_instance.classify("Classify multilingual light text.") + assert isinstance(classification, dict) + assert "class" in classification + assert "score" in classification + + +def test_cohere_representation_model_multilingual_light_language_detection( + cohere_instance, +): + # Test using the Representation model for multilingual light language detection + cohere_instance.model = "embed-multilingual-light-v3.0" + language = cohere_instance.detect_language( + "Detect the language of multilingual light text." + ) + assert isinstance(language, str) + + +def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded( + cohere_instance, +): + # Test handling max tokens limit exceeded error for multilingual light model + cohere_instance.model = "embed-multilingual-light-v3.0" + cohere_instance.max_tokens = 10 + prompt = "This is a test prompt that will exceed the max tokens limit for multilingual light model." + with pytest.raises(ValueError): + embedding = cohere_instance.embed(prompt) + + +def test_cohere_command_light_model(cohere_instance): + # Test using the Command Light model for text generation + cohere_instance.model = "command-light" + response = cohere_instance("Generate text using Command Light model.") + assert isinstance(response, str) + + +def test_cohere_base_light_model(cohere_instance): + # Test using the Base Light model for text generation + cohere_instance.model = "base-light" + response = cohere_instance("Generate text using Base Light model.") + assert isinstance(response, str) + + +def test_cohere_generate_summarize_endpoint(cohere_instance): + # Test using the Co.summarize() endpoint for text summarization + cohere_instance.model = "command" + response = cohere_instance.summarize("Summarize this text.") + assert isinstance(response, str) + + +def test_cohere_representation_model_english_embedding(cohere_instance): + # Test using the Representation model for English text embedding + cohere_instance.model = "embed-english-v3.0" + embedding = cohere_instance.embed("Generate English embeddings.") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +def test_cohere_representation_model_english_classification(cohere_instance): + # Test using the Representation model for English text classification + cohere_instance.model = "embed-english-v3.0" + classification = cohere_instance.classify("Classify English text.") + assert isinstance(classification, dict) + assert "class" in classification + assert "score" in classification + + +def test_cohere_representation_model_english_language_detection(cohere_instance): + # Test using the Representation model for English language detection + cohere_instance.model = "embed-english-v3.0" + language = cohere_instance.detect_language("Detect the language of English text.") + assert isinstance(language, str) + + +def test_cohere_representation_model_english_max_tokens_limit_exceeded(cohere_instance): + # Test handling max tokens limit exceeded error for English model + cohere_instance.model = "embed-english-v3.0" + cohere_instance.max_tokens = 10 + prompt = ( + "This is a test prompt that will exceed the max tokens limit for English model." + ) + with pytest.raises(ValueError): + embedding = cohere_instance.embed(prompt) + + +def test_cohere_representation_model_english_light_embedding(cohere_instance): + # Test using the Representation model for English light text embedding + cohere_instance.model = "embed-english-light-v3.0" + embedding = cohere_instance.embed("Generate English light embeddings.") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +def test_cohere_representation_model_english_light_classification(cohere_instance): + # Test using the Representation model for English light text classification + cohere_instance.model = "embed-english-light-v3.0" + classification = cohere_instance.classify("Classify English light text.") + assert isinstance(classification, dict) + assert "class" in classification + assert "score" in classification + + +def test_cohere_representation_model_english_light_language_detection(cohere_instance): + # Test using the Representation model for English light language detection + cohere_instance.model = "embed-english-light-v3.0" + language = cohere_instance.detect_language( + "Detect the language of English light text." + ) + assert isinstance(language, str) + + +def test_cohere_representation_model_english_light_max_tokens_limit_exceeded( + cohere_instance, +): + # Test handling max tokens limit exceeded error for English light model + cohere_instance.model = "embed-english-light-v3.0" + cohere_instance.max_tokens = 10 + prompt = "This is a test prompt that will exceed the max tokens limit for English light model." + with pytest.raises(ValueError): + embedding = cohere_instance.embed(prompt) + + +def test_cohere_command_model(cohere_instance): + # Test using the Command model for text generation + cohere_instance.model = "command" + response = cohere_instance("Generate text using the Command model.") + assert isinstance(response, str) + + +# Add more production-grade test cases based on real-world scenarios + + +def test_cohere_invalid_model(cohere_instance): + # Test using an invalid model name + cohere_instance.model = "invalid-model" + with pytest.raises(ValueError): + response = cohere_instance("Generate text using an invalid model.") + + +def test_cohere_streaming_generation(cohere_instance): + # Test streaming generation with the Command model + cohere_instance.model = "command" + prompt = "Generate text using streaming." + chunks = list(cohere_instance.stream(prompt)) + assert isinstance(chunks, list) + assert len(chunks) > 0 + assert all(isinstance(chunk, GenerationChunk) for chunk in chunks) + + +def test_cohere_base_model_generation_with_max_tokens(cohere_instance): + # Test generating text using the base model with a specified max_tokens limit + cohere_instance.model = "base" + cohere_instance.max_tokens = 20 + prompt = "Generate text with max_tokens limit." + response = cohere_instance(prompt) + assert len(response.split()) <= 20 + + +def test_cohere_command_light_generation_with_stop(cohere_instance): + # Test generating text using the command-light model with stop words + cohere_instance.model = "command-light" + prompt = "Generate text with stop words." + stop = ["stop", "words"] + response = cohere_instance(prompt, stop=stop) + assert all(word not in response for word in stop) diff --git a/tests/swarms/groupchat.py b/tests/swarms/groupchat.py index 68609e31..f81c415a 100644 --- a/tests/swarms/groupchat.py +++ b/tests/swarms/groupchat.py @@ -8,6 +8,7 @@ from swarms.swarms.flow import GroupChat, GroupChatManager llm = OpenAIChat() llm2 = Anthropic() + # Mock the OpenAI class for testing class MockOpenAI: def __init__(self, *args, **kwargs): @@ -126,6 +127,7 @@ def test_groupchat_manager_initialization(agent1, agent2): assert manager.groupchat == groupchat assert manager.selector == selector + # Test case to ensure GroupChatManager generates a reply from an agent def test_groupchat_manager_generate_reply(): # Create a GroupChat with two agents