parent
bab2835472
commit
bf9a747fa3
@ -0,0 +1,6 @@
|
|||||||
|
from swarms.models.cohere_chat import Cohere
|
||||||
|
|
||||||
|
|
||||||
|
cohere = Cohere(model="command-light", cohere_api_key="")
|
||||||
|
|
||||||
|
out = cohere("Hello, how are you?")
|
@ -1,335 +0,0 @@
|
|||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
)
|
|
||||||
from langchain.chat_models.base import (
|
|
||||||
BaseChatModel,
|
|
||||||
_agenerate_from_stream,
|
|
||||||
_generate_from_stream,
|
|
||||||
)
|
|
||||||
from langchain.llms.cohere import BaseCohere
|
|
||||||
from langchain.schema.messages import (
|
|
||||||
AIMessage,
|
|
||||||
AIMessageChunk,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
||||||
|
|
||||||
|
|
||||||
def get_role(message: BaseMessage) -> str:
|
|
||||||
"""Get the role of the message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The role of the message.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the message is of an unknown type.
|
|
||||||
"""
|
|
||||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
|
||||||
return "User"
|
|
||||||
elif isinstance(message, AIMessage):
|
|
||||||
return "Chatbot"
|
|
||||||
elif isinstance(message, SystemMessage):
|
|
||||||
return "System"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_cohere_chat_request(
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
*,
|
|
||||||
connectors: Optional[List[Dict[str, str]]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Get the request for the Cohere chat API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages.
|
|
||||||
connectors: The connectors.
|
|
||||||
**kwargs: The keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The request for the Cohere chat API.
|
|
||||||
"""
|
|
||||||
documents = (
|
|
||||||
None
|
|
||||||
if "source_documents" not in kwargs
|
|
||||||
else [
|
|
||||||
{
|
|
||||||
"snippet": doc.page_content,
|
|
||||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
|
||||||
}
|
|
||||||
for i, doc in enumerate(kwargs["source_documents"])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
kwargs.pop("source_documents", None)
|
|
||||||
maybe_connectors = connectors if documents is None else None
|
|
||||||
|
|
||||||
# by enabling automatic prompt truncation, the probability of request failure is
|
|
||||||
# reduced with minimal impact on response quality
|
|
||||||
prompt_truncation = (
|
|
||||||
"AUTO" if documents is not None or connectors is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"message": messages[0].content,
|
|
||||||
"chat_history": [
|
|
||||||
{"role": get_role(x), "message": x.content} for x in messages[1:]
|
|
||||||
],
|
|
||||||
"documents": documents,
|
|
||||||
"connectors": maybe_connectors,
|
|
||||||
"prompt_truncation": prompt_truncation,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CohereChat(BaseChatModel, BaseCohere):
|
|
||||||
"""`Cohere` chat large language models.
|
|
||||||
|
|
||||||
To use, you should have the ``cohere`` python package installed, and the
|
|
||||||
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
|
||||||
it as a named parameter to the constructor.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from swarms.models.cohere import CohereChat, HumanMessage
|
|
||||||
|
|
||||||
chat = CohereChat(model="foo")
|
|
||||||
result = chat([HumanMessage(content="Hello")])
|
|
||||||
print(result.content)
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
allow_population_by_field_name = True
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return type of chat model."""
|
|
||||||
return "cohere-chat"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the default parameters for calling Cohere API."""
|
|
||||||
return {
|
|
||||||
"temperature": self.temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {**{"model": self.model}, **self._default_params}
|
|
||||||
|
|
||||||
def _stream(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
|
||||||
"""
|
|
||||||
Stream the output
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages.
|
|
||||||
stop: The stop tokens.
|
|
||||||
run_manager: The callback manager.
|
|
||||||
**kwargs: The keyword arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
|
||||||
stream = self.client.chat(**request, stream=True)
|
|
||||||
|
|
||||||
for data in stream:
|
|
||||||
if data.event_type == "text-generation":
|
|
||||||
delta = data.text
|
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(delta)
|
|
||||||
|
|
||||||
async def _astream(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
|
||||||
"""
|
|
||||||
Stream generations from the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages.
|
|
||||||
stop: The stop tokens.
|
|
||||||
run_manager: The callback manager.
|
|
||||||
**kwargs: The keyword arguments.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
The generations.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
async for generation in model._astream(messages):
|
|
||||||
print(generation.message.content)
|
|
||||||
"""
|
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
|
||||||
stream = await self.async_client.chat(**request, stream=True)
|
|
||||||
|
|
||||||
async for data in stream:
|
|
||||||
if data.event_type == "text-generation":
|
|
||||||
delta = data.text
|
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(delta)
|
|
||||||
|
|
||||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
|
||||||
"""Get the generation info from cohere API response."""
|
|
||||||
return {
|
|
||||||
"documents": response.documents,
|
|
||||||
"citations": response.citations,
|
|
||||||
"search_results": response.search_results,
|
|
||||||
"search_queries": response.search_queries,
|
|
||||||
"token_count": response.token_count,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _run(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
"""
|
|
||||||
Run the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages.
|
|
||||||
stop: The stop tokens.
|
|
||||||
run_manager: The callback manager.
|
|
||||||
**kwargs: The keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
result = model._run(messages)
|
|
||||||
print(result.content)
|
|
||||||
"""
|
|
||||||
if self.streaming:
|
|
||||||
stream_iter = self._stream(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return _generate_from_stream(stream_iter)
|
|
||||||
|
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
|
||||||
response = self.client.chat(**request)
|
|
||||||
|
|
||||||
message = AIMessage(content=response.text)
|
|
||||||
generation_info = None
|
|
||||||
if hasattr(response, "documents"):
|
|
||||||
generation_info = self._get_generation_info(response)
|
|
||||||
return ChatResult(
|
|
||||||
generations=[
|
|
||||||
ChatGeneration(message=message, generation_info=generation_info)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
"""
|
|
||||||
__Call__ the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages.
|
|
||||||
stop: The stop tokens.
|
|
||||||
run_manager: The callback manager.
|
|
||||||
**kwargs: The keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result.
|
|
||||||
"""
|
|
||||||
if self.streaming:
|
|
||||||
stream_iter = self._stream(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return _generate_from_stream(stream_iter)
|
|
||||||
|
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
|
||||||
response = self.client.chat(**request)
|
|
||||||
|
|
||||||
message = AIMessage(content=response.text)
|
|
||||||
generation_info = None
|
|
||||||
if hasattr(response, "documents"):
|
|
||||||
generation_info = self._get_generation_info(response)
|
|
||||||
return ChatResult(
|
|
||||||
generations=[
|
|
||||||
ChatGeneration(message=message, generation_info=generation_info)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
"""
|
|
||||||
Asynchronously run the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages.
|
|
||||||
stop: The stop tokens.
|
|
||||||
run_manager: The callback manager.
|
|
||||||
**kwargs: The keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
result = await model._arun(messages)
|
|
||||||
print(result.content)
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.streaming:
|
|
||||||
stream_iter = self._astream(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await _agenerate_from_stream(stream_iter)
|
|
||||||
|
|
||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
|
||||||
response = self.client.chat(**request, stream=False)
|
|
||||||
|
|
||||||
message = AIMessage(content=response.text)
|
|
||||||
generation_info = None
|
|
||||||
if hasattr(response, "documents"):
|
|
||||||
generation_info = self._get_generation_info(response)
|
|
||||||
return ChatResult(
|
|
||||||
generations=[
|
|
||||||
ChatGeneration(message=message, generation_info=generation_info)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
"""Calculate number of tokens."""
|
|
||||||
return len(self.client.tokenize(text).tokens)
|
|
@ -0,0 +1,247 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
|
||||||
|
import cohere
|
||||||
|
|
||||||
|
min_seconds = 4
|
||||||
|
max_seconds = 10
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(llm.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(retry_if_exception_type(cohere.error.CohereError)),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(llm, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return llm.client.generate(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def acompletion_with_retry(llm, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return await llm.async_client.generate(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCohere(Serializable):
|
||||||
|
"""Base class for Cohere models."""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
async_client: Any #: :meta private:
|
||||||
|
model: Optional[str] = Field(default=None, description="Model name to use.")
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
temperature: float = 0.75
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
|
cohere_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
streaming: bool = Field(default=False)
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
user_agent: str = "langchain"
|
||||||
|
"""Identifier for the application making the request."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
try:
|
||||||
|
import cohere
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import cohere python package. "
|
||||||
|
"Please install it with `pip install cohere`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cohere_api_key = get_from_dict_or_env(
|
||||||
|
values, "cohere_api_key", "COHERE_API_KEY"
|
||||||
|
)
|
||||||
|
client_name = values["user_agent"]
|
||||||
|
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||||
|
values["async_client"] = cohere.AsyncClient(
|
||||||
|
cohere_api_key, client_name=client_name
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Cohere(LLM, BaseCohere):
|
||||||
|
"""Cohere large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``cohere`` python package installed, and the
|
||||||
|
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||||
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import Cohere
|
||||||
|
|
||||||
|
cohere = Cohere(model="gptd-instruct-tft", cohere_api_key="my-api-key")
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: int = 256
|
||||||
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
|
k: int = 0
|
||||||
|
"""Number of most likely tokens to consider at each step."""
|
||||||
|
|
||||||
|
p: int = 1
|
||||||
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
|
|
||||||
|
frequency_penalty: float = 0.0
|
||||||
|
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
||||||
|
|
||||||
|
presence_penalty: float = 0.0
|
||||||
|
"""Penalizes repeated tokens. Between 0 and 1."""
|
||||||
|
|
||||||
|
truncate: Optional[str] = None
|
||||||
|
"""Specify how the client handles inputs longer than the maximum token
|
||||||
|
length: Truncate from START, END or NONE"""
|
||||||
|
|
||||||
|
max_retries: int = 10
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling Cohere API."""
|
||||||
|
return {
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"k": self.k,
|
||||||
|
"p": self.p,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"truncate": self.truncate,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"cohere_api_key": "COHERE_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model": self.model}, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "cohere"
|
||||||
|
|
||||||
|
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
|
||||||
|
params = self._default_params
|
||||||
|
if self.stop is not None and stop is not None:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
elif self.stop is not None:
|
||||||
|
params["stop_sequences"] = self.stop
|
||||||
|
else:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
return {**params, **kwargs}
|
||||||
|
|
||||||
|
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
|
||||||
|
text = response.generations[0].text
|
||||||
|
# If stop tokens are provided, Cohere's endpoint returns them.
|
||||||
|
# In order to make this consistent with other endpoints, we strip them.
|
||||||
|
if stop:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to Cohere's generate endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = cohere("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
params = self._invocation_params(stop, **kwargs)
|
||||||
|
response = completion_with_retry(
|
||||||
|
self, model=self.model, prompt=prompt, **params
|
||||||
|
)
|
||||||
|
_stop = params.get("stop_sequences")
|
||||||
|
return self._process_response(response, _stop)
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Async call out to Cohere's generate endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = await cohere("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
params = self._invocation_params(stop, **kwargs)
|
||||||
|
response = await acompletion_with_retry(
|
||||||
|
self, model=self.model, prompt=prompt, **params
|
||||||
|
)
|
||||||
|
_stop = params.get("stop_sequences")
|
||||||
|
return self._process_response(response, _stop)
|
@ -0,0 +1,655 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cohere.models.response import GenerationChunk
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from swarms.models.cohere_chat import BaseCohere, Cohere
|
||||||
|
|
||||||
|
# Load the environment variables
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv("COHERE_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cohere_instance():
|
||||||
|
return Cohere(cohere_api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_wrap_prompt(cohere_instance):
|
||||||
|
prompt = "What is the meaning of life?"
|
||||||
|
wrapped_prompt = cohere_instance._wrap_prompt(prompt)
|
||||||
|
assert wrapped_prompt.startswith(cohere_instance.HUMAN_PROMPT)
|
||||||
|
assert wrapped_prompt.endswith(cohere_instance.AI_PROMPT)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_convert_prompt(cohere_instance):
|
||||||
|
prompt = "What is the meaning of life?"
|
||||||
|
converted_prompt = cohere_instance.convert_prompt(prompt)
|
||||||
|
assert converted_prompt.startswith(cohere_instance.HUMAN_PROMPT)
|
||||||
|
assert converted_prompt.endswith(cohere_instance.AI_PROMPT)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_stop(cohere_instance):
|
||||||
|
response = cohere_instance("Translate to French.", stop=["stop1", "stop2"])
|
||||||
|
assert response == "Mocked Response from Cohere"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_stop(cohere_instance):
|
||||||
|
generator = cohere_instance.stream("Write a story.", stop=["stop1", "stop2"])
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_stop(cohere_instance):
|
||||||
|
response = cohere_instance.async_call("Tell me a joke.", stop=["stop1", "stop2"])
|
||||||
|
assert response == "Mocked Response from Cohere"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_stop(cohere_instance):
|
||||||
|
async_generator = cohere_instance.async_stream(
|
||||||
|
"Translate to French.", stop=["stop1", "stop2"]
|
||||||
|
)
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_get_num_tokens_with_count_tokens(cohere_instance):
|
||||||
|
cohere_instance.count_tokens = Mock(return_value=10)
|
||||||
|
text = "This is a test sentence."
|
||||||
|
num_tokens = cohere_instance.get_num_tokens(text)
|
||||||
|
assert num_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_get_num_tokens_without_count_tokens(cohere_instance):
|
||||||
|
del cohere_instance.count_tokens
|
||||||
|
with pytest.raises(NameError):
|
||||||
|
text = "This is a test sentence."
|
||||||
|
cohere_instance.get_num_tokens(text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_wrap_prompt_without_human_ai_prompt(cohere_instance):
|
||||||
|
del cohere_instance.HUMAN_PROMPT
|
||||||
|
del cohere_instance.AI_PROMPT
|
||||||
|
prompt = "What is the meaning of life?"
|
||||||
|
with pytest.raises(NameError):
|
||||||
|
cohere_instance._wrap_prompt(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_cohere_import():
|
||||||
|
with patch.dict("sys.modules", {"cohere": None}):
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_cohere_validate_environment():
|
||||||
|
values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"}
|
||||||
|
validated_values = BaseCohere.validate_environment(values)
|
||||||
|
assert "client" in validated_values
|
||||||
|
assert "async_client" in validated_values
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_cohere_validate_environment_without_cohere():
|
||||||
|
values = {"cohere_api_key": "my-api-key", "user_agent": "langchain"}
|
||||||
|
with patch.dict("sys.modules", {"cohere": None}):
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
BaseCohere.validate_environment(values)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for benchmarking generations with various models
|
||||||
|
def test_cohere_generate_with_command_light(cohere_instance):
|
||||||
|
cohere_instance.model = "command-light"
|
||||||
|
response = cohere_instance("Generate text with Command Light model.")
|
||||||
|
assert response.startswith("Generated text with Command Light model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_command(cohere_instance):
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
response = cohere_instance("Generate text with Command model.")
|
||||||
|
assert response.startswith("Generated text with Command model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_base_light(cohere_instance):
|
||||||
|
cohere_instance.model = "base-light"
|
||||||
|
response = cohere_instance("Generate text with Base Light model.")
|
||||||
|
assert response.startswith("Generated text with Base Light model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_base(cohere_instance):
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
response = cohere_instance("Generate text with Base model.")
|
||||||
|
assert response.startswith("Generated text with Base model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_english_v2(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v2.0"
|
||||||
|
response = cohere_instance("Generate embeddings with English v2.0 model.")
|
||||||
|
assert response.startswith("Generated embeddings with English v2.0 model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_english_light_v2(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-light-v2.0"
|
||||||
|
response = cohere_instance("Generate embeddings with English Light v2.0 model.")
|
||||||
|
assert response.startswith("Generated embeddings with English Light v2.0 model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_multilingual_v2(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v2.0"
|
||||||
|
response = cohere_instance("Generate embeddings with Multilingual v2.0 model.")
|
||||||
|
assert response.startswith("Generated embeddings with Multilingual v2.0 model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_english_v3(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
response = cohere_instance("Generate embeddings with English v3.0 model.")
|
||||||
|
assert response.startswith("Generated embeddings with English v3.0 model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_english_light_v3(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-light-v3.0"
|
||||||
|
response = cohere_instance("Generate embeddings with English Light v3.0 model.")
|
||||||
|
assert response.startswith("Generated embeddings with English Light v3.0 model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_multilingual_v3(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
response = cohere_instance("Generate embeddings with Multilingual v3.0 model.")
|
||||||
|
assert response.startswith("Generated embeddings with Multilingual v3.0 model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_with_embed_multilingual_light_v3(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||||
|
response = cohere_instance(
|
||||||
|
"Generate embeddings with Multilingual Light v3.0 model."
|
||||||
|
)
|
||||||
|
assert response.startswith(
|
||||||
|
"Generated embeddings with Multilingual Light v3.0 model"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Add more test cases to benchmark other models and functionalities
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_command_model(cohere_instance):
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_base_model(cohere_instance):
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_embed_english_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v2.0"
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_embed_english_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_embed_multilingual_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v2.0"
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_embed_multilingual_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_invalid_model(cohere_instance):
|
||||||
|
cohere_instance.model = "invalid-model"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
response = cohere_instance("Translate to French.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_long_prompt(cohere_instance):
|
||||||
|
prompt = "This is a very long prompt. " * 100
|
||||||
|
response = cohere_instance(prompt)
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_call_with_max_tokens_limit_exceeded(cohere_instance):
|
||||||
|
cohere_instance.max_tokens = 10
|
||||||
|
prompt = "This is a test prompt that will exceed the max tokens limit."
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
response = cohere_instance(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_command_model(cohere_instance):
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
generator = cohere_instance.stream("Write a story.")
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_base_model(cohere_instance):
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
generator = cohere_instance.stream("Write a story.")
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_embed_english_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v2.0"
|
||||||
|
generator = cohere_instance.stream("Write a story.")
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_embed_english_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
generator = cohere_instance.stream("Write a story.")
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_embed_multilingual_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v2.0"
|
||||||
|
generator = cohere_instance.stream("Write a story.")
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_with_embed_multilingual_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
generator = cohere_instance.stream("Write a story.")
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_command_model(cohere_instance):
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
response = cohere_instance.async_call("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_base_model(cohere_instance):
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
response = cohere_instance.async_call("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_embed_english_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v2.0"
|
||||||
|
response = cohere_instance.async_call("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_embed_english_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
response = cohere_instance.async_call("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_embed_multilingual_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v2.0"
|
||||||
|
response = cohere_instance.async_call("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_call_with_embed_multilingual_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
response = cohere_instance.async_call("Translate to French.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_command_model(cohere_instance):
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
async_generator = cohere_instance.async_stream("Write a story.")
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_base_model(cohere_instance):
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
async_generator = cohere_instance.async_stream("Write a story.")
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_embed_english_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v2.0"
|
||||||
|
async_generator = cohere_instance.async_stream("Write a story.")
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_embed_english_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
async_generator = cohere_instance.async_stream("Write a story.")
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_embed_multilingual_v2_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v2.0"
|
||||||
|
async_generator = cohere_instance.async_stream("Write a story.")
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance):
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
async_generator = cohere_instance.async_stream("Write a story.")
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_custom_configuration(cohere_instance):
|
||||||
|
# Test customizing Cohere configurations
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.temperature = 0.5
|
||||||
|
cohere_instance.max_tokens = 100
|
||||||
|
cohere_instance.k = 1
|
||||||
|
cohere_instance.p = 0.8
|
||||||
|
cohere_instance.frequency_penalty = 0.2
|
||||||
|
cohere_instance.presence_penalty = 0.4
|
||||||
|
response = cohere_instance("Customize configurations.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_api_error_handling(cohere_instance):
|
||||||
|
# Test error handling when the API key is invalid
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
response = cohere_instance("Error handling with invalid API key.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_async_api_error_handling(cohere_instance):
|
||||||
|
# Test async error handling when the API key is invalid
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
response = cohere_instance.async_call("Error handling with invalid API key.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_stream_api_error_handling(cohere_instance):
|
||||||
|
# Test error handling in streaming mode when the API key is invalid
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.cohere_api_key = "invalid-api-key"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
generator = cohere_instance.stream("Error handling with invalid API key.")
|
||||||
|
for token in generator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_streaming_mode(cohere_instance):
|
||||||
|
# Test the streaming mode for large text generation
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.streaming = True
|
||||||
|
prompt = "Generate a lengthy text using streaming mode."
|
||||||
|
generator = cohere_instance.stream(prompt)
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_streaming_mode_async(cohere_instance):
|
||||||
|
# Test the async streaming mode for large text generation
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.streaming = True
|
||||||
|
prompt = "Generate a lengthy text using async streaming mode."
|
||||||
|
async_generator = cohere_instance.async_stream(prompt)
|
||||||
|
for token in async_generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_embedding(cohere_instance):
|
||||||
|
# Test using the Representation model for text embedding
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
embedding = cohere_instance.embed("Generate an embedding for this text.")
|
||||||
|
assert isinstance(embedding, list)
|
||||||
|
assert len(embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_classification(cohere_instance):
|
||||||
|
# Test using the Representation model for text classification
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
classification = cohere_instance.classify("Classify this text.")
|
||||||
|
assert isinstance(classification, dict)
|
||||||
|
assert "class" in classification
|
||||||
|
assert "score" in classification
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_language_detection(cohere_instance):
|
||||||
|
# Test using the Representation model for language detection
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
language = cohere_instance.detect_language("Detect the language of this text.")
|
||||||
|
assert isinstance(language, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_max_tokens_limit_exceeded(cohere_instance):
|
||||||
|
# Test handling max tokens limit exceeded error
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
cohere_instance.max_tokens = 10
|
||||||
|
prompt = "This is a test prompt that will exceed the max tokens limit."
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding = cohere_instance.embed(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
# Add more production-grade test cases based on real-world scenarios
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_embedding(cohere_instance):
|
||||||
|
# Test using the Representation model for multilingual text embedding
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
embedding = cohere_instance.embed("Generate multilingual embeddings.")
|
||||||
|
assert isinstance(embedding, list)
|
||||||
|
assert len(embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_classification(cohere_instance):
|
||||||
|
# Test using the Representation model for multilingual text classification
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
classification = cohere_instance.classify("Classify multilingual text.")
|
||||||
|
assert isinstance(classification, dict)
|
||||||
|
assert "class" in classification
|
||||||
|
assert "score" in classification
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_language_detection(cohere_instance):
|
||||||
|
# Test using the Representation model for multilingual language detection
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
language = cohere_instance.detect_language(
|
||||||
|
"Detect the language of multilingual text."
|
||||||
|
)
|
||||||
|
assert isinstance(language, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_max_tokens_limit_exceeded(
|
||||||
|
cohere_instance,
|
||||||
|
):
|
||||||
|
# Test handling max tokens limit exceeded error for multilingual model
|
||||||
|
cohere_instance.model = "embed-multilingual-v3.0"
|
||||||
|
cohere_instance.max_tokens = 10
|
||||||
|
prompt = "This is a test prompt that will exceed the max tokens limit for multilingual model."
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding = cohere_instance.embed(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_light_embedding(cohere_instance):
|
||||||
|
# Test using the Representation model for multilingual light text embedding
|
||||||
|
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||||
|
embedding = cohere_instance.embed("Generate multilingual light embeddings.")
|
||||||
|
assert isinstance(embedding, list)
|
||||||
|
assert len(embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_light_classification(cohere_instance):
|
||||||
|
# Test using the Representation model for multilingual light text classification
|
||||||
|
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||||
|
classification = cohere_instance.classify("Classify multilingual light text.")
|
||||||
|
assert isinstance(classification, dict)
|
||||||
|
assert "class" in classification
|
||||||
|
assert "score" in classification
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_light_language_detection(
|
||||||
|
cohere_instance,
|
||||||
|
):
|
||||||
|
# Test using the Representation model for multilingual light language detection
|
||||||
|
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||||
|
language = cohere_instance.detect_language(
|
||||||
|
"Detect the language of multilingual light text."
|
||||||
|
)
|
||||||
|
assert isinstance(language, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_multilingual_light_max_tokens_limit_exceeded(
|
||||||
|
cohere_instance,
|
||||||
|
):
|
||||||
|
# Test handling max tokens limit exceeded error for multilingual light model
|
||||||
|
cohere_instance.model = "embed-multilingual-light-v3.0"
|
||||||
|
cohere_instance.max_tokens = 10
|
||||||
|
prompt = "This is a test prompt that will exceed the max tokens limit for multilingual light model."
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding = cohere_instance.embed(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_command_light_model(cohere_instance):
|
||||||
|
# Test using the Command Light model for text generation
|
||||||
|
cohere_instance.model = "command-light"
|
||||||
|
response = cohere_instance("Generate text using Command Light model.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_base_light_model(cohere_instance):
|
||||||
|
# Test using the Base Light model for text generation
|
||||||
|
cohere_instance.model = "base-light"
|
||||||
|
response = cohere_instance("Generate text using Base Light model.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_generate_summarize_endpoint(cohere_instance):
|
||||||
|
# Test using the Co.summarize() endpoint for text summarization
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
response = cohere_instance.summarize("Summarize this text.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_embedding(cohere_instance):
|
||||||
|
# Test using the Representation model for English text embedding
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
embedding = cohere_instance.embed("Generate English embeddings.")
|
||||||
|
assert isinstance(embedding, list)
|
||||||
|
assert len(embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_classification(cohere_instance):
|
||||||
|
# Test using the Representation model for English text classification
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
classification = cohere_instance.classify("Classify English text.")
|
||||||
|
assert isinstance(classification, dict)
|
||||||
|
assert "class" in classification
|
||||||
|
assert "score" in classification
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_language_detection(cohere_instance):
|
||||||
|
# Test using the Representation model for English language detection
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
language = cohere_instance.detect_language("Detect the language of English text.")
|
||||||
|
assert isinstance(language, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_max_tokens_limit_exceeded(cohere_instance):
|
||||||
|
# Test handling max tokens limit exceeded error for English model
|
||||||
|
cohere_instance.model = "embed-english-v3.0"
|
||||||
|
cohere_instance.max_tokens = 10
|
||||||
|
prompt = (
|
||||||
|
"This is a test prompt that will exceed the max tokens limit for English model."
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding = cohere_instance.embed(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_light_embedding(cohere_instance):
|
||||||
|
# Test using the Representation model for English light text embedding
|
||||||
|
cohere_instance.model = "embed-english-light-v3.0"
|
||||||
|
embedding = cohere_instance.embed("Generate English light embeddings.")
|
||||||
|
assert isinstance(embedding, list)
|
||||||
|
assert len(embedding) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_light_classification(cohere_instance):
|
||||||
|
# Test using the Representation model for English light text classification
|
||||||
|
cohere_instance.model = "embed-english-light-v3.0"
|
||||||
|
classification = cohere_instance.classify("Classify English light text.")
|
||||||
|
assert isinstance(classification, dict)
|
||||||
|
assert "class" in classification
|
||||||
|
assert "score" in classification
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_light_language_detection(cohere_instance):
|
||||||
|
# Test using the Representation model for English light language detection
|
||||||
|
cohere_instance.model = "embed-english-light-v3.0"
|
||||||
|
language = cohere_instance.detect_language(
|
||||||
|
"Detect the language of English light text."
|
||||||
|
)
|
||||||
|
assert isinstance(language, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_representation_model_english_light_max_tokens_limit_exceeded(
|
||||||
|
cohere_instance,
|
||||||
|
):
|
||||||
|
# Test handling max tokens limit exceeded error for English light model
|
||||||
|
cohere_instance.model = "embed-english-light-v3.0"
|
||||||
|
cohere_instance.max_tokens = 10
|
||||||
|
prompt = "This is a test prompt that will exceed the max tokens limit for English light model."
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
embedding = cohere_instance.embed(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_command_model(cohere_instance):
|
||||||
|
# Test using the Command model for text generation
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
response = cohere_instance("Generate text using the Command model.")
|
||||||
|
assert isinstance(response, str)
|
||||||
|
|
||||||
|
|
||||||
|
# Add more production-grade test cases based on real-world scenarios
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_invalid_model(cohere_instance):
|
||||||
|
# Test using an invalid model name
|
||||||
|
cohere_instance.model = "invalid-model"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
response = cohere_instance("Generate text using an invalid model.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_streaming_generation(cohere_instance):
|
||||||
|
# Test streaming generation with the Command model
|
||||||
|
cohere_instance.model = "command"
|
||||||
|
prompt = "Generate text using streaming."
|
||||||
|
chunks = list(cohere_instance.stream(prompt))
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, GenerationChunk) for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_base_model_generation_with_max_tokens(cohere_instance):
|
||||||
|
# Test generating text using the base model with a specified max_tokens limit
|
||||||
|
cohere_instance.model = "base"
|
||||||
|
cohere_instance.max_tokens = 20
|
||||||
|
prompt = "Generate text with max_tokens limit."
|
||||||
|
response = cohere_instance(prompt)
|
||||||
|
assert len(response.split()) <= 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_cohere_command_light_generation_with_stop(cohere_instance):
|
||||||
|
# Test generating text using the command-light model with stop words
|
||||||
|
cohere_instance.model = "command-light"
|
||||||
|
prompt = "Generate text with stop words."
|
||||||
|
stop = ["stop", "words"]
|
||||||
|
response = cohere_instance(prompt, stop=stop)
|
||||||
|
assert all(word not in response for word in stop)
|
Loading…
Reference in new issue