parent
bab2835472
commit
bf9a747fa3
@ -0,0 +1,6 @@
|
||||
from swarms.models.cohere_chat import Cohere
|
||||
|
||||
|
||||
cohere = Cohere(model="command-light", cohere_api_key="")
|
||||
|
||||
out = cohere("Hello, how are you?")
|
@ -1,335 +0,0 @@
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import (
|
||||
BaseChatModel,
|
||||
_agenerate_from_stream,
|
||||
_generate_from_stream,
|
||||
)
|
||||
from langchain.llms.cohere import BaseCohere
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
def get_role(message: BaseMessage) -> str:
|
||||
"""Get the role of the message.
|
||||
|
||||
Args:
|
||||
message: The message.
|
||||
|
||||
Returns:
|
||||
The role of the message.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message is of an unknown type.
|
||||
"""
|
||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||
return "User"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "Chatbot"
|
||||
elif isinstance(message, SystemMessage):
|
||||
return "System"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def get_cohere_chat_request(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
connectors: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the request for the Cohere chat API.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
connectors: The connectors.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The request for the Cohere chat API.
|
||||
"""
|
||||
documents = (
|
||||
None
|
||||
if "source_documents" not in kwargs
|
||||
else [
|
||||
{
|
||||
"snippet": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(kwargs["source_documents"])
|
||||
]
|
||||
)
|
||||
kwargs.pop("source_documents", None)
|
||||
maybe_connectors = connectors if documents is None else None
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
prompt_truncation = (
|
||||
"AUTO" if documents is not None or connectors is not None else None
|
||||
)
|
||||
|
||||
return {
|
||||
"message": messages[0].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
||||
],
|
||||
"documents": documents,
|
||||
"connectors": maybe_connectors,
|
||||
"prompt_truncation": prompt_truncation,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class CohereChat(BaseChatModel, BaseCohere):
|
||||
"""`Cohere` chat large language models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from swarms.models.cohere import CohereChat, HumanMessage
|
||||
|
||||
chat = CohereChat(model="foo")
|
||||
result = chat([HumanMessage(content="Hello")])
|
||||
print(result.content)
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "cohere-chat"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""
|
||||
Stream the output
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
stop: The stop tokens.
|
||||
run_manager: The callback manager.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
"""
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""
|
||||
Stream generations from the model.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
stop: The stop tokens.
|
||||
run_manager: The callback manager.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Yields:
|
||||
The generations.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
async for generation in model._astream(messages):
|
||||
print(generation.message.content)
|
||||
"""
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
stream = await self.async_client.chat(**request, stream=True)
|
||||
|
||||
async for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
|
||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||
"""Get the generation info from cohere API response."""
|
||||
return {
|
||||
"documents": response.documents,
|
||||
"citations": response.citations,
|
||||
"search_results": response.search_results,
|
||||
"search_queries": response.search_queries,
|
||||
"token_count": response.token_count,
|
||||
}
|
||||
|
||||
def _run(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Run the model.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
stop: The stop tokens.
|
||||
run_manager: The callback manager.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
result = model._run(messages)
|
||||
print(result.content)
|
||||
"""
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
__Call__ the model.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
stop: The stop tokens.
|
||||
run_manager: The callback manager.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result.
|
||||
"""
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Asynchronously run the model.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
stop: The stop tokens.
|
||||
run_manager: The callback manager.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
result = await model._arun(messages)
|
||||
print(result.content)
|
||||
|
||||
"""
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await _agenerate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request, stream=False)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
return len(self.client.tokenize(text).tokens)
|
@ -0,0 +1,247 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.load.serializable import Serializable
|
||||
from pydantic import Extra, Field, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
|
||||
import cohere
|
||||
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(retry_if_exception_type(cohere.error.CohereError)),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(llm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def acompletion_with_retry(llm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await llm.async_client.generate(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
class BaseCohere(Serializable):
|
||||
"""Base class for Cohere models."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
async_client: Any #: :meta private:
|
||||
model: Optional[str] = Field(default=None, description="Model name to use.")
|
||||
"""Model name to use."""
|
||||
|
||||
temperature: float = 0.75
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
cohere_api_key: Optional[str] = None
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to stream the results."""
|
||||
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
else:
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
client_name = values["user_agent"]
|
||||
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||
values["async_client"] = cohere.AsyncClient(
|
||||
cohere_api_key, client_name=client_name
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class Cohere(LLM, BaseCohere):
|
||||
"""Cohere large language models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Cohere
|
||||
|
||||
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
|
||||
"""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
k: int = 0
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
p: int = 1
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
|
||||
frequency_penalty: float = 0.0
|
||||
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
||||
|
||||
presence_penalty: float = 0.0
|
||||
"""Penalizes repeated tokens. Between 0 and 1."""
|
||||
|
||||
truncate: Optional[str] = None
|
||||
"""Specify how the client handles inputs longer than the maximum token
|
||||
length: Truncate from START, END or NONE"""
|
||||
|
||||
max_retries: int = 10
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"k": self.k,
|
||||
"p": self.p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"truncate": self.truncate,
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"cohere_api_key": "COHERE_API_KEY"}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "cohere"
|
||||
|
||||
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop_sequences"] = self.stop
|
||||
else:
|
||||
params["stop_sequences"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
||||
text = response.generations[0].text
|
||||
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||
# In order to make this consistent with other endpoints, we strip them.
|
||||
if stop:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Cohere's generate endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = cohere("Tell me a joke.")
|
||||
"""
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
response = completion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
_stop = params.get("stop_sequences")
|
||||
return self._process_response(response, _stop)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Async call out to Cohere's generate endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = await cohere("Tell me a joke.")
|
||||
"""
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
response = await acompletion_with_retry(
|
||||
self, model=self.model, prompt=prompt, **params
|
||||
)
|
||||
_stop = params.get("stop_sequences")
|
||||
return self._process_response(response, _stop)
|
@ -0,0 +1,655 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from cohere.models.response import GenerationChunk
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models.cohere_chat import BaseCohere, Cohere
|
||||
|
||||
# Load the environment variables
|
||||
load_dotenv()
|
||||
api_key = os.getenv("COHERE_API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cohere_instance():
|
||||
return Cohere(cohere_api_key=api_key)
|
||||
|
||||
|
||||
def test_cohere_wrap_prompt(cohere_instance):
|
||||
prompt = "What is the meaning of life?"
|
||||
wrapped_prompt = cohere_instance._wrap_prompt(prompt)
|
||||
assert wrapped_prompt.startswith(cohere_instance.HUMAN_PROMPT)
|
||||
assert wrapped_prompt.endswith(cohere_instance.AI_PROMPT)
|
||||
|
||||
|
||||
def test_cohere_convert_prompt(cohere_instance):
|
||||
prompt = "What is the meaning of life?"
|
||||
converted_prompt = cohere_instance.convert_prompt(prompt)
|
||||
assert converted_prompt.startswith(cohere_instance.HUMAN_PROMPT)
|
||||
assert converted_prompt.endswith(cohere_instance.AI_PROMPT)
|
||||
|
||||
|
||||
def test_cohere_call_with_stop(cohere_instance):
|
||||
response = cohere_instance("Translate to French.", stop=["stop1", "stop2"])
|
||||
assert response == "Mocked Response from Cohere"
|
||||
|
||||
|
||||
def test_cohere_stream_with_stop(cohere_instance):
|
||||
generator = cohere_instance.stream("Write a story.", stop=["stop1", "stop2"])
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_stop(cohere_instance):
|
||||
response = cohere_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"])
|
||||
assert response == "Mocked Response from Cohere"
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_stop(cohere_instance):
|
||||
async_generator = cohere_instance.async_stream(
|
||||
"Translate to French.", stop=["stop1", "stop2"]
|
||||
)
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_get_num_tokens_with_count_tokens(cohere_instance):
|
||||
cohere_instance.count_tokens = Mock(return_value=10)
|
||||
text = "This is a test sentence."
|
||||
num_tokens = cohere_instance.get_num_tokens(text)
|
||||
assert num_tokens == 10
|
||||
|
||||
|
||||
def test_cohere_get_num_tokens_without_count_tokens(cohere_instance):
|
||||
del cohere_instance.count_tokens
|
||||
with pytest.raises(NameError):
|
||||
text = "This is a test sentence."
|
||||
cohere_instance.get_num_tokens(text)
|
||||
|
||||
|
||||
def test_cohere_wrap_prompt_without_human_ai_prompt(cohere_instance):
|
||||
del cohere_instance.HUMAN_PROMPT
|
||||
del cohere_instance.AI_PROMPT
|
||||
prompt = "What is the meaning of life?"
|
||||
with pytest.raises(NameError):
|
||||
cohere_instance._wrap_prompt(prompt)
|
||||
|
||||
|
||||
def test_base_cohere_import():
|
||||
with patch.dict("sys.modules", {"cohere": None}):
|
||||
with pytest.raises(ImportError):
|
||||
pass
|
||||
|
||||
|
||||
def test_base_cohere_validate_environment():
|
||||
values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"}
|
||||
validated_values = BaseCohere.validate_environment(values)
|
||||
assert "client" in validated_values
|
||||
assert "async_client" in validated_values
|
||||
|
||||
|
||||
def test_base_cohere_validate_environment_without_cohere():
|
||||
values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"}
|
||||
with patch.dict("sys.modules", {"cohere": None}):
|
||||
with pytest.raises(ImportError):
|
||||
BaseCohere.validate_environment(values)
|
||||
|
||||
|
||||
# Test cases for benchmarking generations with various models
|
||||
def test_cohere_generate_with_command_light(cohere_instance):
|
||||
cohere_instance.model = "command-light"
|
||||
response = cohere_instance("Generate text with Command Light model.")
|
||||
assert response.startswith("Generated text with Command Light model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_command(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance("Generate text with Command model.")
|
||||
assert response.startswith("Generated text with Command model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_base_light(cohere_instance):
|
||||
cohere_instance.model = "base-light"
|
||||
response = cohere_instance("Generate text with Base Light model.")
|
||||
assert response.startswith("Generated text with Base Light model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_base(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
response = cohere_instance("Generate text with Base model.")
|
||||
assert response.startswith("Generated text with Base model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_v2(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
response = cohere_instance("Generate embeddings with English v2.0 model.")
|
||||
assert response.startswith("Generated embeddings with English v2.0 model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_light_v2(cohere_instance):
|
||||
cohere_instance.model = "embed-english-light-v2.0"
|
||||
response = cohere_instance("Generate embeddings with English Light v2.0 model.")
|
||||
assert response.startswith("Generated embeddings with English Light v2.0 model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_multilingual_v2(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
response = cohere_instance("Generate embeddings with Multilingual v2.0 model.")
|
||||
assert response.startswith("Generated embeddings with Multilingual v2.0 model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
response = cohere_instance("Generate embeddings with English v3.0 model.")
|
||||
assert response.startswith("Generated embeddings with English v3.0 model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_english_light_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
response = cohere_instance("Generate embeddings with English Light v3.0 model.")
|
||||
assert response.startswith("Generated embeddings with English Light v3.0 model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_multilingual_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
response = cohere_instance("Generate embeddings with Multilingual v3.0 model.")
|
||||
assert response.startswith("Generated embeddings with Multilingual v3.0 model")
|
||||
|
||||
|
||||
def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
response = cohere_instance(
|
||||
"Generate embeddings with Multilingual Light v3.0 model."
|
||||
)
|
||||
assert response.startswith(
|
||||
"Generated embeddings with Multilingual Light v3.0 model"
|
||||
)
|
||||
|
||||
|
||||
# Add more test cases to benchmark other models and functionalities
|
||||
|
||||
|
||||
def test_cohere_call_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_english_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_english_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
response = cohere_instance("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_invalid_model(cohere_instance):
|
||||
cohere_instance.model = "invalid-model"
|
||||
with pytest.raises(ValueError):
|
||||
response = cohere_instance("Translate to French.")
|
||||
|
||||
|
||||
def test_cohere_call_with_long_prompt(cohere_instance):
|
||||
prompt = "This is a very long prompt. " * 100
|
||||
response = cohere_instance(prompt)
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance):
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = "This is a test prompt that will exceed the max tokens limit."
|
||||
with pytest.raises(ValueError):
|
||||
response = cohere_instance(prompt)
|
||||
|
||||
|
||||
def test_cohere_stream_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_english_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_english_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
generator = cohere_instance.stream("Write a story.")
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_english_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_english_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
response = cohere_instance.async_call("Translate to French.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_command_model(cohere_instance):
|
||||
cohere_instance.model = "command"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_base_model(cohere_instance):
|
||||
cohere_instance.model = "base"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v2.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v2.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance):
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
async_generator = cohere_instance.async_stream("Write a story.")
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_custom_configuration(cohere_instance):
|
||||
# Test customizing Cohere configurations
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.temperature = 0.5
|
||||
cohere_instance.max_tokens = 100
|
||||
cohere_instance.k = 1
|
||||
cohere_instance.p = 0.8
|
||||
cohere_instance.frequency_penalty = 0.2
|
||||
cohere_instance.presence_penalty = 0.4
|
||||
response = cohere_instance("Customize configurations.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_api_error_handling(cohere_instance):
|
||||
# Test error handling when the API key is invalid
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||
with pytest.raises(Exception):
|
||||
response = cohere_instance("Error handling with invalid API key.")
|
||||
|
||||
|
||||
def test_cohere_async_api_error_handling(cohere_instance):
|
||||
# Test async error handling when the API key is invalid
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||
with pytest.raises(Exception):
|
||||
response = cohere_instance.async_call("Error handling with invalid API key.")
|
||||
|
||||
|
||||
def test_cohere_stream_api_error_handling(cohere_instance):
|
||||
# Test error handling in streaming mode when the API key is invalid
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||
with pytest.raises(Exception):
|
||||
generator = cohere_instance.stream("Error handling with invalid API key.")
|
||||
for token in generator:
|
||||
pass
|
||||
|
||||
|
||||
def test_cohere_streaming_mode(cohere_instance):
|
||||
# Test the streaming mode for large text generation
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.streaming = True
|
||||
prompt = "Generate a lengthy text using streaming mode."
|
||||
generator = cohere_instance.stream(prompt)
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_streaming_mode_async(cohere_instance):
|
||||
# Test the async streaming mode for large text generation
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.streaming = True
|
||||
prompt = "Generate a lengthy text using async streaming mode."
|
||||
async_generator = cohere_instance.async_stream(prompt)
|
||||
for token in async_generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_embedding(cohere_instance):
|
||||
# Test using the Representation model for text embedding
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
embedding = cohere_instance.embed("Generate an embedding for this text.")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_classification(cohere_instance):
|
||||
# Test using the Representation model for text classification
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
classification = cohere_instance.classify("Classify this text.")
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_language_detection(cohere_instance):
|
||||
# Test using the Representation model for language detection
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
language = cohere_instance.detect_language("Detect the language of this text.")
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_max_tokens_limit_exceeded(cohere_instance):
|
||||
# Test handling max tokens limit exceeded error
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = "This is a test prompt that will exceed the max tokens limit."
|
||||
with pytest.raises(ValueError):
|
||||
embedding = cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
# Add more production-grade test cases based on real-world scenarios
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_embedding(cohere_instance):
|
||||
# Test using the Representation model for multilingual text embedding
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
embedding = cohere_instance.embed("Generate multilingual embeddings.")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_classification(cohere_instance):
|
||||
# Test using the Representation model for multilingual text classification
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
classification = cohere_instance.classify("Classify multilingual text.")
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_language_detection(cohere_instance):
|
||||
# Test using the Representation model for multilingual language detection
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of multilingual text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for multilingual model
|
||||
cohere_instance.model = "embed-multilingual-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = "This is a test prompt that will exceed the max tokens limit for multilingual model."
|
||||
with pytest.raises(ValueError):
|
||||
embedding = cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_embedding(cohere_instance):
|
||||
# Test using the Representation model for multilingual light text embedding
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
embedding = cohere_instance.embed("Generate multilingual light embeddings.")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_classification(cohere_instance):
|
||||
# Test using the Representation model for multilingual light text classification
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
classification = cohere_instance.classify("Classify multilingual light text.")
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_language_detection(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test using the Representation model for multilingual light language detection
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of multilingual light text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for multilingual light model
|
||||
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = "This is a test prompt that will exceed the max tokens limit for multilingual light model."
|
||||
with pytest.raises(ValueError):
|
||||
embedding = cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_command_light_model(cohere_instance):
|
||||
# Test using the Command Light model for text generation
|
||||
cohere_instance.model = "command-light"
|
||||
response = cohere_instance("Generate text using Command Light model.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_base_light_model(cohere_instance):
|
||||
# Test using the Base Light model for text generation
|
||||
cohere_instance.model = "base-light"
|
||||
response = cohere_instance("Generate text using Base Light model.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_generate_summarize_endpoint(cohere_instance):
|
||||
# Test using the Co.summarize() endpoint for text summarization
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance.summarize("Summarize this text.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_embedding(cohere_instance):
|
||||
# Test using the Representation model for English text embedding
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
embedding = cohere_instance.embed("Generate English embeddings.")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_classification(cohere_instance):
|
||||
# Test using the Representation model for English text classification
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
classification = cohere_instance.classify("Classify English text.")
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_language_detection(cohere_instance):
|
||||
# Test using the Representation model for English language detection
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
language = cohere_instance.detect_language("Detect the language of English text.")
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_max_tokens_limit_exceeded(cohere_instance):
|
||||
# Test handling max tokens limit exceeded error for English model
|
||||
cohere_instance.model = "embed-english-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = (
|
||||
"This is a test prompt that will exceed the max tokens limit for English model."
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
embedding = cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_embedding(cohere_instance):
|
||||
# Test using the Representation model for English light text embedding
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
embedding = cohere_instance.embed("Generate English light embeddings.")
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) > 0
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_classification(cohere_instance):
|
||||
# Test using the Representation model for English light text classification
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
classification = cohere_instance.classify("Classify English light text.")
|
||||
assert isinstance(classification, dict)
|
||||
assert "class" in classification
|
||||
assert "score" in classification
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_language_detection(cohere_instance):
|
||||
# Test using the Representation model for English light language detection
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
language = cohere_instance.detect_language(
|
||||
"Detect the language of English light text."
|
||||
)
|
||||
assert isinstance(language, str)
|
||||
|
||||
|
||||
def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
|
||||
cohere_instance,
|
||||
):
|
||||
# Test handling max tokens limit exceeded error for English light model
|
||||
cohere_instance.model = "embed-english-light-v3.0"
|
||||
cohere_instance.max_tokens = 10
|
||||
prompt = "This is a test prompt that will exceed the max tokens limit for English light model."
|
||||
with pytest.raises(ValueError):
|
||||
embedding = cohere_instance.embed(prompt)
|
||||
|
||||
|
||||
def test_cohere_command_model(cohere_instance):
|
||||
# Test using the Command model for text generation
|
||||
cohere_instance.model = "command"
|
||||
response = cohere_instance("Generate text using the Command model.")
|
||||
assert isinstance(response, str)
|
||||
|
||||
|
||||
# Add more production-grade test cases based on real-world scenarios
|
||||
|
||||
|
||||
def test_cohere_invalid_model(cohere_instance):
|
||||
# Test using an invalid model name
|
||||
cohere_instance.model = "invalid-model"
|
||||
with pytest.raises(ValueError):
|
||||
response = cohere_instance("Generate text using an invalid model.")
|
||||
|
||||
|
||||
def test_cohere_streaming_generation(cohere_instance):
|
||||
# Test streaming generation with the Command model
|
||||
cohere_instance.model = "command"
|
||||
prompt = "Generate text using streaming."
|
||||
chunks = list(cohere_instance.stream(prompt))
|
||||
assert isinstance(chunks, list)
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, GenerationChunk) for chunk in chunks)
|
||||
|
||||
|
||||
def test_cohere_base_model_generation_with_max_tokens(cohere_instance):
|
||||
# Test generating text using the base model with a specified max_tokens limit
|
||||
cohere_instance.model = "base"
|
||||
cohere_instance.max_tokens = 20
|
||||
prompt = "Generate text with max_tokens limit."
|
||||
response = cohere_instance(prompt)
|
||||
assert len(response.split()) <= 20
|
||||
|
||||
|
||||
def test_cohere_command_light_generation_with_stop(cohere_instance):
|
||||
# Test generating text using the command-light model with stop words
|
||||
cohere_instance.model = "command-light"
|
||||
prompt = "Generate text with stop words."
|
||||
stop = ["stop", "words"]
|
||||
response = cohere_instance(prompt, stop=stop)
|
||||
assert all(word not in response for word in stop)
|
Loading…
Reference in new issue