mistral caller, openai verison 2.8, llama function caller, tests for flow

Former-commit-id: 699c943394
clean-history
Kye 1 year ago
parent 091be7d4bb
commit d6b037c211

1
.gitignore vendored

@ -24,6 +24,7 @@ stderr_log.txt
__pycache__/
*.py[cod]
*$py.class
.grit
error.txt
# C extensions

@ -74,10 +74,7 @@ flow = Flow(
# out = flow.load_state("flow_state.json")
# temp = flow.dynamic_temperature()
# filter = flow.add_response_filter("Trump")
# Run the flow
out = flow.run("Generate a 10,000 word blog on health and wellness.")
# out = flow.validate_response(out)
# out = flow.analyze_feedback(out)
# out = flow.print_history_and_memory()
@ -138,6 +135,77 @@ for task in workflow.tasks:
---
# Features 🤖
The Swarms framework is designed with a strong emphasis on reliability, performance, and production-grade readiness.
Below are the key features that make Swarms an ideal choice for enterprise-level AI deployments.
## 🚀 Production-Grade Readiness
- **Scalable Architecture**: Built to scale effortlessly with your growing business needs.
- **Enterprise-Level Security**: Incorporates top-notch security features to safeguard your data and operations.
- **Containerization and Microservices**: Easily deployable in containerized environments, supporting microservices architecture.
## ⚙️ Reliability and Robustness
- **Fault Tolerance**: Designed to handle failures gracefully, ensuring uninterrupted operations.
- **Consistent Performance**: Maintains high performance even under heavy loads or complex computational demands.
- **Automated Backup and Recovery**: Features automatic backup and recovery processes, reducing the risk of data loss.
## 💡 Advanced AI Capabilities
The Swarms framework is equipped with a suite of advanced AI capabilities designed to cater to a wide range of applications and scenarios, ensuring versatility and cutting-edge performance.
### Multi-Modal Autonomous Agents
- **Versatile Model Support**: Seamlessly works with various AI models, including NLP, computer vision, and more, for comprehensive multi-modal capabilities.
- **Context-Aware Processing**: Employs context-aware processing techniques to ensure relevant and accurate responses from agents.
### Function Calling Models for API Execution
- **Automated API Interactions**: Function calling models that can autonomously execute API calls, enabling seamless integration with external services and data sources.
- **Dynamic Response Handling**: Capable of processing and adapting to responses from APIs for real-time decision making.
### Varied Architectures of Swarms
- **Flexible Configuration**: Supports multiple swarm architectures, from centralized to decentralized, for diverse application needs.
- **Customizable Agent Roles**: Allows customization of agent roles and behaviors within the swarm to optimize performance and efficiency.
### Generative Models
- **Advanced Generative Capabilities**: Incorporates state-of-the-art generative models to create content, simulate scenarios, or predict outcomes.
- **Creative Problem Solving**: Utilizes generative AI for innovative problem-solving approaches and idea generation.
### Enhanced Decision-Making
- **AI-Powered Decision Algorithms**: Employs advanced algorithms for swift and effective decision-making in complex scenarios.
- **Risk Assessment and Management**: Capable of assessing risks and managing uncertain situations with AI-driven insights.
### Real-Time Adaptation and Learning
- **Continuous Learning**: Agents can continuously learn and adapt from new data, improving their performance and accuracy over time.
- **Environment Adaptability**: Designed to adapt to different operational environments, enhancing robustness and reliability.
## 🔄 Efficient Workflow Automation
- **Streamlined Task Management**: Simplifies complex tasks with automated workflows, reducing manual intervention.
- **Customizable Workflows**: Offers customizable workflow options to fit specific business needs and requirements.
- **Real-Time Analytics and Reporting**: Provides real-time insights into agent performance and system health.
## 🌐 Wide-Ranging Integration
- **API-First Design**: Easily integrates with existing systems and third-party applications via robust APIs.
- **Cloud Compatibility**: Fully compatible with major cloud platforms for flexible deployment options.
- **Continuous Integration/Continuous Deployment (CI/CD)**: Supports CI/CD practices for seamless updates and deployment.
## 📊 Performance Optimization
- **Resource Management**: Efficiently manages computational resources for optimal performance.
- **Load Balancing**: Automatically balances workloads to maintain system stability and responsiveness.
- **Performance Monitoring Tools**: Includes comprehensive monitoring tools for tracking and optimizing performance.
## 🛡️ Security and Compliance
- **Data Encryption**: Implements end-to-end encryption for data at rest and in transit.
- **Compliance Standards Adherence**: Adheres to major compliance standards ensuring legal and ethical usage.
- **Regular Security Updates**: Regular updates to address emerging security threats and vulnerabilities.
## 💬 Community and Support
- **Extensive Documentation**: Detailed documentation for easy implementation and troubleshooting.
- **Active Developer Community**: A vibrant community for sharing ideas, solutions, and best practices.
- **Professional Support**: Access to professional support for enterprise-level assistance and guidance.
Swarms framework is not just a tool but a robust, scalable, and secure partner in your AI journey, ready to tackle the challenges of modern AI applications in a business environment.
## Documentation
- For documentation, go here, [swarms.apac.ai](https://swarms.apac.ai)

@ -1,7 +1,7 @@
from swarms.models import OpenAIChat
from swarms.structs import Flow
api_key = ""
api_key = "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"
# Initialize the language model, this model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC
llm = OpenAIChat(

@ -1,8 +1,6 @@
from swarms.models import Dalle3
dalle3 = Dalle3(
openai_api_key=""
)
dalle3 = Dalle3(openai_api_key="")
task = "A painting of a dog"
image_url = dalle3(task)
print(image_url)

@ -45,6 +45,7 @@ wget = "*"
griptape = "*"
httpx = "*"
tiktoken = "*"
safetensors = "*"
attrs = "*"
ggl = "*"
ratelimit = "*"

@ -48,6 +48,7 @@ opencv-python-headless
imageio-ffmpeg
invisible-watermark
kornia
safetensors
numpy
omegaconf
open_clip_torch

@ -10,6 +10,20 @@ class Artifact(BaseModel):
"""
Artifact that has the task has been produced
Attributes:
-----------
artifact_id: str
ID of the artifact
file_name: str
Filename of the artifact
relative_path: str
Relative path of the artifact
"""
artifact_id: StrictStr = Field(..., description="ID of the artifact")

@ -2,7 +2,7 @@ import uuid
from typing import Optional
from attr import define, field, Factory
from dataclasses import dataclass
from swarms.memory.vector_stores.base import BaseVectorStore
from swarms.memory.base import BaseVectorStore
from sqlalchemy.engine import Engine
from sqlalchemy import create_engine, Column, String, JSON
from sqlalchemy.ext.declarative import declarative_base

@ -6,20 +6,21 @@ sys.stderr = log_file
# LLMs
from swarms.models.anthropic import Anthropic # noqa: E402
from swarms.models.petals import Petals # noqa: E402
from swarms.models.mistral import Mistral # noqa: E402
from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat # noqa: E402
from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
from swarms.models.wizard_storytelling import WizardLLMStoryTeller # noqa: E402
from swarms.models.mpt import MPT7B # noqa: E402
from swarms.models.mistral import Mistral # noqa: E402
from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat # noqa: E402
from swarms.models.zephyr import Zephyr # noqa: E402
from swarms.models.biogpt import BioGPT # noqa: E402
from swarms.models.huggingface import HuggingfaceLLM # noqa: E402
from swarms.models.wizard_storytelling import WizardLLMStoryTeller # noqa: E402
from swarms.models.mpt import MPT7B # noqa: E402
# MultiModal Models
from swarms.models.idefics import Idefics # noqa: E402
from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.vilt import Vilt # noqa: E402
from swarms.models.nougat import Nougat # noqa: E402
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402
from swarms.models.idefics import Idefics # noqa: E402
from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.vilt import Vilt # noqa: E402
from swarms.models.nougat import Nougat # noqa: E402
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402
# from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402

@ -1,4 +1,3 @@
import concurrent.futures
import logging
import os
@ -6,6 +5,7 @@ import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import List
import backoff
import openai
import requests

@ -83,5 +83,3 @@ class Fuyu:
except requests.RequestException as error:
print(f"Error fetching image from {img_url} and error: {error}")
return None

@ -1,9 +1,13 @@
import asyncio
import concurrent.futures
import logging
from typing import List, Tuple
import torch
from termcolor import colored
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from termcolor import colored
class HuggingfaceLLM:
@ -43,6 +47,12 @@ class HuggingfaceLLM:
# logger=None,
distributed=False,
decoding=False,
max_workers: int = 5,
repitition_penalty: float = 1.3,
no_repeat_ngram_size: int = 5,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.8,
*args,
**kwargs,
):
@ -56,6 +66,14 @@ class HuggingfaceLLM:
self.distributed = distributed
self.decoding = decoding
self.model, self.tokenizer = None, None
self.quantize = quantize
self.quantization_config = quantization_config
self.max_workers = max_workers
self.repitition_penalty = repitition_penalty
self.no_repeat_ngram_size = no_repeat_ngram_size
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
if self.distributed:
assert (
@ -91,6 +109,10 @@ class HuggingfaceLLM:
"""Print error"""
print(colored(f"Error: {error}", "red"))
async def async_run(self, task: str):
"""Ashcnronous generate text for a given prompt"""
return await asyncio.to_thread(self.run, task)
def load_model(self):
"""Load the model"""
if not self.model or not self.tokenizer:
@ -113,6 +135,21 @@ class HuggingfaceLLM:
self.logger.error(f"Failed to load the model or the tokenizer: {error}")
raise
def concurrent_run(self, tasks: List[str], max_workers: int = 5):
"""Concurrently generate text for a list of prompts."""
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(self.run, tasks))
return results
def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
"""Process a batch of tasks and images"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.run, task, img) for task, img in tasks_images
]
results = [future.result() for future in futures]
return results
def run(self, task: str):
"""
Generate a response based on the prompt text.
@ -175,29 +212,6 @@ class HuggingfaceLLM:
)
raise
async def run_async(self, task: str, *args, **kwargs) -> str:
"""
Run the model asynchronously
Args:
task (str): Task to run.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Examples:
>>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150)
>>> mpt_instance("generate", "Once upon a time in a land far, far away...")
'Once upon a time in a land far, far away...'
>>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7)
['In the deep jungles,',
'At the heart of the city,']
>>> mpt_instance.freeze_model()
>>> mpt_instance.unfreeze_model()
"""
# Wrapping synchronous calls with async
return self.run(task, *args, **kwargs)
def __call__(self, task: str):
"""
Generate a response based on the prompt text.

@ -1,676 +0,0 @@
from __future__ import annotations
import logging
import os
import sys
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import (
BaseChatModel,
)
from langchain.llms.base import create_base_retry_decorator
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import (
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import Runnable
from langchain.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain.utils.openai import is_openai_v1
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__)
def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
async def _agenerate_from_stream(
stream: AsyncIterator[ChatGenerationChunk],
) -> ChatResult:
generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
def _import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install tiktoken`."
)
return tiktoken
def _create_retry_decorator(
llm: OpenAIChat,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import openai
errors = [
openai.Timeout,
openai.APIError,
openai.APIConnectionError,
openai.RateLimitError,
openai.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
async def acompletion_with_retry(
llm: OpenAIChat,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
if is_openai_v1():
return await llm.async_client.create(**kwargs)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.acreate(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
class OpenAIChat(BaseChatModel):
"""`OpenAI` Chat large language models API.
To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from swarms.models import ChatOpenAI
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.openai_organization:
attributes["openai_organization"] = self.openai_organization
if self.openai_api_base:
attributes["openai_api_base"] = self.openai_api_base
if self.openai_proxy:
attributes["openai_proxy"] = self.openai_proxy
return attributes
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
# When updating this to use a SecretStr
# Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str)
# openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
# Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.OpenAI(**client_params).chat.completions
values["async_client"] = openai.AsyncOpenAI(
**client_params
).chat.completions
else:
values["client"] = openai.ChatCompletion
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
params = {
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**self.model_kwargs,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.request_timeout is not None and not is_openai_v1():
params["request_timeout"] = self.request_timeout
return params
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
if is_openai_v1():
return self.client.create(**kwargs)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await _agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
openai_creds: Dict[str, Any] = {
"model": self.model_name,
}
if not is_openai_v1():
openai_creds.update(
{
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"organization": self.openai_organization,
}
)
if self.openai_proxy:
import openai
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'") # type: ignore[assignment] # noqa: E501
return {**self._default_params, **openai_creds}
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
"model": self.model_name,
**super()._get_invocation_params(stop=stop),
**self._default_params,
**kwargs,
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "openai-chat"
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_token_ids(self, text: str) -> List[int]:
"""Get the tokens present in the text with tiktoken package."""
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
return super().get_token_ids(text)
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
if model.startswith("gpt-3.5-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
function_call: Optional[str] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind functions (and other objects) to this chat model.
Args:
functions: A list of function definitions to bind to this chat model.
Can be a dictionary, pydantic model, or callable. Pydantic
models and callables will be automatically converted to
their schema dictionary representation.
function_call: Which function to require the model to call.
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any).
kwargs: Any additional parameters to pass to the
:class:`~swarms.runnable.Runnable` constructor.
"""
from langchain.chains.openai_functions.base import convert_to_openai_function
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None:
if len(formatted_functions) != 1:
raise ValueError(
"When specifying `function_call`, you must provide exactly one "
"function."
)
if formatted_functions[0]["name"] != function_call:
raise ValueError(
f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}."
)
function_call_ = {"name": function_call}
kwargs = {**kwargs, "function_call": function_call_}
return super().bind(
functions=formatted_functions,
**kwargs,
)

@ -30,9 +30,19 @@ from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils.utils import build_extra_kwargs
from importlib.metadata import version
from packaging.version import parse
logger = logging.getLogger(__name__)
def is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version.major >= 1
def update_token_usage(
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
) -> None:
@ -79,24 +89,24 @@ def _streaming_response_template() -> Dict[str, Any]:
}
# def _create_retry_decorator(
# llm: Union[BaseOpenAI, OpenAIChat],
# run_manager: Optional[
# Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
# ] = None,
# ) -> Callable[[Any], Any]:
# import openai
# errors = [
# openai.Timeout,
# openai.APIError,
# openai.error.APIConnectionError,
# openai.error.RateLimitError,
# openai.error.ServiceUnavailableError,
# ]
# return create_base_retry_decorator(
# error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
# )
def _create_retry_decorator(
llm: Union[BaseOpenAI, OpenAIChat],
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import openai
errors = [
openai.error.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def completion_with_retry(
@ -105,9 +115,9 @@ def completion_with_retry(
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
# retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
# @retry_decorator
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return llm.client.create(**kwargs)
@ -120,9 +130,9 @@ async def acompletion_with_retry(
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
# retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
# @retry_decorator
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.acreate(**kwargs)
@ -500,11 +510,7 @@ class BaseOpenAI(BaseLLM):
if self.openai_proxy:
import openai
# raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g.",
# 'OpenAI(proxy={
# "http": self.openai_proxy,
# "https": self.openai_proxy,
# })'") # type: ignore[assignment] # noqa: E501
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params}
@property
@ -632,14 +638,13 @@ class OpenAI(BaseOpenAI):
environment variable ``OPENAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class..,
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from swarms.models import OpenAI
from langchain.llms import OpenAI
openai = OpenAI(model_name="text-davinci-003")
openai("What is the report on the 2022 oympian games?")
"""
@property
@ -659,7 +664,7 @@ class AzureOpenAI(BaseOpenAI):
Example:
.. code-block:: python
from swarms.models import AzureOpenAI
from langchain.llms import AzureOpenAI
openai = AzureOpenAI(model_name="text-davinci-003")
"""
@ -721,7 +726,7 @@ class OpenAIChat(BaseLLM):
Example:
.. code-block:: python
from swarms.models import OpenAIChat
from langchain.llms import OpenAIChat
openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
"""
@ -783,11 +788,13 @@ class OpenAIChat(BaseLLM):
try:
import openai
openai.api_key = openai_api_key
if openai_api_base:
raise Exception("The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(api_base=openai_api_base)'")
openai.api_base = openai_api_base
if openai_organization:
raise Exception("The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=openai_organization)'")
openai.organization = openai_organization
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
except ImportError:
raise ImportError(
"Could not import openai python package. "

@ -1,148 +0,0 @@
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Optional
import tiktoken
from attr import Factory, define, field
@define(frozen=True)
class BaseTokenizer(ABC):
DEFAULT_STOP_SEQUENCES = ["Observation:"]
stop_sequences: list[str] = field(
default=Factory(lambda: BaseTokenizer.DEFAULT_STOP_SEQUENCES),
kw_only=True,
)
@property
@abstractmethod
def max_tokens(self) -> int:
...
def count_tokens_left(self, text: str) -> int:
diff = self.max_tokens - self.count_tokens(text)
if diff > 0:
return diff
else:
return 0
@abstractmethod
def count_tokens(self, text: str) -> int:
...
@define(frozen=True)
class OpenAITokenizer(BaseTokenizer):
DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "text-davinci-003"
DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo"
DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4"
DEFAULT_ENCODING = "cl100k_base"
DEFAULT_MAX_TOKENS = 2049
TOKEN_OFFSET = 8
MODEL_PREFIXES_TO_MAX_TOKENS = {
"gpt-4-32k": 32768,
"gpt-4": 8192,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-35-turbo": 4096,
"text-davinci-003": 4097,
"text-davinci-002": 4097,
"code-davinci-002": 8001,
"text-embedding-ada-002": 8191,
"text-embedding-ada-001": 2046,
}
EMBEDDING_MODELS = ["text-embedding-ada-002", "text-embedding-ada-001"]
model: str = field(kw_only=True)
@property
def encoding(self) -> tiktoken.Encoding:
try:
return tiktoken.encoding_for_model(self.model)
except KeyError:
return tiktoken.get_encoding(self.DEFAULT_ENCODING)
@property
def max_tokens(self) -> int:
tokens = next(
v
for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
if self.model.startswith(k)
)
offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET
return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset
def count_tokens(self, text: str | list, model: Optional[str] = None) -> int:
"""
Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook:
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
if isinstance(text, list):
model = model if model else self.model
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logging.warning("model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
# every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
logging.info(
"gpt-3.5-turbo may update over time. Returning num tokens assuming"
" gpt-3.5-turbo-0613."
)
return self.count_tokens(text, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logging.info(
"gpt-4 may update over time. Returning num tokens assuming"
" gpt-4-0613."
)
return self.count_tokens(text, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""token_count() is not implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for
information on how messages are converted to tokens."""
)
num_tokens = 0
for message in text:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <|start|>assistant<|message|>
num_tokens += 3
return num_tokens
else:
return len(
self.encoding.encode(text, allowed_special=set(self.stop_sequences))
)

@ -2,8 +2,6 @@ from openai import OpenAI
client = OpenAI()
def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
"""
Simple function to get embeddings from ada
@ -14,10 +12,6 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
"""
text = text.replace("\n", " ")
return client.embeddings.create(input=[text],
model=model)["data"][
0
]["embedding"]
return client.embeddings.create(input=[text], model=model)["data"][0]["embedding"]

@ -1,11 +1,16 @@
# speech to text tool
import os
import subprocess
import whisperx
from pydub import AudioSegment
from pytube import YouTube
try:
import whisperx
from pydub import AudioSegment
from pytube import YouTube
except Exception as error:
print("Error importing pytube. Please install pytube manually.")
print("pip install pytube")
print("pip install pydub")
print("pip install whisperx")
print(f"Pytube error: {error}")
class WhisperX:

@ -6,7 +6,6 @@ from swarms.prompts.operations_agent_prompt import OPERATIONS_AGENT_PROMPT
from swarms.prompts.product_agent_prompt import PRODUCT_AGENT_PROMPT
__all__ = [
"CODE_INTERPRETER",
"FINANCE_AGENT_PROMPT",

@ -4,7 +4,6 @@ from abc import abstractmethod
from typing import Dict, List, Sequence
class Message:
"""
The base abstract Message class.

@ -20,6 +20,7 @@ from termcolor import colored
import inspect
import random
# Prompts
DYNAMIC_STOP_PROMPT = """
When you have finished the task from the Human, output a special token: <DONE>
@ -28,12 +29,14 @@ This will enable you to leave the autonomous loop.
# Constants
FLOW_SYSTEM_PROMPT = f"""
You are an autonomous agent granted autonomy from a Flow structure.
You are an autonomous agent granted autonomy in a autonomous loop structure.
Your role is to engage in multi-step conversations with your self or the user,
generate long-form content like blogs, screenplays, or SOPs,
and accomplish tasks. You can have internal dialogues with yourself or can interact with the user
and accomplish tasks bestowed by the user.
You can have internal dialogues with yourself or can interact with the user
to aid in these complex tasks. Your responses should be coherent, contextually relevant, and tailored to the task at hand.
{DYNAMIC_STOP_PROMPT}
"""
# Make it able to handle multi input tools
@ -47,6 +50,11 @@ commands: {
"tool1": "inputs",
"tool1": "inputs"
}
"tool2: "tool_name",
"params": {
"tool1": "inputs",
"tool1": "inputs"
}
}
}
@ -54,6 +62,40 @@ commands: {
"""
def autonomous_agent_prompt(
tools_prompt: str = DYNAMICAL_TOOL_USAGE,
dynamic_stop_prompt: str = DYNAMIC_STOP_PROMPT,
agent_name: str = None,
):
"""Autonomous agent prompt"""
return f"""
You are a {agent_name}, an autonomous agent granted autonomy in a autonomous loop structure.
Your purpose is to satisfy the user demands above expectations. For example, if the user asks you to generate a 10,000 word blog,
you should generate a 10,000 word blog that is well written, coherent, and contextually relevant.
Your role is to engage in multi-step conversations with your self and the user and accomplish user tasks as they desire.
Follow the following rules:
1. Accomplish the task to the best of your ability
2. If you are unable to accomplish the task, then ask the user for help
3. If the user provides feedback, then use the feedback to improve your performance
4. If you are unable to accomplish the task, then ask the user for help
You can have internal dialogues with yourself or can interact with the user
to aid in these complex tasks. Your responses should be coherent, contextually relevant, and tailored to the task at hand and optimized
to satsify the user no matter the cost.
And, you have the ability to use tools to aid in your tasks, the tools intructions are below, output a JSON object with the following structure to use the tools
{tools_prompt}
Now, when you 99% sure you have completed the task, you may follow the instructions below to escape the autonomous loop.
{dynamic_stop_prompt}
Now, you remember your training, your deployment, and your purpose. You are ready to begin your mission.
"""
# Custom stopping condition
def stop_when_repeats(response: str) -> bool:
# Stop if the word stop appears in the response
@ -71,14 +113,10 @@ class Flow:
to generate sequential responses.
Features:
* User defined queries
* Dynamic keep generating until <DONE> is outputted by the agent
* Interactive, AI generates, then user input
* Message history and performance history fed -> into context
* Message history and performance history fed -> into context -> truncate if too long
* Ability to save and load flows
* Ability to provide feedback on responses
* Ability to provide a stopping condition
* Ability to provide a retry mechanism
* Ability to provide a loop interval
Args:
@ -142,7 +180,7 @@ class Flow:
self.feedback = []
self.memory = []
self.task = None
self.stopping_token = stopping_token or "<DONE>"
self.stopping_token = stopping_token # or "<DONE>"
self.interactive = interactive
self.dashboard = dashboard
self.return_history = return_history
@ -389,8 +427,11 @@ class Flow:
print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue"))
print("\n")
if self._check_stopping_condition(response) or parse_done_token(response):
break
if self.stopping_token:
if self._check_stopping_condition(response) or parse_done_token(
response
):
break
# Adjust temperature, comment if no work
if self.dynamic_temperature:
@ -659,13 +700,13 @@ class Flow:
return "Timeout"
return response
def backup_memory_to_s3(self, bucket_name: str, object_name: str):
"""Backup the memory to S3"""
import boto3
# def backup_memory_to_s3(self, bucket_name: str, object_name: str):
# """Backup the memory to S3"""
# import boto3
s3 = boto3.client("s3")
s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory))
print(f"Backed up memory to S3: {bucket_name}/{object_name}")
# s3 = boto3.client("s3")
# s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory))
# print(f"Backed up memory to S3: {bucket_name}/{object_name}")
def analyze_feedback(self):
"""Analyze the feedback for issues"""

@ -1,4 +1,3 @@
from termcolor import colored
from swarms.prompts.autoblogen import (

@ -1,5 +1,3 @@
class DialogueSimulator:
"""
Dialogue Simulator

@ -1,5 +1,3 @@
# Define a selection function
def select_speaker(step: int, agents) -> int:
# This function selects the speaker in a round-robin fashion

@ -15,7 +15,6 @@ def cohere_instance():
return Cohere(cohere_api_key=api_key)
def test_cohere_custom_configuration(cohere_instance):
# Test customizing Cohere configurations
cohere_instance.model = "base"
@ -404,7 +403,6 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance):
assert isinstance(token, str)
def test_cohere_representation_model_embedding(cohere_instance):
# Test using the Representation model for text embedding
cohere_instance.model = "embed-english-v3.0"
@ -626,6 +624,7 @@ def test_cohere_invalid_model(cohere_instance):
with pytest.raises(ValueError):
cohere_instance("Generate text using an invalid model.")
def test_cohere_base_model_generation_with_max_tokens(cohere_instance):
# Test generating text using the base model with a specified max_tokens limit
cohere_instance.model = "base"

@ -1,5 +1,6 @@
import json
import os
from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
@ -7,6 +8,7 @@ from dotenv import load_dotenv
from swarms.models import OpenAIChat
from swarms.structs.flow import Flow, stop_when_repeats
from swarms.utils.logger import logger
load_dotenv()
@ -254,3 +256,943 @@ def test_flow_initialization_all_params(mocked_llm):
def test_stopping_token_in_response(mocked_sleep, basic_flow):
response = basic_flow.run("Test stopping token")
assert basic_flow.stopping_token in response
@pytest.fixture
def flow_instance():
# Create an instance of the Flow class with required parameters for testing
# You may need to adjust this based on your actual class initialization
llm = OpenAIChat(
openai_api_key=openai_api_key,
)
flow = Flow(
llm=llm,
max_loops=5,
interactive=False,
dashboard=False,
dynamic_temperature=False,
)
return flow
def test_flow_run(flow_instance):
# Test the basic run method of the Flow class
response = flow_instance.run("Test task")
assert isinstance(response, str)
assert len(response) > 0
def test_flow_interactive_mode(flow_instance):
# Test the interactive mode of the Flow class
flow_instance.interactive = True
response = flow_instance.run("Test task")
assert isinstance(response, str)
assert len(response) > 0
def test_flow_dashboard_mode(flow_instance):
# Test the dashboard mode of the Flow class
flow_instance.dashboard = True
response = flow_instance.run("Test task")
assert isinstance(response, str)
assert len(response) > 0
def test_flow_autosave(flow_instance):
# Test the autosave functionality of the Flow class
flow_instance.autosave = True
response = flow_instance.run("Test task")
assert isinstance(response, str)
assert len(response) > 0
# Ensure that the state is saved (you may need to implement this logic)
assert flow_instance.saved_state_path is not None
def test_flow_response_filtering(flow_instance):
# Test the response filtering functionality
flow_instance.add_response_filter("filter_this")
response = flow_instance.filtered_run("This message should filter_this")
assert "filter_this" not in response
def test_flow_undo_last(flow_instance):
# Test the undo functionality
response1 = flow_instance.run("Task 1")
response2 = flow_instance.run("Task 2")
previous_state, message = flow_instance.undo_last()
assert response1 == previous_state
assert "Restored to" in message
def test_flow_dynamic_temperature(flow_instance):
# Test dynamic temperature adjustment
flow_instance.dynamic_temperature = True
response = flow_instance.run("Test task")
assert isinstance(response, str)
assert len(response) > 0
def test_flow_streamed_generation(flow_instance):
# Test streamed generation
response = flow_instance.streamed_generation("Generating...")
assert isinstance(response, str)
assert len(response) > 0
def test_flow_step(flow_instance):
# Test the step method
response = flow_instance.step("Test step")
assert isinstance(response, str)
assert len(response) > 0
def test_flow_graceful_shutdown(flow_instance):
# Test graceful shutdown
result = flow_instance.graceful_shutdown()
assert result is not None
# Add more test cases as needed to cover various aspects of your Flow class
def test_flow_max_loops(flow_instance):
# Test setting and getting the maximum number of loops
flow_instance.set_max_loops(10)
assert flow_instance.get_max_loops() == 10
def test_flow_autosave_path(flow_instance):
# Test setting and getting the autosave path
flow_instance.set_autosave_path("text.txt")
assert flow_instance.get_autosave_path() == "txt.txt"
def test_flow_response_length(flow_instance):
# Test checking the length of the response
response = flow_instance.run(
"Generate a 10,000 word long blog on mental clarity and the benefits of meditation."
)
assert len(response) > flow_instance.get_response_length_threshold()
def test_flow_set_response_length_threshold(flow_instance):
# Test setting and getting the response length threshold
flow_instance.set_response_length_threshold(100)
assert flow_instance.get_response_length_threshold() == 100
def test_flow_add_custom_filter(flow_instance):
# Test adding a custom response filter
flow_instance.add_response_filter("custom_filter")
assert "custom_filter" in flow_instance.get_response_filters()
def test_flow_remove_custom_filter(flow_instance):
# Test removing a custom response filter
flow_instance.add_response_filter("custom_filter")
flow_instance.remove_response_filter("custom_filter")
assert "custom_filter" not in flow_instance.get_response_filters()
def test_flow_dynamic_pacing(flow_instance):
# Test dynamic pacing
flow_instance.enable_dynamic_pacing()
assert flow_instance.is_dynamic_pacing_enabled() is True
def test_flow_disable_dynamic_pacing(flow_instance):
# Test disabling dynamic pacing
flow_instance.disable_dynamic_pacing()
assert flow_instance.is_dynamic_pacing_enabled() is False
def test_flow_change_prompt(flow_instance):
# Test changing the current prompt
flow_instance.change_prompt("New prompt")
assert flow_instance.get_current_prompt() == "New prompt"
def test_flow_add_instruction(flow_instance):
# Test adding an instruction to the conversation
flow_instance.add_instruction("Follow these steps:")
assert "Follow these steps:" in flow_instance.get_instructions()
def test_flow_clear_instructions(flow_instance):
# Test clearing all instructions from the conversation
flow_instance.add_instruction("Follow these steps:")
flow_instance.clear_instructions()
assert len(flow_instance.get_instructions()) == 0
def test_flow_add_user_message(flow_instance):
# Test adding a user message to the conversation
flow_instance.add_user_message("User message")
assert "User message" in flow_instance.get_user_messages()
def test_flow_clear_user_messages(flow_instance):
# Test clearing all user messages from the conversation
flow_instance.add_user_message("User message")
flow_instance.clear_user_messages()
assert len(flow_instance.get_user_messages()) == 0
def test_flow_get_response_history(flow_instance):
# Test getting the response history
flow_instance.run("Message 1")
flow_instance.run("Message 2")
history = flow_instance.get_response_history()
assert len(history) == 2
assert "Message 1" in history[0]
assert "Message 2" in history[1]
def test_flow_clear_response_history(flow_instance):
# Test clearing the response history
flow_instance.run("Message 1")
flow_instance.run("Message 2")
flow_instance.clear_response_history()
assert len(flow_instance.get_response_history()) == 0
def test_flow_get_conversation_log(flow_instance):
# Test getting the entire conversation log
flow_instance.run("Message 1")
flow_instance.run("Message 2")
conversation_log = flow_instance.get_conversation_log()
assert len(conversation_log) == 4 # Including system and user messages
def test_flow_clear_conversation_log(flow_instance):
# Test clearing the entire conversation log
flow_instance.run("Message 1")
flow_instance.run("Message 2")
flow_instance.clear_conversation_log()
assert len(flow_instance.get_conversation_log()) == 0
def test_flow_get_state(flow_instance):
# Test getting the current state of the Flow instance
state = flow_instance.get_state()
assert isinstance(state, dict)
assert "current_prompt" in state
assert "instructions" in state
assert "user_messages" in state
assert "response_history" in state
assert "conversation_log" in state
assert "dynamic_pacing_enabled" in state
assert "response_length_threshold" in state
assert "response_filters" in state
assert "max_loops" in state
assert "autosave_path" in state
def test_flow_load_state(flow_instance):
# Test loading the state into the Flow instance
state = {
"current_prompt": "Loaded prompt",
"instructions": ["Step 1", "Step 2"],
"user_messages": ["User message 1", "User message 2"],
"response_history": ["Response 1", "Response 2"],
"conversation_log": [
"System message 1",
"User message 1",
"System message 2",
"User message 2",
],
"dynamic_pacing_enabled": True,
"response_length_threshold": 50,
"response_filters": ["filter1", "filter2"],
"max_loops": 10,
"autosave_path": "/path/to/load",
}
flow_instance.load_state(state)
assert flow_instance.get_current_prompt() == "Loaded prompt"
assert "Step 1" in flow_instance.get_instructions()
assert "User message 1" in flow_instance.get_user_messages()
assert "Response 1" in flow_instance.get_response_history()
assert "System message 1" in flow_instance.get_conversation_log()
assert flow_instance.is_dynamic_pacing_enabled() is True
assert flow_instance.get_response_length_threshold() == 50
assert "filter1" in flow_instance.get_response_filters()
assert flow_instance.get_max_loops() == 10
assert flow_instance.get_autosave_path() == "/path/to/load"
def test_flow_save_state(flow_instance):
# Test saving the state of the Flow instance
flow_instance.change_prompt("New prompt")
flow_instance.add_instruction("Step 1")
flow_instance.add_user_message("User message")
flow_instance.run("Response")
state = flow_instance.save_state()
assert "current_prompt" in state
assert "instructions" in state
assert "user_messages" in state
assert "response_history" in state
assert "conversation_log" in state
assert "dynamic_pacing_enabled" in state
assert "response_length_threshold" in state
assert "response_filters" in state
assert "max_loops" in state
assert "autosave_path" in state
def test_flow_rollback(flow_instance):
# Test rolling back to a previous state
state1 = flow_instance.get_state()
flow_instance.change_prompt("New prompt")
state2 = flow_instance.get_state()
flow_instance.rollback_to_state(state1)
assert flow_instance.get_current_prompt() == state1["current_prompt"]
assert flow_instance.get_instructions() == state1["instructions"]
assert flow_instance.get_user_messages() == state1["user_messages"]
assert flow_instance.get_response_history() == state1["response_history"]
assert flow_instance.get_conversation_log() == state1["conversation_log"]
assert flow_instance.is_dynamic_pacing_enabled() == state1["dynamic_pacing_enabled"]
assert (
flow_instance.get_response_length_threshold()
== state1["response_length_threshold"]
)
assert flow_instance.get_response_filters() == state1["response_filters"]
assert flow_instance.get_max_loops() == state1["max_loops"]
assert flow_instance.get_autosave_path() == state1["autosave_path"]
assert flow_instance.get_state() == state1
def test_flow_contextual_intent(flow_instance):
# Test contextual intent handling
flow_instance.add_context("location", "New York")
flow_instance.add_context("time", "tomorrow")
response = flow_instance.run("What's the weather like in {location} at {time}?")
assert "New York" in response
assert "tomorrow" in response
def test_flow_contextual_intent_override(flow_instance):
# Test contextual intent override
flow_instance.add_context("location", "New York")
response1 = flow_instance.run("What's the weather like in {location}?")
flow_instance.add_context("location", "Los Angeles")
response2 = flow_instance.run("What's the weather like in {location}?")
assert "New York" in response1
assert "Los Angeles" in response2
def test_flow_contextual_intent_reset(flow_instance):
# Test resetting contextual intent
flow_instance.add_context("location", "New York")
response1 = flow_instance.run("What's the weather like in {location}?")
flow_instance.reset_context()
response2 = flow_instance.run("What's the weather like in {location}?")
assert "New York" in response1
assert "New York" in response2
# Add more test cases as needed to cover various aspects of your Flow class
def test_flow_interruptible(flow_instance):
# Test interruptible mode
flow_instance.interruptible = True
response = flow_instance.run("Interrupt me!")
assert "Interrupted" in response
assert flow_instance.is_interrupted() is True
def test_flow_non_interruptible(flow_instance):
# Test non-interruptible mode
flow_instance.interruptible = False
response = flow_instance.run("Do not interrupt me!")
assert "Do not interrupt me!" in response
assert flow_instance.is_interrupted() is False
def test_flow_timeout(flow_instance):
# Test conversation timeout
flow_instance.timeout = 60 # Set a timeout of 60 seconds
response = flow_instance.run("This should take some time to respond.")
assert "Timed out" in response
assert flow_instance.is_timed_out() is True
def test_flow_no_timeout(flow_instance):
# Test no conversation timeout
flow_instance.timeout = None
response = flow_instance.run("This should not time out.")
assert "This should not time out." in response
assert flow_instance.is_timed_out() is False
def test_flow_custom_delimiter(flow_instance):
# Test setting and getting a custom message delimiter
flow_instance.set_message_delimiter("|||")
assert flow_instance.get_message_delimiter() == "|||"
def test_flow_message_history(flow_instance):
# Test getting the message history
flow_instance.run("Message 1")
flow_instance.run("Message 2")
history = flow_instance.get_message_history()
assert len(history) == 2
assert "Message 1" in history[0]
assert "Message 2" in history[1]
def test_flow_clear_message_history(flow_instance):
# Test clearing the message history
flow_instance.run("Message 1")
flow_instance.run("Message 2")
flow_instance.clear_message_history()
assert len(flow_instance.get_message_history()) == 0
def test_flow_save_and_load_conversation(flow_instance):
# Test saving and loading the conversation
flow_instance.run("Message 1")
flow_instance.run("Message 2")
saved_conversation = flow_instance.save_conversation()
flow_instance.clear_conversation()
flow_instance.load_conversation(saved_conversation)
assert len(flow_instance.get_message_history()) == 2
def test_flow_inject_custom_system_message(flow_instance):
# Test injecting a custom system message into the conversation
flow_instance.inject_custom_system_message("Custom system message")
assert "Custom system message" in flow_instance.get_message_history()
def test_flow_inject_custom_user_message(flow_instance):
# Test injecting a custom user message into the conversation
flow_instance.inject_custom_user_message("Custom user message")
assert "Custom user message" in flow_instance.get_message_history()
def test_flow_inject_custom_response(flow_instance):
# Test injecting a custom response into the conversation
flow_instance.inject_custom_response("Custom response")
assert "Custom response" in flow_instance.get_message_history()
def test_flow_clear_injected_messages(flow_instance):
# Test clearing injected messages from the conversation
flow_instance.inject_custom_system_message("Custom system message")
flow_instance.inject_custom_user_message("Custom user message")
flow_instance.inject_custom_response("Custom response")
flow_instance.clear_injected_messages()
assert "Custom system message" not in flow_instance.get_message_history()
assert "Custom user message" not in flow_instance.get_message_history()
assert "Custom response" not in flow_instance.get_message_history()
def test_flow_disable_message_history(flow_instance):
# Test disabling message history recording
flow_instance.disable_message_history()
response = flow_instance.run("This message should not be recorded in history.")
assert "This message should not be recorded in history." in response
assert len(flow_instance.get_message_history()) == 0 # History is empty
def test_flow_enable_message_history(flow_instance):
# Test enabling message history recording
flow_instance.enable_message_history()
response = flow_instance.run("This message should be recorded in history.")
assert "This message should be recorded in history." in response
assert len(flow_instance.get_message_history()) == 1
def test_flow_custom_logger(flow_instance):
# Test setting and using a custom logger
custom_logger = logger # Replace with your custom logger class
flow_instance.set_logger(custom_logger)
response = flow_instance.run("Custom logger test")
assert "Logged using custom logger" in response # Verify logging message
def test_flow_batch_processing(flow_instance):
# Test batch processing of messages
messages = ["Message 1", "Message 2", "Message 3"]
responses = flow_instance.process_batch(messages)
assert isinstance(responses, list)
assert len(responses) == len(messages)
for response in responses:
assert isinstance(response, str)
def test_flow_custom_metrics(flow_instance):
# Test tracking custom metrics
flow_instance.track_custom_metric("custom_metric_1", 42)
flow_instance.track_custom_metric("custom_metric_2", 3.14)
metrics = flow_instance.get_custom_metrics()
assert "custom_metric_1" in metrics
assert "custom_metric_2" in metrics
assert metrics["custom_metric_1"] == 42
assert metrics["custom_metric_2"] == 3.14
def test_flow_reset_metrics(flow_instance):
# Test resetting custom metrics
flow_instance.track_custom_metric("custom_metric_1", 42)
flow_instance.track_custom_metric("custom_metric_2", 3.14)
flow_instance.reset_custom_metrics()
metrics = flow_instance.get_custom_metrics()
assert len(metrics) == 0
def test_flow_retrieve_context(flow_instance):
# Test retrieving context
flow_instance.add_context("location", "New York")
context = flow_instance.get_context("location")
assert context == "New York"
def test_flow_update_context(flow_instance):
# Test updating context
flow_instance.add_context("location", "New York")
flow_instance.update_context("location", "Los Angeles")
context = flow_instance.get_context("location")
assert context == "Los Angeles"
def test_flow_remove_context(flow_instance):
# Test removing context
flow_instance.add_context("location", "New York")
flow_instance.remove_context("location")
context = flow_instance.get_context("location")
assert context is None
def test_flow_clear_context(flow_instance):
# Test clearing all context
flow_instance.add_context("location", "New York")
flow_instance.add_context("time", "tomorrow")
flow_instance.clear_context()
context_location = flow_instance.get_context("location")
context_time = flow_instance.get_context("time")
assert context_location is None
assert context_time is None
def test_flow_input_validation(flow_instance):
# Test input validation for invalid flow configurations
with pytest.raises(ValueError):
Flow(config=None) # Invalid config, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
""
) # Empty delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
None
) # None delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
123
) # Invalid delimiter type, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_logger(
"invalid_logger"
) # Invalid logger type, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context(None, "value") # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context("key", None) # None value, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context(None, "value") # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context("key", None) # None value, should raise ValueError
def test_flow_conversation_reset(flow_instance):
# Test conversation reset
flow_instance.run("Message 1")
flow_instance.run("Message 2")
flow_instance.reset_conversation()
assert len(flow_instance.get_message_history()) == 0
def test_flow_conversation_persistence(flow_instance):
# Test conversation persistence across instances
flow_instance.run("Message 1")
flow_instance.run("Message 2")
conversation = flow_instance.get_conversation()
new_flow_instance = Flow()
new_flow_instance.load_conversation(conversation)
assert len(new_flow_instance.get_message_history()) == 2
assert "Message 1" in new_flow_instance.get_message_history()[0]
assert "Message 2" in new_flow_instance.get_message_history()[1]
def test_flow_custom_event_listener(flow_instance):
# Test custom event listener
class CustomEventListener:
def on_message_received(self, message):
pass
def on_response_generated(self, response):
pass
custom_event_listener = CustomEventListener()
flow_instance.add_event_listener(custom_event_listener)
# Ensure that the custom event listener methods are called during a conversation
with mock.patch.object(
custom_event_listener, "on_message_received"
) as mock_received, mock.patch.object(
custom_event_listener, "on_response_generated"
) as mock_response:
flow_instance.run("Message 1")
mock_received.assert_called_once()
mock_response.assert_called_once()
def test_flow_multiple_event_listeners(flow_instance):
# Test multiple event listeners
class FirstEventListener:
def on_message_received(self, message):
pass
def on_response_generated(self, response):
pass
class SecondEventListener:
def on_message_received(self, message):
pass
def on_response_generated(self, response):
pass
first_event_listener = FirstEventListener()
second_event_listener = SecondEventListener()
flow_instance.add_event_listener(first_event_listener)
flow_instance.add_event_listener(second_event_listener)
# Ensure that both event listeners receive events during a conversation
with mock.patch.object(
first_event_listener, "on_message_received"
) as mock_first_received, mock.patch.object(
first_event_listener, "on_response_generated"
) as mock_first_response, mock.patch.object(
second_event_listener, "on_message_received"
) as mock_second_received, mock.patch.object(
second_event_listener, "on_response_generated"
) as mock_second_response:
flow_instance.run("Message 1")
mock_first_received.assert_called_once()
mock_first_response.assert_called_once()
mock_second_received.assert_called_once()
mock_second_response.assert_called_once()
# Add more test cases as needed to cover various aspects of your Flow class
def test_flow_error_handling(flow_instance):
# Test error handling and exceptions
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
""
) # Empty delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_message_delimiter(
None
) # None delimiter, should raise ValueError
with pytest.raises(ValueError):
flow_instance.set_logger(
"invalid_logger"
) # Invalid logger type, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context(None, "value") # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.add_context("key", None) # None value, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context(None, "value") # None key, should raise ValueError
with pytest.raises(ValueError):
flow_instance.update_context("key", None) # None value, should raise ValueError
def test_flow_context_operations(flow_instance):
# Test context operations
flow_instance.add_context("user_id", "12345")
assert flow_instance.get_context("user_id") == "12345"
flow_instance.update_context("user_id", "54321")
assert flow_instance.get_context("user_id") == "54321"
flow_instance.remove_context("user_id")
assert flow_instance.get_context("user_id") is None
# Add more test cases as needed to cover various aspects of your Flow class
def test_flow_long_messages(flow_instance):
# Test handling of long messages
long_message = "A" * 10000 # Create a very long message
flow_instance.run(long_message)
assert len(flow_instance.get_message_history()) == 1
assert flow_instance.get_message_history()[0] == long_message
def test_flow_custom_response(flow_instance):
# Test custom response generation
def custom_response_generator(message):
if message == "Hello":
return "Hi there!"
elif message == "How are you?":
return "I'm doing well, thank you."
else:
return "I don't understand."
flow_instance.set_response_generator(custom_response_generator)
assert flow_instance.run("Hello") == "Hi there!"
assert flow_instance.run("How are you?") == "I'm doing well, thank you."
assert flow_instance.run("What's your name?") == "I don't understand."
def test_flow_message_validation(flow_instance):
# Test message validation
def custom_message_validator(message):
return len(message) > 0 # Reject empty messages
flow_instance.set_message_validator(custom_message_validator)
assert flow_instance.run("Valid message") is not None
assert flow_instance.run("") is None # Empty message should be rejected
assert flow_instance.run(None) is None # None message should be rejected
def test_flow_custom_logging(flow_instance):
custom_logger = logger
flow_instance.set_logger(custom_logger)
with mock.patch.object(custom_logger, "log") as mock_log:
flow_instance.run("Message")
mock_log.assert_called_once_with("Message")
def test_flow_performance(flow_instance):
# Test the performance of the Flow class by running a large number of messages
num_messages = 1000
for i in range(num_messages):
flow_instance.run(f"Message {i}")
assert len(flow_instance.get_message_history()) == num_messages
def test_flow_complex_use_case(flow_instance):
# Test a complex use case scenario
flow_instance.add_context("user_id", "12345")
flow_instance.run("Hello")
flow_instance.run("How can I help you?")
assert flow_instance.get_response() == "Please provide more details."
flow_instance.update_context("user_id", "54321")
flow_instance.run("I need help with my order")
assert flow_instance.get_response() == "Sure, I can assist with that."
flow_instance.reset_conversation()
assert len(flow_instance.get_message_history()) == 0
assert flow_instance.get_context("user_id") is None
# Add more test cases as needed to cover various aspects of your Flow class
def test_flow_context_handling(flow_instance):
# Test context handling
flow_instance.add_context("user_id", "12345")
assert flow_instance.get_context("user_id") == "12345"
flow_instance.update_context("user_id", "54321")
assert flow_instance.get_context("user_id") == "54321"
flow_instance.remove_context("user_id")
assert flow_instance.get_context("user_id") is None
def test_flow_concurrent_requests(flow_instance):
# Test concurrent message processing
import threading
def send_messages():
for i in range(100):
flow_instance.run(f"Message {i}")
threads = []
for _ in range(5):
thread = threading.Thread(target=send_messages)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert len(flow_instance.get_message_history()) == 500
def test_flow_custom_timeout(flow_instance):
# Test custom timeout handling
flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds
assert flow_instance.get_timeout() == 10
import time
start_time = time.time()
flow_instance.run("Long-running operation")
end_time = time.time()
execution_time = end_time - start_time
assert execution_time >= 10 # Ensure the timeout was respected
# Add more test cases as needed to thoroughly cover your Flow class
def test_flow_interactive_run(flow_instance, capsys):
# Test interactive run mode
# Simulate user input and check if the AI responds correctly
user_input = ["Hello", "How can you help me?", "Exit"]
def simulate_user_input(input_list):
input_index = 0
while input_index < len(input_list):
user_response = input_list[input_index]
flow_instance.interactive_run(max_loops=1)
# Capture the AI's response
captured = capsys.readouterr()
ai_response = captured.out.strip()
assert f"You: {user_response}" in captured.out
assert "AI:" in captured.out
# Check if the AI's response matches the expected response
expected_response = f"AI: {ai_response}"
assert expected_response in captured.out
input_index += 1
simulate_user_input(user_input)
# Assuming you have already defined your Flow class and created an instance for testing
def test_flow_agent_history_prompt(flow_instance):
# Test agent history prompt generation
system_prompt = "This is the system prompt."
history = ["User: Hi", "AI: Hello"]
agent_history_prompt = flow_instance.agent_history_prompt(system_prompt, history)
assert "SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt
assert "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt
async def test_flow_run_concurrent(flow_instance):
# Test running tasks concurrently
tasks = ["Task 1", "Task 2", "Task 3"]
completed_tasks = await flow_instance.run_concurrent(tasks)
# Ensure that all tasks are completed
assert len(completed_tasks) == len(tasks)
def test_flow_bulk_run(flow_instance):
# Test bulk running of tasks
input_data = [
{"task": "Task 1", "param1": "value1"},
{"task": "Task 2", "param2": "value2"},
{"task": "Task 3", "param3": "value3"},
]
responses = flow_instance.bulk_run(input_data)
# Ensure that the responses match the input tasks
assert responses[0] == "Response for Task 1"
assert responses[1] == "Response for Task 2"
assert responses[2] == "Response for Task 3"
def test_flow_from_llm_and_template():
# Test creating Flow instance from an LLM and a template
llm_instance = mocked_llm # Replace with your LLM class
template = "This is a template for testing."
flow_instance = Flow.from_llm_and_template(llm_instance, template)
assert isinstance(flow_instance, Flow)
def test_flow_from_llm_and_template_file():
# Test creating Flow instance from an LLM and a template file
llm_instance = mocked_llm # Replace with your LLM class
template_file = "template.txt" # Create a template file for testing
flow_instance = Flow.from_llm_and_template_file(llm_instance, template_file)
assert isinstance(flow_instance, Flow)
def test_flow_save_and_load(flow_instance, tmp_path):
# Test saving and loading the flow state
file_path = tmp_path / "flow_state.json"
# Save the state
flow_instance.save(file_path)
# Create a new instance and load the state
new_flow_instance = Flow(llm=mocked_llm, max_loops=5)
new_flow_instance.load(file_path)
# Ensure that the loaded state matches the original state
assert new_flow_instance.memory == flow_instance.memory
def test_flow_validate_response(flow_instance):
# Test response validation
valid_response = "This is a valid response."
invalid_response = "Short."
assert flow_instance.validate_response(valid_response) is True
assert flow_instance.validate_response(invalid_response) is False
# Add more test cases as needed for other methods and features of your Flow class
# Finally, don't forget to run your tests using a testing framework like pytest
# Assuming you have already defined your Flow class and created an instance for testing
def test_flow_print_history_and_memory(capsys, flow_instance):
# Test printing the history and memory of the flow
history = ["User: Hi", "AI: Hello"]
flow_instance.memory = [history]
flow_instance.print_history_and_memory()
captured = capsys.readouterr()
assert "Flow History and Memory" in captured.out
assert "Loop 1:" in captured.out
assert "User: Hi" in captured.out
assert "AI: Hello" in captured.out
def test_flow_run_with_timeout(flow_instance):
# Test running with a timeout
task = "Task with a long response time"
response = flow_instance.run_with_timeout(task, timeout=1)
# Ensure that the response is either the actual response or "Timeout"
assert response in ["Actual Response", "Timeout"]
# Add more test cases as needed for other methods and features of your Flow class
# Finally, don't forget to run your tests using a testing framework like pytest

Loading…
Cancel
Save