50+ tests for cohere

pull/128/head
Kye 1 year ago
parent 4197920802
commit bdc7337b2c

@ -0,0 +1,6 @@
from swarms.models.cohere_chat import Cohere
cohere = Cohere(model="command-light", cohere_api_key="")
out = cohere("Hello, how are you?")

@ -48,6 +48,7 @@ attrs = "*"
ggl = "*" ggl = "*"
ratelimit = "*" ratelimit = "*"
beautifulsoup4 = "*" beautifulsoup4 = "*"
cohere = "*"
huggingface-hub = "*" huggingface-hub = "*"
pydantic = "*" pydantic = "*"
tenacity = "*" tenacity = "*"

@ -64,6 +64,7 @@ webdataset
yapf yapf
autopep8 autopep8
dalle3 dalle3
cohere
torchvision torchvision
rich rich

@ -1,335 +0,0 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.llms.cohere import BaseCohere
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
def get_role(message: BaseMessage) -> str:
"""Get the role of the message.
Args:
message: The message.
Returns:
The role of the message.
Raises:
ValueError: If the message is of an unknown type.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "User"
elif isinstance(message, AIMessage):
return "Chatbot"
elif isinstance(message, SystemMessage):
return "System"
else:
raise ValueError(f"Got unknown type {message}")
def get_cohere_chat_request(
messages: List[BaseMessage],
*,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Args:
messages: The messages.
connectors: The connectors.
**kwargs: The keyword arguments.
Returns:
The request for the Cohere chat API.
"""
documents = (
None
if "source_documents" not in kwargs
else [
{
"snippet": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(kwargs["source_documents"])
]
)
kwargs.pop("source_documents", None)
maybe_connectors = connectors if documents is None else None
# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
prompt_truncation = (
"AUTO" if documents is not None or connectors is not None else None
)
return {
"message": messages[0].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[1:]
],
"documents": documents,
"connectors": maybe_connectors,
"prompt_truncation": prompt_truncation,
**kwargs,
}
class CohereChat(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
To use, you should have the ``cohere`` python package installed, and the
environment variable ``COHERE_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
from swarms.models.cohere import CohereChat, HumanMessage
chat = CohereChat(model="foo")
result = chat([HumanMessage(content="Hello")])
print(result.content)
"""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "cohere-chat"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"temperature": self.temperature,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""
Stream the output
Args:
messages: The messages.
stop: The stop tokens.
run_manager: The callback manager.
**kwargs: The keyword arguments.
"""
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = self.client.chat(**request, stream=True)
for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""
Stream generations from the model.
Args:
messages: The messages.
stop: The stop tokens.
run_manager: The callback manager.
**kwargs: The keyword arguments.
Yields:
The generations.
Examples:
.. code-block:: python
async for generation in model._astream(messages):
print(generation.message.content)
"""
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)
async for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta)
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
return {
"documents": response.documents,
"citations": response.citations,
"search_results": response.search_results,
"search_queries": response.search_queries,
"token_count": response.token_count,
}
def _run(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
Run the model.
Args:
messages: The messages.
stop: The stop tokens.
run_manager: The callback manager.
**kwargs: The keyword arguments.
Returns:
The result.
Examples:
.. code-block:: python
result = model._run(messages)
print(result.content)
"""
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
def __call__(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
__Call__ the model.
Args:
messages: The messages.
stop: The stop tokens.
run_manager: The callback manager.
**kwargs: The keyword arguments.
Returns:
The result.
"""
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
async def _arun(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
Asynchronously run the model.
Args:
messages: The messages.
stop: The stop tokens.
run_manager: The callback manager.
**kwargs: The keyword arguments.
Returns:
The result.
Examples:
.. code-block:: python
result = await model._arun(messages)
print(result.content)
"""
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await _agenerate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
return len(self.client.tokenize(text).tokens)

@ -0,0 +1,247 @@
import logging
from typing import Any, Callable, Dict, List, Optional
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.load.serializable import Serializable
from pydantic import Extra, Field, root_validator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
import cohere
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(cohere.error.CohereError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(llm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return llm.client.generate(**kwargs)
return _completion_with_retry(**kwargs)
def acompletion_with_retry(llm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await llm.async_client.generate(**kwargs)
return _completion_with_retry(**kwargs)
class BaseCohere(Serializable):
"""Base class for Cohere models."""
client: Any #: :meta private:
async_client: Any #: :meta private:
model: Optional[str] = Field(default=None, description="Model name to use.")
"""Model name to use."""
temperature: float = 0.75
"""A non-negative float that tunes the degree of randomness in generation."""
cohere_api_key: Optional[str] = None
stop: Optional[List[str]] = None
streaming: bool = Field(default=False)
"""Whether to stream the results."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import cohere
except ImportError:
raise ImportError(
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
else:
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, client_name=client_name
)
return values
class Cohere(LLM, BaseCohere):
"""Cohere large language models.
To use, you should have the ``cohere`` python package installed, and the
environment variable ``COHERE_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms import Cohere
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
"""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
k: int = 0
"""Number of most likely tokens to consider at each step."""
p: int = 1
"""Total probability mass of tokens to consider at each step."""
frequency_penalty: float = 0.0
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
presence_penalty: float = 0.0
"""Penalizes repeated tokens. Between 0 and 1."""
truncate: Optional[str] = None
"""Specify how the client handles inputs longer than the maximum token
length: Truncate from START, END or NONE"""
max_retries: int = 10
"""Maximum number of retries to make when generating."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"k": self.k,
"p": self.p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"truncate": self.truncate,
}
@property
def lc_secrets(self) -> Dict[str, str]:
return {"cohere_api_key": "COHERE_API_KEY"}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "cohere"
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop_sequences"] = self.stop
else:
params["stop_sequences"] = stop
return {**params, **kwargs}
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop:
text = enforce_stop_tokens(text, stop)
return text
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Cohere's generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = cohere("Tell me a joke.")
"""
params = self._invocation_params(stop, **kwargs)
response = completion_with_retry(
self, model=self.model, prompt=prompt, **params
)
_stop = params.get("stop_sequences")
return self._process_response(response, _stop)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async call out to Cohere's generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = await cohere("Tell me a joke.")
"""
params = self._invocation_params(stop, **kwargs)
response = await acompletion_with_retry(
self, model=self.model, prompt=prompt, **params
)
_stop = params.get("stop_sequences")
return self._process_response(response, _stop)

@ -9,6 +9,7 @@ load_dotenv()
PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING") PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING")
def test_init(): def test_init():
with patch("sqlalchemy.create_engine") as MockEngine: with patch("sqlalchemy.create_engine") as MockEngine:
store = PgVectorVectorStore( store = PgVectorVectorStore(

@ -142,32 +142,39 @@ class MockAnthropicResponse:
def __init__(self): def __init__(self):
self.completion = "Mocked Response from Anthropic" self.completion = "Mocked Response from Anthropic"
def test_anthropic_instance_creation(anthropic_instance): def test_anthropic_instance_creation(anthropic_instance):
assert isinstance(anthropic_instance, Anthropic) assert isinstance(anthropic_instance, Anthropic)
def test_anthropic_call_method(anthropic_instance): def test_anthropic_call_method(anthropic_instance):
response = anthropic_instance("What is the meaning of life?") response = anthropic_instance("What is the meaning of life?")
assert response == "Mocked Response from Anthropic" assert response == "Mocked Response from Anthropic"
def test_anthropic_stream_method(anthropic_instance): def test_anthropic_stream_method(anthropic_instance):
generator = anthropic_instance.stream("Write a story.") generator = anthropic_instance.stream("Write a story.")
for token in generator: for token in generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_anthropic_async_call_method(anthropic_instance): def test_anthropic_async_call_method(anthropic_instance):
response = anthropic_instance.async_call("Tell me a joke.") response = anthropic_instance.async_call("Tell me a joke.")
assert response == "Mocked Response from Anthropic" assert response == "Mocked Response from Anthropic"
def test_anthropic_async_stream_method(anthropic_instance): def test_anthropic_async_stream_method(anthropic_instance):
async_generator = anthropic_instance.async_stream("Translate to French.") async_generator = anthropic_instance.async_stream("Translate to French.")
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_anthropic_get_num_tokens(anthropic_instance): def test_anthropic_get_num_tokens(anthropic_instance):
text = "This is a test sentence." text = "This is a test sentence."
num_tokens = anthropic_instance.get_num_tokens(text) num_tokens = anthropic_instance.get_num_tokens(text)
assert num_tokens > 0 assert num_tokens > 0
# Add more test cases to cover other functionalities and edge cases of the Anthropic class # Add more test cases to cover other functionalities and edge cases of the Anthropic class
@ -177,47 +184,55 @@ def test_anthropic_wrap_prompt(anthropic_instance):
assert wrapped_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert wrapped_prompt.startswith(anthropic_instance.HUMAN_PROMPT)
assert wrapped_prompt.endswith(anthropic_instance.AI_PROMPT) assert wrapped_prompt.endswith(anthropic_instance.AI_PROMPT)
def test_anthropic_convert_prompt(anthropic_instance): def test_anthropic_convert_prompt(anthropic_instance):
prompt = "What is the meaning of life?" prompt = "What is the meaning of life?"
converted_prompt = anthropic_instance.convert_prompt(prompt) converted_prompt = anthropic_instance.convert_prompt(prompt)
assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT) assert converted_prompt.startswith(anthropic_instance.HUMAN_PROMPT)
assert converted_prompt.endswith(anthropic_instance.AI_PROMPT) assert converted_prompt.endswith(anthropic_instance.AI_PROMPT)
def test_anthropic_call_with_stop(anthropic_instance): def test_anthropic_call_with_stop(anthropic_instance):
response = anthropic_instance("Translate to French.", stop=["stop1", "stop2"]) response = anthropic_instance("Translate to French.", stop=["stop1", "stop2"])
assert response == "Mocked Response from Anthropic" assert response == "Mocked Response from Anthropic"
def test_anthropic_stream_with_stop(anthropic_instance): def test_anthropic_stream_with_stop(anthropic_instance):
generator = anthropic_instance.stream("Write a story.", stop=["stop1", "stop2"]) generator = anthropic_instance.stream("Write a story.", stop=["stop1", "stop2"])
for token in generator: for token in generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_anthropic_async_call_with_stop(anthropic_instance): def test_anthropic_async_call_with_stop(anthropic_instance):
response = anthropic_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"]) response = anthropic_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"])
assert response == "Mocked Response from Anthropic" assert response == "Mocked Response from Anthropic"
def test_anthropic_async_stream_with_stop(anthropic_instance): def test_anthropic_async_stream_with_stop(anthropic_instance):
async_generator = anthropic_instance.async_stream("Translate to French.", stop=["stop1", "stop2"]) async_generator = anthropic_instance.async_stream(
"Translate to French.", stop=["stop1", "stop2"]
)
for token in async_generator: for token in async_generator:
assert isinstance(token, str) assert isinstance(token, str)
def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance): def test_anthropic_get_num_tokens_with_count_tokens(anthropic_instance):
anthropic_instance.count_tokens = Mock(return_value=10) anthropic_instance.count_tokens = Mock(return_value=10)
text = "This is a test sentence." text = "This is a test sentence."
num_tokens = anthropic_instance.get_num_tokens(text) num_tokens = anthropic_instance.get_num_tokens(text)
assert num_tokens == 10 assert num_tokens == 10
def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance): def test_anthropic_get_num_tokens_without_count_tokens(anthropic_instance):
del anthropic_instance.count_tokens del anthropic_instance.count_tokens
with pytest.raises(NameError): with pytest.raises(NameError):
text = "This is a test sentence." text = "This is a test sentence."
anthropic_instance.get_num_tokens(text) anthropic_instance.get_num_tokens(text)
def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance): def test_anthropic_wrap_prompt_without_human_ai_prompt(anthropic_instance):
del anthropic_instance.HUMAN_PROMPT del anthropic_instance.HUMAN_PROMPT
del anthropic_instance.AI_PROMPT del anthropic_instance.AI_PROMPT
prompt = "What is the meaning of life?" prompt = "What is the meaning of life?"
with pytest.raises(NameError): with pytest.raises(NameError):
anthropic_instance._wrap_prompt(prompt) anthropic_instance._wrap_prompt(prompt)

@ -0,0 +1,655 @@
import os
from unittest.mock import Mock, patch
import pytest
from cohere.models.response import GenerationChunk
from dotenv import load_dotenv
from swarms.models.cohere_chat import BaseCohere, Cohere
# Load the environment variables
load_dotenv()
api_key = os.getenv("COHERE_API_KEY")
@pytest.fixture
def cohere_instance():
return Cohere(cohere_api_key=api_key)
def test_cohere_wrap_prompt(cohere_instance):
prompt = "What is the meaning of life?"
wrapped_prompt = cohere_instance._wrap_prompt(prompt)
assert wrapped_prompt.startswith(cohere_instance.HUMAN_PROMPT)
assert wrapped_prompt.endswith(cohere_instance.AI_PROMPT)
def test_cohere_convert_prompt(cohere_instance):
prompt = "What is the meaning of life?"
converted_prompt = cohere_instance.convert_prompt(prompt)
assert converted_prompt.startswith(cohere_instance.HUMAN_PROMPT)
assert converted_prompt.endswith(cohere_instance.AI_PROMPT)
def test_cohere_call_with_stop(cohere_instance):
response = cohere_instance("Translate to French.", stop=["stop1", "stop2"])
assert response == "Mocked Response from Cohere"
def test_cohere_stream_with_stop(cohere_instance):
generator = cohere_instance.stream("Write a story.", stop=["stop1", "stop2"])
for token in generator:
assert isinstance(token, str)
def test_cohere_async_call_with_stop(cohere_instance):
response = cohere_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"])
assert response == "Mocked Response from Cohere"
def test_cohere_async_stream_with_stop(cohere_instance):
async_generator = cohere_instance.async_stream(
"Translate to French.", stop=["stop1", "stop2"]
)
for token in async_generator:
assert isinstance(token, str)
def test_cohere_get_num_tokens_with_count_tokens(cohere_instance):
cohere_instance.count_tokens = Mock(return_value=10)
text = "This is a test sentence."
num_tokens = cohere_instance.get_num_tokens(text)
assert num_tokens == 10
def test_cohere_get_num_tokens_without_count_tokens(cohere_instance):
del cohere_instance.count_tokens
with pytest.raises(NameError):
text = "This is a test sentence."
cohere_instance.get_num_tokens(text)
def test_cohere_wrap_prompt_without_human_ai_prompt(cohere_instance):
del cohere_instance.HUMAN_PROMPT
del cohere_instance.AI_PROMPT
prompt = "What is the meaning of life?"
with pytest.raises(NameError):
cohere_instance._wrap_prompt(prompt)
def test_base_cohere_import():
with patch.dict("sys.modules", {"cohere": None}):
with pytest.raises(ImportError):
pass
def test_base_cohere_validate_environment():
values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"}
validated_values = BaseCohere.validate_environment(values)
assert "client" in validated_values
assert "async_client" in validated_values
def test_base_cohere_validate_environment_without_cohere():
values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"}
with patch.dict("sys.modules", {"cohere": None}):
with pytest.raises(ImportError):
BaseCohere.validate_environment(values)
# Test cases for benchmarking generations with various models
def test_cohere_generate_with_command_light(cohere_instance):
cohere_instance.model = "command-light"
response = cohere_instance("Generate text with Command Light model.")
assert response.startswith("Generated text with Command Light model")
def test_cohere_generate_with_command(cohere_instance):
cohere_instance.model = "command"
response = cohere_instance("Generate text with Command model.")
assert response.startswith("Generated text with Command model")
def test_cohere_generate_with_base_light(cohere_instance):
cohere_instance.model = "base-light"
response = cohere_instance("Generate text with Base Light model.")
assert response.startswith("Generated text with Base Light model")
def test_cohere_generate_with_base(cohere_instance):
cohere_instance.model = "base"
response = cohere_instance("Generate text with Base model.")
assert response.startswith("Generated text with Base model")
def test_cohere_generate_with_embed_english_v2(cohere_instance):
cohere_instance.model = "embed-english-v2.0"
response = cohere_instance("Generate embeddings with English v2.0 model.")
assert response.startswith("Generated embeddings with English v2.0 model")
def test_cohere_generate_with_embed_english_light_v2(cohere_instance):
cohere_instance.model = "embed-english-light-v2.0"
response = cohere_instance("Generate embeddings with English Light v2.0 model.")
assert response.startswith("Generated embeddings with English Light v2.0 model")
def test_cohere_generate_with_embed_multilingual_v2(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance("Generate embeddings with Multilingual v2.0 model.")
assert response.startswith("Generated embeddings with Multilingual v2.0 model")
def test_cohere_generate_with_embed_english_v3(cohere_instance):
cohere_instance.model = "embed-english-v3.0"
response = cohere_instance("Generate embeddings with English v3.0 model.")
assert response.startswith("Generated embeddings with English v3.0 model")
def test_cohere_generate_with_embed_english_light_v3(cohere_instance):
cohere_instance.model = "embed-english-light-v3.0"
response = cohere_instance("Generate embeddings with English Light v3.0 model.")
assert response.startswith("Generated embeddings with English Light v3.0 model")
def test_cohere_generate_with_embed_multilingual_v3(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance("Generate embeddings with Multilingual v3.0 model.")
assert response.startswith("Generated embeddings with Multilingual v3.0 model")
def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance):
cohere_instance.model = "embed-multilingual-light-v3.0"
response = cohere_instance(
"Generate embeddings with Multilingual Light v3.0 model."
)
assert response.startswith(
"Generated embeddings with Multilingual Light v3.0 model"
)
# Add more test cases to benchmark other models and functionalities
def test_cohere_call_with_command_model(cohere_instance):
cohere_instance.model = "command"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_base_model(cohere_instance):
cohere_instance.model = "base"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_embed_english_v2_model(cohere_instance):
cohere_instance.model = "embed-english-v2.0"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_embed_english_v3_model(cohere_instance):
cohere_instance.model = "embed-english-v3.0"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance("Translate to French.")
assert isinstance(response, str)
def test_cohere_call_with_invalid_model(cohere_instance):
cohere_instance.model = "invalid-model"
with pytest.raises(ValueError):
response = cohere_instance("Translate to French.")
def test_cohere_call_with_long_prompt(cohere_instance):
prompt = "This is a very long prompt. " * 100
response = cohere_instance(prompt)
assert isinstance(response, str)
def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance):
cohere_instance.max_tokens = 10
prompt = "This is a test prompt that will exceed the max tokens limit."
with pytest.raises(ValueError):
response = cohere_instance(prompt)
def test_cohere_stream_with_command_model(cohere_instance):
cohere_instance.model = "command"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_stream_with_base_model(cohere_instance):
cohere_instance.model = "base"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_stream_with_embed_english_v2_model(cohere_instance):
cohere_instance.model = "embed-english-v2.0"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_stream_with_embed_english_v3_model(cohere_instance):
cohere_instance.model = "embed-english-v3.0"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0"
generator = cohere_instance.stream("Write a story.")
for token in generator:
assert isinstance(token, str)
def test_cohere_async_call_with_command_model(cohere_instance):
cohere_instance.model = "command"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_base_model(cohere_instance):
cohere_instance.model = "base"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_english_v2_model(cohere_instance):
cohere_instance.model = "embed-english-v2.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_english_v3_model(cohere_instance):
cohere_instance.model = "embed-english-v3.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0"
response = cohere_instance.async_call("Translate to French.")
assert isinstance(response, str)
def test_cohere_async_stream_with_command_model(cohere_instance):
cohere_instance.model = "command"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_base_model(cohere_instance):
cohere_instance.model = "base"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance):
cohere_instance.model = "embed-english-v2.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance):
cohere_instance.model = "embed-english-v3.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v2.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance):
cohere_instance.model = "embed-multilingual-v3.0"
async_generator = cohere_instance.async_stream("Write a story.")
for token in async_generator:
assert isinstance(token, str)
def test_cohere_custom_configuration(cohere_instance):
# Test customizing Cohere configurations
cohere_instance.model = "base"
cohere_instance.temperature = 0.5
cohere_instance.max_tokens = 100
cohere_instance.k = 1
cohere_instance.p = 0.8
cohere_instance.frequency_penalty = 0.2
cohere_instance.presence_penalty = 0.4
response = cohere_instance("Customize configurations.")
assert isinstance(response, str)
def test_cohere_api_error_handling(cohere_instance):
# Test error handling when the API key is invalid
cohere_instance.model = "base"
cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception):
response = cohere_instance("Error handling with invalid API key.")
def test_cohere_async_api_error_handling(cohere_instance):
# Test async error handling when the API key is invalid
cohere_instance.model = "base"
cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception):
response = cohere_instance.async_call("Error handling with invalid API key.")
def test_cohere_stream_api_error_handling(cohere_instance):
# Test error handling in streaming mode when the API key is invalid
cohere_instance.model = "base"
cohere_instance.cohere_api_key = "invalid-api-key"
with pytest.raises(Exception):
generator = cohere_instance.stream("Error handling with invalid API key.")
for token in generator:
pass
def test_cohere_streaming_mode(cohere_instance):
# Test the streaming mode for large text generation
cohere_instance.model = "base"
cohere_instance.streaming = True
prompt = "Generate a lengthy text using streaming mode."
generator = cohere_instance.stream(prompt)
for token in generator:
assert isinstance(token, str)
def test_cohere_streaming_mode_async(cohere_instance):
# Test the async streaming mode for large text generation
cohere_instance.model = "base"
cohere_instance.streaming = True
prompt = "Generate a lengthy text using async streaming mode."
async_generator = cohere_instance.async_stream(prompt)
for token in async_generator:
assert isinstance(token, str)
def test_cohere_representation_model_embedding(cohere_instance):
# Test using the Representation model for text embedding
cohere_instance.model = "embed-english-v3.0"
embedding = cohere_instance.embed("Generate an embedding for this text.")
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_classification(cohere_instance):
# Test using the Representation model for text classification
cohere_instance.model = "embed-english-v3.0"
classification = cohere_instance.classify("Classify this text.")
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_language_detection(cohere_instance):
# Test using the Representation model for language detection
cohere_instance.model = "embed-english-v3.0"
language = cohere_instance.detect_language("Detect the language of this text.")
assert isinstance(language, str)
def test_cohere_representation_model_max_tokens_limit_exceeded(cohere_instance):
# Test handling max tokens limit exceeded error
cohere_instance.model = "embed-english-v3.0"
cohere_instance.max_tokens = 10
prompt = "This is a test prompt that will exceed the max tokens limit."
with pytest.raises(ValueError):
embedding = cohere_instance.embed(prompt)
# Add more production-grade test cases based on real-world scenarios
def test_cohere_representation_model_multilingual_embedding(cohere_instance):
# Test using the Representation model for multilingual text embedding
cohere_instance.model = "embed-multilingual-v3.0"
embedding = cohere_instance.embed("Generate multilingual embeddings.")
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_multilingual_classification(cohere_instance):
# Test using the Representation model for multilingual text classification
cohere_instance.model = "embed-multilingual-v3.0"
classification = cohere_instance.classify("Classify multilingual text.")
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_multilingual_language_detection(cohere_instance):
# Test using the Representation model for multilingual language detection
cohere_instance.model = "embed-multilingual-v3.0"
language = cohere_instance.detect_language(
"Detect the language of multilingual text."
)
assert isinstance(language, str)
def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded(
cohere_instance,
):
# Test handling max tokens limit exceeded error for multilingual model
cohere_instance.model = "embed-multilingual-v3.0"
cohere_instance.max_tokens = 10
prompt = "This is a test prompt that will exceed the max tokens limit for multilingual model."
with pytest.raises(ValueError):
embedding = cohere_instance.embed(prompt)
def test_cohere_representation_model_multilingual_light_embedding(cohere_instance):
# Test using the Representation model for multilingual light text embedding
cohere_instance.model = "embed-multilingual-light-v3.0"
embedding = cohere_instance.embed("Generate multilingual light embeddings.")
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_multilingual_light_classification(cohere_instance):
# Test using the Representation model for multilingual light text classification
cohere_instance.model = "embed-multilingual-light-v3.0"
classification = cohere_instance.classify("Classify multilingual light text.")
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_multilingual_light_language_detection(
cohere_instance,
):
# Test using the Representation model for multilingual light language detection
cohere_instance.model = "embed-multilingual-light-v3.0"
language = cohere_instance.detect_language(
"Detect the language of multilingual light text."
)
assert isinstance(language, str)
def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded(
cohere_instance,
):
# Test handling max tokens limit exceeded error for multilingual light model
cohere_instance.model = "embed-multilingual-light-v3.0"
cohere_instance.max_tokens = 10
prompt = "This is a test prompt that will exceed the max tokens limit for multilingual light model."
with pytest.raises(ValueError):
embedding = cohere_instance.embed(prompt)
def test_cohere_command_light_model(cohere_instance):
# Test using the Command Light model for text generation
cohere_instance.model = "command-light"
response = cohere_instance("Generate text using Command Light model.")
assert isinstance(response, str)
def test_cohere_base_light_model(cohere_instance):
# Test using the Base Light model for text generation
cohere_instance.model = "base-light"
response = cohere_instance("Generate text using Base Light model.")
assert isinstance(response, str)
def test_cohere_generate_summarize_endpoint(cohere_instance):
# Test using the Co.summarize() endpoint for text summarization
cohere_instance.model = "command"
response = cohere_instance.summarize("Summarize this text.")
assert isinstance(response, str)
def test_cohere_representation_model_english_embedding(cohere_instance):
# Test using the Representation model for English text embedding
cohere_instance.model = "embed-english-v3.0"
embedding = cohere_instance.embed("Generate English embeddings.")
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_english_classification(cohere_instance):
# Test using the Representation model for English text classification
cohere_instance.model = "embed-english-v3.0"
classification = cohere_instance.classify("Classify English text.")
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_english_language_detection(cohere_instance):
# Test using the Representation model for English language detection
cohere_instance.model = "embed-english-v3.0"
language = cohere_instance.detect_language("Detect the language of English text.")
assert isinstance(language, str)
def test_cohere_representation_model_english_max_tokens_limit_exceeded(cohere_instance):
# Test handling max tokens limit exceeded error for English model
cohere_instance.model = "embed-english-v3.0"
cohere_instance.max_tokens = 10
prompt = (
"This is a test prompt that will exceed the max tokens limit for English model."
)
with pytest.raises(ValueError):
embedding = cohere_instance.embed(prompt)
def test_cohere_representation_model_english_light_embedding(cohere_instance):
# Test using the Representation model for English light text embedding
cohere_instance.model = "embed-english-light-v3.0"
embedding = cohere_instance.embed("Generate English light embeddings.")
assert isinstance(embedding, list)
assert len(embedding) > 0
def test_cohere_representation_model_english_light_classification(cohere_instance):
# Test using the Representation model for English light text classification
cohere_instance.model = "embed-english-light-v3.0"
classification = cohere_instance.classify("Classify English light text.")
assert isinstance(classification, dict)
assert "class" in classification
assert "score" in classification
def test_cohere_representation_model_english_light_language_detection(cohere_instance):
# Test using the Representation model for English light language detection
cohere_instance.model = "embed-english-light-v3.0"
language = cohere_instance.detect_language(
"Detect the language of English light text."
)
assert isinstance(language, str)
def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
cohere_instance,
):
# Test handling max tokens limit exceeded error for English light model
cohere_instance.model = "embed-english-light-v3.0"
cohere_instance.max_tokens = 10
prompt = "This is a test prompt that will exceed the max tokens limit for English light model."
with pytest.raises(ValueError):
embedding = cohere_instance.embed(prompt)
def test_cohere_command_model(cohere_instance):
# Test using the Command model for text generation
cohere_instance.model = "command"
response = cohere_instance("Generate text using the Command model.")
assert isinstance(response, str)
# Add more production-grade test cases based on real-world scenarios
def test_cohere_invalid_model(cohere_instance):
# Test using an invalid model name
cohere_instance.model = "invalid-model"
with pytest.raises(ValueError):
response = cohere_instance("Generate text using an invalid model.")
def test_cohere_streaming_generation(cohere_instance):
# Test streaming generation with the Command model
cohere_instance.model = "command"
prompt = "Generate text using streaming."
chunks = list(cohere_instance.stream(prompt))
assert isinstance(chunks, list)
assert len(chunks) > 0
assert all(isinstance(chunk, GenerationChunk) for chunk in chunks)
def test_cohere_base_model_generation_with_max_tokens(cohere_instance):
# Test generating text using the base model with a specified max_tokens limit
cohere_instance.model = "base"
cohere_instance.max_tokens = 20
prompt = "Generate text with max_tokens limit."
response = cohere_instance(prompt)
assert len(response.split()) <= 20
def test_cohere_command_light_generation_with_stop(cohere_instance):
# Test generating text using the command-light model with stop words
cohere_instance.model = "command-light"
prompt = "Generate text with stop words."
stop = ["stop", "words"]
response = cohere_instance(prompt, stop=stop)
assert all(word not in response for word in stop)

@ -8,6 +8,7 @@ from swarms.swarms.flow import GroupChat, GroupChatManager
llm = OpenAIChat() llm = OpenAIChat()
llm2 = Anthropic() llm2 = Anthropic()
# Mock the OpenAI class for testing # Mock the OpenAI class for testing
class MockOpenAI: class MockOpenAI:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -126,6 +127,7 @@ def test_groupchat_manager_initialization(agent1, agent2):
assert manager.groupchat == groupchat assert manager.groupchat == groupchat
assert manager.selector == selector assert manager.selector == selector
# Test case to ensure GroupChatManager generates a reply from an agent # Test case to ensure GroupChatManager generates a reply from an agent
def test_groupchat_manager_generate_reply(): def test_groupchat_manager_generate_reply():
# Create a GroupChat with two agents # Create a GroupChat with two agents

Loading…
Cancel
Save