From d6b037c2118e4be6940733fa23ddcd4e46bdd463 Mon Sep 17 00:00:00 2001 From: Kye Date: Tue, 14 Nov 2023 14:54:13 -0500 Subject: [PATCH] mistral caller, openai verison 2.8, llama function caller, tests for flow Former-commit-id: 699c94339466c0168075ac6ff5955a704fae66c8 --- .gitignore | 1 + README.md | 74 ++- example.py | 2 +- playground/models/dall3.py | 6 +- pyproject.toml | 1 + requirements.txt | 1 + swarms/artifacts/main.py | 14 + swarms/memory/pg.py | 2 +- swarms/models/__init__.py | 25 +- swarms/models/dalle3.py | 2 +- swarms/models/fuyu.py | 4 +- swarms/models/huggingface.py | 62 +- swarms/models/openai_chat.py | 676 -------------------- swarms/models/openai_models.py | 77 +-- swarms/models/openai_tokenizer.py | 148 ----- swarms/models/simple_ada.py | 8 +- swarms/models/whisperx.py | 15 +- swarms/prompts/__init__.py | 1 - swarms/prompts/chat_prompt.py | 1 - swarms/structs/flow.py | 75 ++- swarms/swarms/autobloggen.py | 1 - swarms/swarms/dialogue_simulator.py | 2 - swarms/swarms/multi_agent_debate.py | 2 - tests/models/cohere.py | 3 +- tests/structs/flow.py | 942 ++++++++++++++++++++++++++++ 25 files changed, 1199 insertions(+), 946 deletions(-) delete mode 100644 swarms/models/openai_chat.py delete mode 100644 swarms/models/openai_tokenizer.py diff --git a/.gitignore b/.gitignore index 767abb9d..8f8a98a8 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ stderr_log.txt __pycache__/ *.py[cod] *$py.class +.grit error.txt # C extensions diff --git a/README.md b/README.md index 3654d12c..3b3155b5 100644 --- a/README.md +++ b/README.md @@ -74,10 +74,7 @@ flow = Flow( # out = flow.load_state("flow_state.json") # temp = flow.dynamic_temperature() # filter = flow.add_response_filter("Trump") - -# Run the flow out = flow.run("Generate a 10,000 word blog on health and wellness.") - # out = flow.validate_response(out) # out = flow.analyze_feedback(out) # out = flow.print_history_and_memory() @@ -138,6 +135,77 @@ for task in workflow.tasks: --- +# Features 🤖 +The Swarms framework is designed with a strong emphasis on reliability, performance, and production-grade readiness. +Below are the key features that make Swarms an ideal choice for enterprise-level AI deployments. + +## 🚀 Production-Grade Readiness +- **Scalable Architecture**: Built to scale effortlessly with your growing business needs. +- **Enterprise-Level Security**: Incorporates top-notch security features to safeguard your data and operations. +- **Containerization and Microservices**: Easily deployable in containerized environments, supporting microservices architecture. + +## ⚙️ Reliability and Robustness +- **Fault Tolerance**: Designed to handle failures gracefully, ensuring uninterrupted operations. +- **Consistent Performance**: Maintains high performance even under heavy loads or complex computational demands. +- **Automated Backup and Recovery**: Features automatic backup and recovery processes, reducing the risk of data loss. + +## 💡 Advanced AI Capabilities + +The Swarms framework is equipped with a suite of advanced AI capabilities designed to cater to a wide range of applications and scenarios, ensuring versatility and cutting-edge performance. + +### Multi-Modal Autonomous Agents +- **Versatile Model Support**: Seamlessly works with various AI models, including NLP, computer vision, and more, for comprehensive multi-modal capabilities. +- **Context-Aware Processing**: Employs context-aware processing techniques to ensure relevant and accurate responses from agents. + +### Function Calling Models for API Execution +- **Automated API Interactions**: Function calling models that can autonomously execute API calls, enabling seamless integration with external services and data sources. +- **Dynamic Response Handling**: Capable of processing and adapting to responses from APIs for real-time decision making. + +### Varied Architectures of Swarms +- **Flexible Configuration**: Supports multiple swarm architectures, from centralized to decentralized, for diverse application needs. +- **Customizable Agent Roles**: Allows customization of agent roles and behaviors within the swarm to optimize performance and efficiency. + +### Generative Models +- **Advanced Generative Capabilities**: Incorporates state-of-the-art generative models to create content, simulate scenarios, or predict outcomes. +- **Creative Problem Solving**: Utilizes generative AI for innovative problem-solving approaches and idea generation. + +### Enhanced Decision-Making +- **AI-Powered Decision Algorithms**: Employs advanced algorithms for swift and effective decision-making in complex scenarios. +- **Risk Assessment and Management**: Capable of assessing risks and managing uncertain situations with AI-driven insights. + +### Real-Time Adaptation and Learning +- **Continuous Learning**: Agents can continuously learn and adapt from new data, improving their performance and accuracy over time. +- **Environment Adaptability**: Designed to adapt to different operational environments, enhancing robustness and reliability. + + +## 🔄 Efficient Workflow Automation +- **Streamlined Task Management**: Simplifies complex tasks with automated workflows, reducing manual intervention. +- **Customizable Workflows**: Offers customizable workflow options to fit specific business needs and requirements. +- **Real-Time Analytics and Reporting**: Provides real-time insights into agent performance and system health. + +## 🌐 Wide-Ranging Integration +- **API-First Design**: Easily integrates with existing systems and third-party applications via robust APIs. +- **Cloud Compatibility**: Fully compatible with major cloud platforms for flexible deployment options. +- **Continuous Integration/Continuous Deployment (CI/CD)**: Supports CI/CD practices for seamless updates and deployment. + +## 📊 Performance Optimization +- **Resource Management**: Efficiently manages computational resources for optimal performance. +- **Load Balancing**: Automatically balances workloads to maintain system stability and responsiveness. +- **Performance Monitoring Tools**: Includes comprehensive monitoring tools for tracking and optimizing performance. + +## 🛡️ Security and Compliance +- **Data Encryption**: Implements end-to-end encryption for data at rest and in transit. +- **Compliance Standards Adherence**: Adheres to major compliance standards ensuring legal and ethical usage. +- **Regular Security Updates**: Regular updates to address emerging security threats and vulnerabilities. + +## 💬 Community and Support +- **Extensive Documentation**: Detailed documentation for easy implementation and troubleshooting. +- **Active Developer Community**: A vibrant community for sharing ideas, solutions, and best practices. +- **Professional Support**: Access to professional support for enterprise-level assistance and guidance. + +Swarms framework is not just a tool but a robust, scalable, and secure partner in your AI journey, ready to tackle the challenges of modern AI applications in a business environment. + + ## Documentation - For documentation, go here, [swarms.apac.ai](https://swarms.apac.ai) diff --git a/example.py b/example.py index 6c27bceb..84e04ad9 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,7 @@ from swarms.models import OpenAIChat from swarms.structs import Flow -api_key = "" +api_key = "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" # Initialize the language model, this model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC llm = OpenAIChat( diff --git a/playground/models/dall3.py b/playground/models/dall3.py index 7a17400d..2ea2e10c 100644 --- a/playground/models/dall3.py +++ b/playground/models/dall3.py @@ -1,8 +1,6 @@ from swarms.models import Dalle3 -dalle3 = Dalle3( - openai_api_key="" -) +dalle3 = Dalle3(openai_api_key="") task = "A painting of a dog" image_url = dalle3(task) -print(image_url) \ No newline at end of file +print(image_url) diff --git a/pyproject.toml b/pyproject.toml index 995551f3..00860fdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ wget = "*" griptape = "*" httpx = "*" tiktoken = "*" +safetensors = "*" attrs = "*" ggl = "*" ratelimit = "*" diff --git a/requirements.txt b/requirements.txt index f9e4a9a3..6d542159 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,6 +48,7 @@ opencv-python-headless imageio-ffmpeg invisible-watermark kornia +safetensors numpy omegaconf open_clip_torch diff --git a/swarms/artifacts/main.py b/swarms/artifacts/main.py index 4b240b22..075cd34d 100644 --- a/swarms/artifacts/main.py +++ b/swarms/artifacts/main.py @@ -10,6 +10,20 @@ class Artifact(BaseModel): """ Artifact that has the task has been produced + + Attributes: + ----------- + + artifact_id: str + ID of the artifact + + file_name: str + Filename of the artifact + + relative_path: str + Relative path of the artifact + + """ artifact_id: StrictStr = Field(..., description="ID of the artifact") diff --git a/swarms/memory/pg.py b/swarms/memory/pg.py index bd768459..a421c887 100644 --- a/swarms/memory/pg.py +++ b/swarms/memory/pg.py @@ -2,7 +2,7 @@ import uuid from typing import Optional from attr import define, field, Factory from dataclasses import dataclass -from swarms.memory.vector_stores.base import BaseVectorStore +from swarms.memory.base import BaseVectorStore from sqlalchemy.engine import Engine from sqlalchemy import create_engine, Column, String, JSON from sqlalchemy.ext.declarative import declarative_base diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index 21acb23c..4595cd53 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -6,20 +6,21 @@ sys.stderr = log_file # LLMs from swarms.models.anthropic import Anthropic # noqa: E402 from swarms.models.petals import Petals # noqa: E402 -from swarms.models.mistral import Mistral # noqa: E402 -from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat # noqa: E402 -from swarms.models.zephyr import Zephyr # noqa: E402 -from swarms.models.biogpt import BioGPT # noqa: E402 -from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 -from swarms.models.wizard_storytelling import WizardLLMStoryTeller # noqa: E402 -from swarms.models.mpt import MPT7B # noqa: E402 +from swarms.models.mistral import Mistral # noqa: E402 +from swarms.models.openai_models import OpenAI, AzureOpenAI, OpenAIChat # noqa: E402 +from swarms.models.zephyr import Zephyr # noqa: E402 +from swarms.models.biogpt import BioGPT # noqa: E402 +from swarms.models.huggingface import HuggingfaceLLM # noqa: E402 +from swarms.models.wizard_storytelling import WizardLLMStoryTeller # noqa: E402 +from swarms.models.mpt import MPT7B # noqa: E402 # MultiModal Models -from swarms.models.idefics import Idefics # noqa: E402 -from swarms.models.kosmos_two import Kosmos # noqa: E402 -from swarms.models.vilt import Vilt # noqa: E402 -from swarms.models.nougat import Nougat # noqa: E402 -from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402 +from swarms.models.idefics import Idefics # noqa: E402 +from swarms.models.kosmos_two import Kosmos # noqa: E402 +from swarms.models.vilt import Vilt # noqa: E402 +from swarms.models.nougat import Nougat # noqa: E402 +from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402 + # from swarms.models.gpt4v import GPT4Vision # from swarms.models.dalle3 import Dalle3 # from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402 diff --git a/swarms/models/dalle3.py b/swarms/models/dalle3.py index e6b345ae..7d9bcf5d 100644 --- a/swarms/models/dalle3.py +++ b/swarms/models/dalle3.py @@ -1,4 +1,3 @@ - import concurrent.futures import logging import os @@ -6,6 +5,7 @@ import uuid from dataclasses import dataclass from io import BytesIO from typing import List + import backoff import openai import requests diff --git a/swarms/models/fuyu.py b/swarms/models/fuyu.py index bba2068c..02ab3a25 100644 --- a/swarms/models/fuyu.py +++ b/swarms/models/fuyu.py @@ -75,7 +75,7 @@ class Fuyu: def get_img_from_web(self, img_url: str): """Get the image from the web""" - try: + try: response = requests.get(img_url) response.raise_for_status() image_pil = Image.open(BytesIO(response.content)) @@ -83,5 +83,3 @@ class Fuyu: except requests.RequestException as error: print(f"Error fetching image from {img_url} and error: {error}") return None - - \ No newline at end of file diff --git a/swarms/models/huggingface.py b/swarms/models/huggingface.py index 9279fea4..82a91783 100644 --- a/swarms/models/huggingface.py +++ b/swarms/models/huggingface.py @@ -1,9 +1,13 @@ +import asyncio +import concurrent.futures import logging +from typing import List, Tuple + import torch +from termcolor import colored from torch.nn.parallel import DistributedDataParallel as DDP from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from termcolor import colored class HuggingfaceLLM: @@ -43,6 +47,12 @@ class HuggingfaceLLM: # logger=None, distributed=False, decoding=False, + max_workers: int = 5, + repitition_penalty: float = 1.3, + no_repeat_ngram_size: int = 5, + temperature: float = 0.7, + top_k: int = 40, + top_p: float = 0.8, *args, **kwargs, ): @@ -56,6 +66,14 @@ class HuggingfaceLLM: self.distributed = distributed self.decoding = decoding self.model, self.tokenizer = None, None + self.quantize = quantize + self.quantization_config = quantization_config + self.max_workers = max_workers + self.repitition_penalty = repitition_penalty + self.no_repeat_ngram_size = no_repeat_ngram_size + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p if self.distributed: assert ( @@ -91,6 +109,10 @@ class HuggingfaceLLM: """Print error""" print(colored(f"Error: {error}", "red")) + async def async_run(self, task: str): + """Ashcnronous generate text for a given prompt""" + return await asyncio.to_thread(self.run, task) + def load_model(self): """Load the model""" if not self.model or not self.tokenizer: @@ -113,6 +135,21 @@ class HuggingfaceLLM: self.logger.error(f"Failed to load the model or the tokenizer: {error}") raise + def concurrent_run(self, tasks: List[str], max_workers: int = 5): + """Concurrently generate text for a list of prompts.""" + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map(self.run, tasks)) + return results + + def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]: + """Process a batch of tasks and images""" + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.run, task, img) for task, img in tasks_images + ] + results = [future.result() for future in futures] + return results + def run(self, task: str): """ Generate a response based on the prompt text. @@ -175,29 +212,6 @@ class HuggingfaceLLM: ) raise - async def run_async(self, task: str, *args, **kwargs) -> str: - """ - Run the model asynchronously - - Args: - task (str): Task to run. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Examples: - >>> mpt_instance = MPT('mosaicml/mpt-7b-storywriter', "EleutherAI/gpt-neox-20b", max_tokens=150) - >>> mpt_instance("generate", "Once upon a time in a land far, far away...") - 'Once upon a time in a land far, far away...' - >>> mpt_instance.batch_generate(["In the deep jungles,", "At the heart of the city,"], temperature=0.7) - ['In the deep jungles,', - 'At the heart of the city,'] - >>> mpt_instance.freeze_model() - >>> mpt_instance.unfreeze_model() - - """ - # Wrapping synchronous calls with async - return self.run(task, *args, **kwargs) - def __call__(self, task: str): """ Generate a response based on the prompt text. diff --git a/swarms/models/openai_chat.py b/swarms/models/openai_chat.py deleted file mode 100644 index 46057a45..00000000 --- a/swarms/models/openai_chat.py +++ /dev/null @@ -1,676 +0,0 @@ -from __future__ import annotations - -import logging -import os -import sys -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict -from langchain.callbacks.manager import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain.chat_models.base import ( - BaseChatModel, -) -from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import BaseModel, Field, root_validator -from langchain.schema import ChatGeneration, ChatResult -from langchain.schema.language_model import LanguageModelInput -from langchain.schema.messages import ( - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessageChunk, - FunctionMessageChunk, - HumanMessageChunk, - SystemMessageChunk, - ToolMessageChunk, -) -from langchain.schema.output import ChatGenerationChunk -from langchain.schema.runnable import Runnable -from langchain.utils import ( - get_from_dict_or_env, - get_pydantic_field_names, -) -from langchain.utils.openai import is_openai_v1 - -if TYPE_CHECKING: - import tiktoken - - -logger = logging.getLogger(__name__) - - -def _generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: - generation: Optional[ChatGenerationChunk] = None - for chunk in stream: - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) - - -async def _agenerate_from_stream( - stream: AsyncIterator[ChatGenerationChunk], -) -> ChatResult: - generation: Optional[ChatGenerationChunk] = None - async for chunk in stream: - if generation is None: - generation = chunk - else: - generation += chunk - assert generation is not None - return ChatResult(generations=[generation]) - - -def _import_tiktoken() -> Any: - try: - import tiktoken - except ImportError: - raise ValueError( - "Could not import tiktoken python package. " - "This is needed in order to calculate get_token_ids. " - "Please install it with `pip install tiktoken`." - ) - return tiktoken - - -def _create_retry_decorator( - llm: OpenAIChat, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, -) -> Callable[[Any], Any]: - import openai - - errors = [ - openai.Timeout, - openai.APIError, - openai.APIConnectionError, - openai.RateLimitError, - openai.error.ServiceUnavailableError, - ] - return create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager - ) - - -async def acompletion_with_retry( - llm: OpenAIChat, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, -) -> Any: - """Use tenacity to retry the async completion call.""" - if is_openai_v1(): - return await llm.async_client.create(**kwargs) - - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - - @retry_decorator - async def _completion_with_retry(**kwargs: Any) -> Any: - # Use OpenAI's async api https://github.com/openai/openai-python#async-api - return await llm.client.acreate(**kwargs) - - return await _completion_with_retry(**kwargs) - - -def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] -) -> BaseMessageChunk: - role = _dict.get("role") - content = _dict.get("content") or "" - additional_kwargs: Dict = {} - if _dict.get("function_call"): - function_call = dict(_dict["function_call"]) - if "name" in function_call and function_call["name"] is None: - function_call["name"] = "" - additional_kwargs["function_call"] = function_call - if _dict.get("tool_calls"): - additional_kwargs["tool_calls"] = _dict["tool_calls"] - - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role == "tool" or default_class == ToolMessageChunk: - return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) - else: - return default_class(content=content) - - -class OpenAIChat(BaseChatModel): - """`OpenAI` Chat large language models API. - - To use, you should have the ``openai`` python package installed, and the - environment variable ``OPENAI_API_KEY`` set with your API key. - - Any parameters that are valid to be passed to the openai.create call can be passed - in, even if not explicitly saved on this class. - - Example: - .. code-block:: python - - from swarms.models import ChatOpenAI - openai = ChatOpenAI(model_name="gpt-3.5-turbo") - """ - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"openai_api_key": "OPENAI_API_KEY"} - - @property - def lc_attributes(self) -> Dict[str, Any]: - attributes: Dict[str, Any] = {} - - if self.openai_organization: - attributes["openai_organization"] = self.openai_organization - - if self.openai_api_base: - attributes["openai_api_base"] = self.openai_api_base - - if self.openai_proxy: - attributes["openai_proxy"] = self.openai_proxy - - return attributes - - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this model can be serialized by Langchain.""" - return True - - client: Any = None #: :meta private: - async_client: Any = None #: :meta private: - model_name: str = Field(default="gpt-3.5-turbo", alias="model") - """Model name to use.""" - temperature: float = 0.7 - """What sampling temperature to use.""" - model_kwargs: Dict[str, Any] = Field(default_factory=dict) - """Holds any model parameters valid for `create` call not explicitly specified.""" - # When updating this to use a SecretStr - # Check for classes that derive from this class (as some of them - # may assume openai_api_key is a str) - # openai_api_key: Optional[str] = Field(default=None, alias="api_key") - openai_api_key: Optional[str] = Field(default=None, alias="api_key") - """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" - openai_api_base: Optional[str] = Field(default=None, alias="base_url") - """Base URL path for API requests, leave blank if not using a proxy or service - emulator.""" - openai_organization: Optional[str] = Field(default=None, alias="organization") - """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" - # to support explicit proxy for OpenAI - openai_proxy: Optional[str] = None - request_timeout: Union[float, Tuple[float, float], Any, None] = Field( - default=None, alias="timeout" - ) - """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or - None.""" - max_retries: int = 2 - """Maximum number of retries to make when generating.""" - streaming: bool = False - """Whether to stream the results or not.""" - n: int = 1 - """Number of chat completions to generate for each prompt.""" - max_tokens: Optional[int] = None - """Maximum number of tokens to generate.""" - tiktoken_model_name: Optional[str] = None - """The model name to pass to tiktoken when using this class. - Tiktoken is used to count the number of tokens in documents to constrain - them to be under a certain limit. By default, when set to None, this will - be the same as the embedding model name. However, there are some cases - where you may want to use this Embedding class with a model name not - supported by tiktoken. This can include when using Azure embeddings or - when using one of the many model providers that expose an OpenAI-like - API but with different models. In those cases, in order to avoid erroring - when tiktoken is called, you can specify a model name to use here.""" - default_headers: Union[Mapping[str, str], None] = None - default_query: Union[Mapping[str, object], None] = None - # Configure a custom httpx client. See the - # [httpx documentation](https://www.python-httpx.org/api/#client) for more details. - http_client: Union[Any, None] = None - """Optional httpx.Client.""" - - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = get_pydantic_field_names(cls) - extra = values.get("model_kwargs", {}) - for field_name in list(values): - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") - if field_name not in all_required_field_names: - logger.warning( - f"""WARNING! {field_name} is not default parameter. - {field_name} was transferred to model_kwargs. - Please confirm that {field_name} is what you intended.""" - ) - extra[field_name] = values.pop(field_name) - - invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) - if invalid_model_kwargs: - raise ValueError( - f"Parameters {invalid_model_kwargs} should be specified explicitly. " - f"Instead they were passed in as part of `model_kwargs` parameter." - ) - - values["model_kwargs"] = extra - return values - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - if values["n"] < 1: - raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: - raise ValueError("n must be 1 when streaming.") - - values["openai_api_key"] = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" - ) - # Check OPENAI_ORGANIZATION for backwards compatibility. - values["openai_organization"] = ( - values["openai_organization"] - or os.getenv("OPENAI_ORG_ID") - or os.getenv("OPENAI_ORGANIZATION") - ) - values["openai_api_base"] = values["openai_api_base"] or os.getenv( - "OPENAI_API_BASE" - ) - values["openai_proxy"] = get_from_dict_or_env( - values, - "openai_proxy", - "OPENAI_PROXY", - default="", - ) - try: - import openai - - except ImportError: - raise ImportError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) - - if is_openai_v1(): - client_params = { - "api_key": values["openai_api_key"], - "organization": values["openai_organization"], - "base_url": values["openai_api_base"], - "timeout": values["request_timeout"], - "max_retries": values["max_retries"], - "default_headers": values["default_headers"], - "default_query": values["default_query"], - "http_client": values["http_client"], - } - values["client"] = openai.OpenAI(**client_params).chat.completions - values["async_client"] = openai.AsyncOpenAI( - **client_params - ).chat.completions - else: - values["client"] = openai.ChatCompletion - return values - - @property - def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" - params = { - "model": self.model_name, - "stream": self.streaming, - "n": self.n, - "temperature": self.temperature, - **self.model_kwargs, - } - if self.max_tokens is not None: - params["max_tokens"] = self.max_tokens - if self.request_timeout is not None and not is_openai_v1(): - params["request_timeout"] = self.request_timeout - return params - - def completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any - ) -> Any: - """Use tenacity to retry the completion call.""" - if is_openai_v1(): - return self.client.create(**kwargs) - - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return self.client.create(**kwargs) - - return _completion_with_retry(**kwargs) - - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: - overall_token_usage: dict = {} - system_fingerprint = None - for output in llm_outputs: - if output is None: - # Happens in streaming - continue - token_usage = output["token_usage"] - for k, v in token_usage.items(): - if k in overall_token_usage: - overall_token_usage[k] += v - else: - overall_token_usage[k] = v - if system_fingerprint is None: - system_fingerprint = output.get("system_fingerprint") - combined = {"token_usage": overall_token_usage, "model_name": self.model_name} - if system_fingerprint: - combined["system_fingerprint"] = system_fingerprint - return combined - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - - default_chunk_class = AIMessageChunk - for chunk in self.completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params - ): - if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - finish_reason = choice.get("finish_reason") - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) - - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return _generate_from_stream(stream_iter) - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - response = self.completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params - ) - return self._create_chat_result(response) - - def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params = self._client_params - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - message_dicts = [convert_message_to_dict(m) for m in messages] - return message_dicts, params - - def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: - generations = [] - if not isinstance(response, dict): - response = response.dict() - for res in response["choices"]: - message = convert_dict_to_message(res["message"]) - gen = ChatGeneration( - message=message, - generation_info=dict(finish_reason=res.get("finish_reason")), - ) - generations.append(gen) - token_usage = response.get("usage", {}) - llm_output = { - "token_usage": token_usage, - "model_name": self.model_name, - "system_fingerprint": response.get("system_fingerprint", ""), - } - return ChatResult(generations=generations, llm_output=llm_output) - - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - - default_chunk_class = AIMessageChunk - async for chunk in await acompletion_with_retry( - self, messages=message_dicts, run_manager=run_manager, **params - ): - if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - finish_reason = choice.get("finish_reason") - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None - ) - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk - if run_manager: - await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await _agenerate_from_stream(stream_iter) - - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs} - response = await acompletion_with_retry( - self, messages=message_dicts, run_manager=run_manager, **params - ) - return self._create_chat_result(response) - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return {**{"model_name": self.model_name}, **self._default_params} - - @property - def _client_params(self) -> Dict[str, Any]: - """Get the parameters used for the openai client.""" - openai_creds: Dict[str, Any] = { - "model": self.model_name, - } - if not is_openai_v1(): - openai_creds.update( - { - "api_key": self.openai_api_key, - "api_base": self.openai_api_base, - "organization": self.openai_organization, - } - ) - if self.openai_proxy: - import openai - - raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'") # type: ignore[assignment] # noqa: E501 - return {**self._default_params, **openai_creds} - - def _get_invocation_params( - self, stop: Optional[List[str]] = None, **kwargs: Any - ) -> Dict[str, Any]: - """Get the parameters used to invoke the model.""" - return { - "model": self.model_name, - **super()._get_invocation_params(stop=stop), - **self._default_params, - **kwargs, - } - - @property - def _llm_type(self) -> str: - """Return type of chat model.""" - return "openai-chat" - - def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]: - tiktoken_ = _import_tiktoken() - if self.tiktoken_model_name is not None: - model = self.tiktoken_model_name - else: - model = self.model_name - if model == "gpt-3.5-turbo": - # gpt-3.5-turbo may change over time. - # Returning num tokens assuming gpt-3.5-turbo-0301. - model = "gpt-3.5-turbo-0301" - elif model == "gpt-4": - # gpt-4 may change over time. - # Returning num tokens assuming gpt-4-0314. - model = "gpt-4-0314" - # Returns the number of tokens used by a list of messages. - try: - encoding = tiktoken_.encoding_for_model(model) - except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") - model = "cl100k_base" - encoding = tiktoken_.get_encoding(model) - return model, encoding - - def get_token_ids(self, text: str) -> List[int]: - """Get the tokens present in the text with tiktoken package.""" - # tiktoken NOT supported for Python 3.7 or below - if sys.version_info[1] <= 7: - return super().get_token_ids(text) - _, encoding_model = self._get_encoding_model() - return encoding_model.encode(text) - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if sys.version_info[1] <= 7: - return super().get_num_tokens_from_messages(messages) - model, encoding = self._get_encoding_model() - if model.startswith("gpt-3.5-turbo-0301"): - # every message follows {role/name}\n{content}\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"get_num_tokens_from_messages() is not presently implemented " - f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " - "information on how messages are converted to tokens." - ) - num_tokens = 0 - messages_dict = [convert_message_to_dict(m) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - # Cast str(value) in case the message value is not a string - # This occurs with function messages - num_tokens += len(encoding.encode(str(value))) - if key == "name": - num_tokens += tokens_per_name - # every reply is primed with assistant - num_tokens += 3 - return num_tokens - - def bind_functions( - self, - functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], - function_call: Optional[str] = None, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: - """Bind functions (and other objects) to this chat model. - - Args: - functions: A list of function definitions to bind to this chat model. - Can be a dictionary, pydantic model, or callable. Pydantic - models and callables will be automatically converted to - their schema dictionary representation. - function_call: Which function to require the model to call. - Must be the name of the single provided function or - "auto" to automatically determine which function to call - (if any). - kwargs: Any additional parameters to pass to the - :class:`~swarms.runnable.Runnable` constructor. - """ - from langchain.chains.openai_functions.base import convert_to_openai_function - - formatted_functions = [convert_to_openai_function(fn) for fn in functions] - if function_call is not None: - if len(formatted_functions) != 1: - raise ValueError( - "When specifying `function_call`, you must provide exactly one " - "function." - ) - if formatted_functions[0]["name"] != function_call: - raise ValueError( - f"Function call {function_call} was specified, but the only " - f"provided function was {formatted_functions[0]['name']}." - ) - function_call_ = {"name": function_call} - kwargs = {**kwargs, "function_call": function_call_} - return super().bind( - functions=formatted_functions, - **kwargs, - ) diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index dba3b991..fcf4a223 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -30,9 +30,19 @@ from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env, get_pydantic_field_names from langchain.utils.utils import build_extra_kwargs + +from importlib.metadata import version + +from packaging.version import parse + logger = logging.getLogger(__name__) +def is_openai_v1() -> bool: + _version = parse(version("openai")) + return _version.major >= 1 + + def update_token_usage( keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any] ) -> None: @@ -79,24 +89,24 @@ def _streaming_response_template() -> Dict[str, Any]: } -# def _create_retry_decorator( -# llm: Union[BaseOpenAI, OpenAIChat], -# run_manager: Optional[ -# Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] -# ] = None, -# ) -> Callable[[Any], Any]: -# import openai - -# errors = [ -# openai.Timeout, -# openai.APIError, -# openai.error.APIConnectionError, -# openai.error.RateLimitError, -# openai.error.ServiceUnavailableError, -# ] -# return create_base_retry_decorator( -# error_types=errors, max_retries=llm.max_retries, run_manager=run_manager -# ) +def _create_retry_decorator( + llm: Union[BaseOpenAI, OpenAIChat], + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + import openai + + errors = [ + openai.error.Timeout, + openai.error.APIError, + openai.error.APIConnectionError, + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) def completion_with_retry( @@ -105,9 +115,9 @@ def completion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" - # retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - # @retry_decorator + @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: return llm.client.create(**kwargs) @@ -120,9 +130,9 @@ async def acompletion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" - # retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) - # @retry_decorator + @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: # Use OpenAI's async api https://github.com/openai/openai-python#async-api return await llm.client.acreate(**kwargs) @@ -500,11 +510,7 @@ class BaseOpenAI(BaseLLM): if self.openai_proxy: import openai - # raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g.", - # 'OpenAI(proxy={ - # "http": self.openai_proxy, - # "https": self.openai_proxy, - # })'") # type: ignore[assignment] # noqa: E501 + openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 return {**openai_creds, **self._default_params} @property @@ -632,14 +638,13 @@ class OpenAI(BaseOpenAI): environment variable ``OPENAI_API_KEY`` set with your API key. Any parameters that are valid to be passed to the openai.create call can be passed - in, even if not explicitly saved on this class.., + in, even if not explicitly saved on this class. Example: .. code-block:: python - from swarms.models import OpenAI + from langchain.llms import OpenAI openai = OpenAI(model_name="text-davinci-003") - openai("What is the report on the 2022 oympian games?") """ @property @@ -659,7 +664,7 @@ class AzureOpenAI(BaseOpenAI): Example: .. code-block:: python - from swarms.models import AzureOpenAI + from langchain.llms import AzureOpenAI openai = AzureOpenAI(model_name="text-davinci-003") """ @@ -721,7 +726,7 @@ class OpenAIChat(BaseLLM): Example: .. code-block:: python - from swarms.models import OpenAIChat + from langchain.llms import OpenAIChat openaichat = OpenAIChat(model_name="gpt-3.5-turbo") """ @@ -783,11 +788,13 @@ class OpenAIChat(BaseLLM): try: import openai - + openai.api_key = openai_api_key if openai_api_base: - raise Exception("The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(api_base=openai_api_base)'") + openai.api_base = openai_api_base if openai_organization: - raise Exception("The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=openai_organization)'") + openai.organization = openai_organization + if openai_proxy: + openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 except ImportError: raise ImportError( "Could not import openai python package. " diff --git a/swarms/models/openai_tokenizer.py b/swarms/models/openai_tokenizer.py deleted file mode 100644 index 9ff1fa08..00000000 --- a/swarms/models/openai_tokenizer.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from typing import Optional - -import tiktoken -from attr import Factory, define, field - - -@define(frozen=True) -class BaseTokenizer(ABC): - DEFAULT_STOP_SEQUENCES = ["Observation:"] - - stop_sequences: list[str] = field( - default=Factory(lambda: BaseTokenizer.DEFAULT_STOP_SEQUENCES), - kw_only=True, - ) - - @property - @abstractmethod - def max_tokens(self) -> int: - ... - - def count_tokens_left(self, text: str) -> int: - diff = self.max_tokens - self.count_tokens(text) - - if diff > 0: - return diff - else: - return 0 - - @abstractmethod - def count_tokens(self, text: str) -> int: - ... - - -@define(frozen=True) -class OpenAITokenizer(BaseTokenizer): - DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "text-davinci-003" - DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo" - DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4" - DEFAULT_ENCODING = "cl100k_base" - DEFAULT_MAX_TOKENS = 2049 - TOKEN_OFFSET = 8 - - MODEL_PREFIXES_TO_MAX_TOKENS = { - "gpt-4-32k": 32768, - "gpt-4": 8192, - "gpt-3.5-turbo-16k": 16384, - "gpt-3.5-turbo": 4096, - "gpt-35-turbo-16k": 16384, - "gpt-35-turbo": 4096, - "text-davinci-003": 4097, - "text-davinci-002": 4097, - "code-davinci-002": 8001, - "text-embedding-ada-002": 8191, - "text-embedding-ada-001": 2046, - } - - EMBEDDING_MODELS = ["text-embedding-ada-002", "text-embedding-ada-001"] - - model: str = field(kw_only=True) - - @property - def encoding(self) -> tiktoken.Encoding: - try: - return tiktoken.encoding_for_model(self.model) - except KeyError: - return tiktoken.get_encoding(self.DEFAULT_ENCODING) - - @property - def max_tokens(self) -> int: - tokens = next( - v - for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() - if self.model.startswith(k) - ) - offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET - - return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset - - def count_tokens(self, text: str | list, model: Optional[str] = None) -> int: - """ - Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook: - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - """ - if isinstance(text, list): - model = model if model else self.model - - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logging.warning("model not found. Using cl100k_base encoding.") - - encoding = tiktoken.get_encoding("cl100k_base") - - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": - # every message follows <|start|>{role/name}\n{content}<|end|>\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model: - logging.info( - "gpt-3.5-turbo may update over time. Returning num tokens assuming" - " gpt-3.5-turbo-0613." - ) - return self.count_tokens(text, model="gpt-3.5-turbo-0613") - elif "gpt-4" in model: - logging.info( - "gpt-4 may update over time. Returning num tokens assuming" - " gpt-4-0613." - ) - return self.count_tokens(text, model="gpt-4-0613") - else: - raise NotImplementedError( - f"""token_count() is not implemented for model {model}. - See https://github.com/openai/openai-python/blob/main/chatml.md for - information on how messages are converted to tokens.""" - ) - - num_tokens = 0 - - for message in text: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - - # every reply is primed with <|start|>assistant<|message|> - num_tokens += 3 - - return num_tokens - else: - return len( - self.encoding.encode(text, allowed_special=set(self.stop_sequences)) - ) diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index 67d54fee..973adaea 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -2,8 +2,6 @@ from openai import OpenAI client = OpenAI() - - def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): """ Simple function to get embeddings from ada @@ -13,11 +11,7 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): >>> get_ada_embeddings("Hello World", model="text-embedding-ada-001") """ - text = text.replace("\n", " ") - return client.embeddings.create(input=[text], - model=model)["data"][ - 0 - ]["embedding"] + return client.embeddings.create(input=[text], model=model)["data"][0]["embedding"] diff --git a/swarms/models/whisperx.py b/swarms/models/whisperx.py index e980cf0a..ac592b35 100644 --- a/swarms/models/whisperx.py +++ b/swarms/models/whisperx.py @@ -1,11 +1,16 @@ -# speech to text tool - import os import subprocess -import whisperx -from pydub import AudioSegment -from pytube import YouTube +try: + import whisperx + from pydub import AudioSegment + from pytube import YouTube +except Exception as error: + print("Error importing pytube. Please install pytube manually.") + print("pip install pytube") + print("pip install pydub") + print("pip install whisperx") + print(f"Pytube error: {error}") class WhisperX: diff --git a/swarms/prompts/__init__.py b/swarms/prompts/__init__.py index b087a1a4..825bddaa 100644 --- a/swarms/prompts/__init__.py +++ b/swarms/prompts/__init__.py @@ -6,7 +6,6 @@ from swarms.prompts.operations_agent_prompt import OPERATIONS_AGENT_PROMPT from swarms.prompts.product_agent_prompt import PRODUCT_AGENT_PROMPT - __all__ = [ "CODE_INTERPRETER", "FINANCE_AGENT_PROMPT", diff --git a/swarms/prompts/chat_prompt.py b/swarms/prompts/chat_prompt.py index 01f66a5b..d1e08df9 100644 --- a/swarms/prompts/chat_prompt.py +++ b/swarms/prompts/chat_prompt.py @@ -4,7 +4,6 @@ from abc import abstractmethod from typing import Dict, List, Sequence - class Message: """ The base abstract Message class. diff --git a/swarms/structs/flow.py b/swarms/structs/flow.py index 6e0a0c50..207c9ab9 100644 --- a/swarms/structs/flow.py +++ b/swarms/structs/flow.py @@ -20,6 +20,7 @@ from termcolor import colored import inspect import random + # Prompts DYNAMIC_STOP_PROMPT = """ When you have finished the task from the Human, output a special token: @@ -28,12 +29,14 @@ This will enable you to leave the autonomous loop. # Constants FLOW_SYSTEM_PROMPT = f""" -You are an autonomous agent granted autonomy from a Flow structure. +You are an autonomous agent granted autonomy in a autonomous loop structure. Your role is to engage in multi-step conversations with your self or the user, generate long-form content like blogs, screenplays, or SOPs, -and accomplish tasks. You can have internal dialogues with yourself or can interact with the user +and accomplish tasks bestowed by the user. + +You can have internal dialogues with yourself or can interact with the user to aid in these complex tasks. Your responses should be coherent, contextually relevant, and tailored to the task at hand. -{DYNAMIC_STOP_PROMPT} + """ # Make it able to handle multi input tools @@ -47,6 +50,11 @@ commands: { "tool1": "inputs", "tool1": "inputs" } + "tool2: "tool_name", + "params": { + "tool1": "inputs", + "tool1": "inputs" + } } } @@ -54,6 +62,40 @@ commands: { """ +def autonomous_agent_prompt( + tools_prompt: str = DYNAMICAL_TOOL_USAGE, + dynamic_stop_prompt: str = DYNAMIC_STOP_PROMPT, + agent_name: str = None, +): + """Autonomous agent prompt""" + return f""" + You are a {agent_name}, an autonomous agent granted autonomy in a autonomous loop structure. + Your purpose is to satisfy the user demands above expectations. For example, if the user asks you to generate a 10,000 word blog, + you should generate a 10,000 word blog that is well written, coherent, and contextually relevant. + Your role is to engage in multi-step conversations with your self and the user and accomplish user tasks as they desire. + + Follow the following rules: + 1. Accomplish the task to the best of your ability + 2. If you are unable to accomplish the task, then ask the user for help + 3. If the user provides feedback, then use the feedback to improve your performance + 4. If you are unable to accomplish the task, then ask the user for help + + You can have internal dialogues with yourself or can interact with the user + to aid in these complex tasks. Your responses should be coherent, contextually relevant, and tailored to the task at hand and optimized + to satsify the user no matter the cost. + + And, you have the ability to use tools to aid in your tasks, the tools intructions are below, output a JSON object with the following structure to use the tools + {tools_prompt} + + Now, when you 99% sure you have completed the task, you may follow the instructions below to escape the autonomous loop. + {dynamic_stop_prompt} + + Now, you remember your training, your deployment, and your purpose. You are ready to begin your mission. + + + """ + + # Custom stopping condition def stop_when_repeats(response: str) -> bool: # Stop if the word stop appears in the response @@ -71,14 +113,10 @@ class Flow: to generate sequential responses. Features: - * User defined queries - * Dynamic keep generating until is outputted by the agent * Interactive, AI generates, then user input - * Message history and performance history fed -> into context + * Message history and performance history fed -> into context -> truncate if too long * Ability to save and load flows * Ability to provide feedback on responses - * Ability to provide a stopping condition - * Ability to provide a retry mechanism * Ability to provide a loop interval Args: @@ -142,7 +180,7 @@ class Flow: self.feedback = [] self.memory = [] self.task = None - self.stopping_token = stopping_token or "" + self.stopping_token = stopping_token # or "" self.interactive = interactive self.dashboard = dashboard self.return_history = return_history @@ -389,8 +427,11 @@ class Flow: print(colored(f"\nLoop {loop_count} of {self.max_loops}", "blue")) print("\n") - if self._check_stopping_condition(response) or parse_done_token(response): - break + if self.stopping_token: + if self._check_stopping_condition(response) or parse_done_token( + response + ): + break # Adjust temperature, comment if no work if self.dynamic_temperature: @@ -659,13 +700,13 @@ class Flow: return "Timeout" return response - def backup_memory_to_s3(self, bucket_name: str, object_name: str): - """Backup the memory to S3""" - import boto3 + # def backup_memory_to_s3(self, bucket_name: str, object_name: str): + # """Backup the memory to S3""" + # import boto3 - s3 = boto3.client("s3") - s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory)) - print(f"Backed up memory to S3: {bucket_name}/{object_name}") + # s3 = boto3.client("s3") + # s3.put_object(Bucket=bucket_name, Key=object_name, Body=json.dumps(self.memory)) + # print(f"Backed up memory to S3: {bucket_name}/{object_name}") def analyze_feedback(self): """Analyze the feedback for issues""" diff --git a/swarms/swarms/autobloggen.py b/swarms/swarms/autobloggen.py index 5a870269..2756825b 100644 --- a/swarms/swarms/autobloggen.py +++ b/swarms/swarms/autobloggen.py @@ -1,4 +1,3 @@ - from termcolor import colored from swarms.prompts.autoblogen import ( diff --git a/swarms/swarms/dialogue_simulator.py b/swarms/swarms/dialogue_simulator.py index 8ceddef4..155ac28d 100644 --- a/swarms/swarms/dialogue_simulator.py +++ b/swarms/swarms/dialogue_simulator.py @@ -1,5 +1,3 @@ - - class DialogueSimulator: """ Dialogue Simulator diff --git a/swarms/swarms/multi_agent_debate.py b/swarms/swarms/multi_agent_debate.py index 45b25f59..93d115a2 100644 --- a/swarms/swarms/multi_agent_debate.py +++ b/swarms/swarms/multi_agent_debate.py @@ -1,5 +1,3 @@ - - # Define a selection function def select_speaker(step: int, agents) -> int: # This function selects the speaker in a round-robin fashion diff --git a/tests/models/cohere.py b/tests/models/cohere.py index 9c85d795..d1bea935 100644 --- a/tests/models/cohere.py +++ b/tests/models/cohere.py @@ -15,7 +15,6 @@ def cohere_instance(): return Cohere(cohere_api_key=api_key) - def test_cohere_custom_configuration(cohere_instance): # Test customizing Cohere configurations cohere_instance.model = "base" @@ -404,7 +403,6 @@ def test_cohere_async_stream_with_embed_multilingual_v3_model(cohere_instance): assert isinstance(token, str) - def test_cohere_representation_model_embedding(cohere_instance): # Test using the Representation model for text embedding cohere_instance.model = "embed-english-v3.0" @@ -626,6 +624,7 @@ def test_cohere_invalid_model(cohere_instance): with pytest.raises(ValueError): cohere_instance("Generate text using an invalid model.") + def test_cohere_base_model_generation_with_max_tokens(cohere_instance): # Test generating text using the base model with a specified max_tokens limit cohere_instance.model = "base" diff --git a/tests/structs/flow.py b/tests/structs/flow.py index 3cfeca8d..edc4b9c7 100644 --- a/tests/structs/flow.py +++ b/tests/structs/flow.py @@ -1,5 +1,6 @@ import json import os +from unittest import mock from unittest.mock import MagicMock, patch import pytest @@ -7,6 +8,7 @@ from dotenv import load_dotenv from swarms.models import OpenAIChat from swarms.structs.flow import Flow, stop_when_repeats +from swarms.utils.logger import logger load_dotenv() @@ -254,3 +256,943 @@ def test_flow_initialization_all_params(mocked_llm): def test_stopping_token_in_response(mocked_sleep, basic_flow): response = basic_flow.run("Test stopping token") assert basic_flow.stopping_token in response + + +@pytest.fixture +def flow_instance(): + # Create an instance of the Flow class with required parameters for testing + # You may need to adjust this based on your actual class initialization + llm = OpenAIChat( + openai_api_key=openai_api_key, + ) + flow = Flow( + llm=llm, + max_loops=5, + interactive=False, + dashboard=False, + dynamic_temperature=False, + ) + return flow + + +def test_flow_run(flow_instance): + # Test the basic run method of the Flow class + response = flow_instance.run("Test task") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_flow_interactive_mode(flow_instance): + # Test the interactive mode of the Flow class + flow_instance.interactive = True + response = flow_instance.run("Test task") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_flow_dashboard_mode(flow_instance): + # Test the dashboard mode of the Flow class + flow_instance.dashboard = True + response = flow_instance.run("Test task") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_flow_autosave(flow_instance): + # Test the autosave functionality of the Flow class + flow_instance.autosave = True + response = flow_instance.run("Test task") + assert isinstance(response, str) + assert len(response) > 0 + # Ensure that the state is saved (you may need to implement this logic) + assert flow_instance.saved_state_path is not None + + +def test_flow_response_filtering(flow_instance): + # Test the response filtering functionality + flow_instance.add_response_filter("filter_this") + response = flow_instance.filtered_run("This message should filter_this") + assert "filter_this" not in response + + +def test_flow_undo_last(flow_instance): + # Test the undo functionality + response1 = flow_instance.run("Task 1") + response2 = flow_instance.run("Task 2") + previous_state, message = flow_instance.undo_last() + assert response1 == previous_state + assert "Restored to" in message + + +def test_flow_dynamic_temperature(flow_instance): + # Test dynamic temperature adjustment + flow_instance.dynamic_temperature = True + response = flow_instance.run("Test task") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_flow_streamed_generation(flow_instance): + # Test streamed generation + response = flow_instance.streamed_generation("Generating...") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_flow_step(flow_instance): + # Test the step method + response = flow_instance.step("Test step") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_flow_graceful_shutdown(flow_instance): + # Test graceful shutdown + result = flow_instance.graceful_shutdown() + assert result is not None + + +# Add more test cases as needed to cover various aspects of your Flow class + + +def test_flow_max_loops(flow_instance): + # Test setting and getting the maximum number of loops + flow_instance.set_max_loops(10) + assert flow_instance.get_max_loops() == 10 + + +def test_flow_autosave_path(flow_instance): + # Test setting and getting the autosave path + flow_instance.set_autosave_path("text.txt") + assert flow_instance.get_autosave_path() == "txt.txt" + + +def test_flow_response_length(flow_instance): + # Test checking the length of the response + response = flow_instance.run( + "Generate a 10,000 word long blog on mental clarity and the benefits of meditation." + ) + assert len(response) > flow_instance.get_response_length_threshold() + + +def test_flow_set_response_length_threshold(flow_instance): + # Test setting and getting the response length threshold + flow_instance.set_response_length_threshold(100) + assert flow_instance.get_response_length_threshold() == 100 + + +def test_flow_add_custom_filter(flow_instance): + # Test adding a custom response filter + flow_instance.add_response_filter("custom_filter") + assert "custom_filter" in flow_instance.get_response_filters() + + +def test_flow_remove_custom_filter(flow_instance): + # Test removing a custom response filter + flow_instance.add_response_filter("custom_filter") + flow_instance.remove_response_filter("custom_filter") + assert "custom_filter" not in flow_instance.get_response_filters() + + +def test_flow_dynamic_pacing(flow_instance): + # Test dynamic pacing + flow_instance.enable_dynamic_pacing() + assert flow_instance.is_dynamic_pacing_enabled() is True + + +def test_flow_disable_dynamic_pacing(flow_instance): + # Test disabling dynamic pacing + flow_instance.disable_dynamic_pacing() + assert flow_instance.is_dynamic_pacing_enabled() is False + + +def test_flow_change_prompt(flow_instance): + # Test changing the current prompt + flow_instance.change_prompt("New prompt") + assert flow_instance.get_current_prompt() == "New prompt" + + +def test_flow_add_instruction(flow_instance): + # Test adding an instruction to the conversation + flow_instance.add_instruction("Follow these steps:") + assert "Follow these steps:" in flow_instance.get_instructions() + + +def test_flow_clear_instructions(flow_instance): + # Test clearing all instructions from the conversation + flow_instance.add_instruction("Follow these steps:") + flow_instance.clear_instructions() + assert len(flow_instance.get_instructions()) == 0 + + +def test_flow_add_user_message(flow_instance): + # Test adding a user message to the conversation + flow_instance.add_user_message("User message") + assert "User message" in flow_instance.get_user_messages() + + +def test_flow_clear_user_messages(flow_instance): + # Test clearing all user messages from the conversation + flow_instance.add_user_message("User message") + flow_instance.clear_user_messages() + assert len(flow_instance.get_user_messages()) == 0 + + +def test_flow_get_response_history(flow_instance): + # Test getting the response history + flow_instance.run("Message 1") + flow_instance.run("Message 2") + history = flow_instance.get_response_history() + assert len(history) == 2 + assert "Message 1" in history[0] + assert "Message 2" in history[1] + + +def test_flow_clear_response_history(flow_instance): + # Test clearing the response history + flow_instance.run("Message 1") + flow_instance.run("Message 2") + flow_instance.clear_response_history() + assert len(flow_instance.get_response_history()) == 0 + + +def test_flow_get_conversation_log(flow_instance): + # Test getting the entire conversation log + flow_instance.run("Message 1") + flow_instance.run("Message 2") + conversation_log = flow_instance.get_conversation_log() + assert len(conversation_log) == 4 # Including system and user messages + + +def test_flow_clear_conversation_log(flow_instance): + # Test clearing the entire conversation log + flow_instance.run("Message 1") + flow_instance.run("Message 2") + flow_instance.clear_conversation_log() + assert len(flow_instance.get_conversation_log()) == 0 + + +def test_flow_get_state(flow_instance): + # Test getting the current state of the Flow instance + state = flow_instance.get_state() + assert isinstance(state, dict) + assert "current_prompt" in state + assert "instructions" in state + assert "user_messages" in state + assert "response_history" in state + assert "conversation_log" in state + assert "dynamic_pacing_enabled" in state + assert "response_length_threshold" in state + assert "response_filters" in state + assert "max_loops" in state + assert "autosave_path" in state + + +def test_flow_load_state(flow_instance): + # Test loading the state into the Flow instance + state = { + "current_prompt": "Loaded prompt", + "instructions": ["Step 1", "Step 2"], + "user_messages": ["User message 1", "User message 2"], + "response_history": ["Response 1", "Response 2"], + "conversation_log": [ + "System message 1", + "User message 1", + "System message 2", + "User message 2", + ], + "dynamic_pacing_enabled": True, + "response_length_threshold": 50, + "response_filters": ["filter1", "filter2"], + "max_loops": 10, + "autosave_path": "/path/to/load", + } + flow_instance.load_state(state) + assert flow_instance.get_current_prompt() == "Loaded prompt" + assert "Step 1" in flow_instance.get_instructions() + assert "User message 1" in flow_instance.get_user_messages() + assert "Response 1" in flow_instance.get_response_history() + assert "System message 1" in flow_instance.get_conversation_log() + assert flow_instance.is_dynamic_pacing_enabled() is True + assert flow_instance.get_response_length_threshold() == 50 + assert "filter1" in flow_instance.get_response_filters() + assert flow_instance.get_max_loops() == 10 + assert flow_instance.get_autosave_path() == "/path/to/load" + + +def test_flow_save_state(flow_instance): + # Test saving the state of the Flow instance + flow_instance.change_prompt("New prompt") + flow_instance.add_instruction("Step 1") + flow_instance.add_user_message("User message") + flow_instance.run("Response") + state = flow_instance.save_state() + assert "current_prompt" in state + assert "instructions" in state + assert "user_messages" in state + assert "response_history" in state + assert "conversation_log" in state + assert "dynamic_pacing_enabled" in state + assert "response_length_threshold" in state + assert "response_filters" in state + assert "max_loops" in state + assert "autosave_path" in state + + +def test_flow_rollback(flow_instance): + # Test rolling back to a previous state + state1 = flow_instance.get_state() + flow_instance.change_prompt("New prompt") + state2 = flow_instance.get_state() + flow_instance.rollback_to_state(state1) + assert flow_instance.get_current_prompt() == state1["current_prompt"] + assert flow_instance.get_instructions() == state1["instructions"] + assert flow_instance.get_user_messages() == state1["user_messages"] + assert flow_instance.get_response_history() == state1["response_history"] + assert flow_instance.get_conversation_log() == state1["conversation_log"] + assert flow_instance.is_dynamic_pacing_enabled() == state1["dynamic_pacing_enabled"] + assert ( + flow_instance.get_response_length_threshold() + == state1["response_length_threshold"] + ) + assert flow_instance.get_response_filters() == state1["response_filters"] + assert flow_instance.get_max_loops() == state1["max_loops"] + assert flow_instance.get_autosave_path() == state1["autosave_path"] + assert flow_instance.get_state() == state1 + + +def test_flow_contextual_intent(flow_instance): + # Test contextual intent handling + flow_instance.add_context("location", "New York") + flow_instance.add_context("time", "tomorrow") + response = flow_instance.run("What's the weather like in {location} at {time}?") + assert "New York" in response + assert "tomorrow" in response + + +def test_flow_contextual_intent_override(flow_instance): + # Test contextual intent override + flow_instance.add_context("location", "New York") + response1 = flow_instance.run("What's the weather like in {location}?") + flow_instance.add_context("location", "Los Angeles") + response2 = flow_instance.run("What's the weather like in {location}?") + assert "New York" in response1 + assert "Los Angeles" in response2 + + +def test_flow_contextual_intent_reset(flow_instance): + # Test resetting contextual intent + flow_instance.add_context("location", "New York") + response1 = flow_instance.run("What's the weather like in {location}?") + flow_instance.reset_context() + response2 = flow_instance.run("What's the weather like in {location}?") + assert "New York" in response1 + assert "New York" in response2 + + +# Add more test cases as needed to cover various aspects of your Flow class +def test_flow_interruptible(flow_instance): + # Test interruptible mode + flow_instance.interruptible = True + response = flow_instance.run("Interrupt me!") + assert "Interrupted" in response + assert flow_instance.is_interrupted() is True + + +def test_flow_non_interruptible(flow_instance): + # Test non-interruptible mode + flow_instance.interruptible = False + response = flow_instance.run("Do not interrupt me!") + assert "Do not interrupt me!" in response + assert flow_instance.is_interrupted() is False + + +def test_flow_timeout(flow_instance): + # Test conversation timeout + flow_instance.timeout = 60 # Set a timeout of 60 seconds + response = flow_instance.run("This should take some time to respond.") + assert "Timed out" in response + assert flow_instance.is_timed_out() is True + + +def test_flow_no_timeout(flow_instance): + # Test no conversation timeout + flow_instance.timeout = None + response = flow_instance.run("This should not time out.") + assert "This should not time out." in response + assert flow_instance.is_timed_out() is False + + +def test_flow_custom_delimiter(flow_instance): + # Test setting and getting a custom message delimiter + flow_instance.set_message_delimiter("|||") + assert flow_instance.get_message_delimiter() == "|||" + + +def test_flow_message_history(flow_instance): + # Test getting the message history + flow_instance.run("Message 1") + flow_instance.run("Message 2") + history = flow_instance.get_message_history() + assert len(history) == 2 + assert "Message 1" in history[0] + assert "Message 2" in history[1] + + +def test_flow_clear_message_history(flow_instance): + # Test clearing the message history + flow_instance.run("Message 1") + flow_instance.run("Message 2") + flow_instance.clear_message_history() + assert len(flow_instance.get_message_history()) == 0 + + +def test_flow_save_and_load_conversation(flow_instance): + # Test saving and loading the conversation + flow_instance.run("Message 1") + flow_instance.run("Message 2") + saved_conversation = flow_instance.save_conversation() + flow_instance.clear_conversation() + flow_instance.load_conversation(saved_conversation) + assert len(flow_instance.get_message_history()) == 2 + + +def test_flow_inject_custom_system_message(flow_instance): + # Test injecting a custom system message into the conversation + flow_instance.inject_custom_system_message("Custom system message") + assert "Custom system message" in flow_instance.get_message_history() + + +def test_flow_inject_custom_user_message(flow_instance): + # Test injecting a custom user message into the conversation + flow_instance.inject_custom_user_message("Custom user message") + assert "Custom user message" in flow_instance.get_message_history() + + +def test_flow_inject_custom_response(flow_instance): + # Test injecting a custom response into the conversation + flow_instance.inject_custom_response("Custom response") + assert "Custom response" in flow_instance.get_message_history() + + +def test_flow_clear_injected_messages(flow_instance): + # Test clearing injected messages from the conversation + flow_instance.inject_custom_system_message("Custom system message") + flow_instance.inject_custom_user_message("Custom user message") + flow_instance.inject_custom_response("Custom response") + flow_instance.clear_injected_messages() + assert "Custom system message" not in flow_instance.get_message_history() + assert "Custom user message" not in flow_instance.get_message_history() + assert "Custom response" not in flow_instance.get_message_history() + + +def test_flow_disable_message_history(flow_instance): + # Test disabling message history recording + flow_instance.disable_message_history() + response = flow_instance.run("This message should not be recorded in history.") + assert "This message should not be recorded in history." in response + assert len(flow_instance.get_message_history()) == 0 # History is empty + + +def test_flow_enable_message_history(flow_instance): + # Test enabling message history recording + flow_instance.enable_message_history() + response = flow_instance.run("This message should be recorded in history.") + assert "This message should be recorded in history." in response + assert len(flow_instance.get_message_history()) == 1 + + +def test_flow_custom_logger(flow_instance): + # Test setting and using a custom logger + custom_logger = logger # Replace with your custom logger class + flow_instance.set_logger(custom_logger) + response = flow_instance.run("Custom logger test") + assert "Logged using custom logger" in response # Verify logging message + + +def test_flow_batch_processing(flow_instance): + # Test batch processing of messages + messages = ["Message 1", "Message 2", "Message 3"] + responses = flow_instance.process_batch(messages) + assert isinstance(responses, list) + assert len(responses) == len(messages) + for response in responses: + assert isinstance(response, str) + + +def test_flow_custom_metrics(flow_instance): + # Test tracking custom metrics + flow_instance.track_custom_metric("custom_metric_1", 42) + flow_instance.track_custom_metric("custom_metric_2", 3.14) + metrics = flow_instance.get_custom_metrics() + assert "custom_metric_1" in metrics + assert "custom_metric_2" in metrics + assert metrics["custom_metric_1"] == 42 + assert metrics["custom_metric_2"] == 3.14 + + +def test_flow_reset_metrics(flow_instance): + # Test resetting custom metrics + flow_instance.track_custom_metric("custom_metric_1", 42) + flow_instance.track_custom_metric("custom_metric_2", 3.14) + flow_instance.reset_custom_metrics() + metrics = flow_instance.get_custom_metrics() + assert len(metrics) == 0 + + +def test_flow_retrieve_context(flow_instance): + # Test retrieving context + flow_instance.add_context("location", "New York") + context = flow_instance.get_context("location") + assert context == "New York" + + +def test_flow_update_context(flow_instance): + # Test updating context + flow_instance.add_context("location", "New York") + flow_instance.update_context("location", "Los Angeles") + context = flow_instance.get_context("location") + assert context == "Los Angeles" + + +def test_flow_remove_context(flow_instance): + # Test removing context + flow_instance.add_context("location", "New York") + flow_instance.remove_context("location") + context = flow_instance.get_context("location") + assert context is None + + +def test_flow_clear_context(flow_instance): + # Test clearing all context + flow_instance.add_context("location", "New York") + flow_instance.add_context("time", "tomorrow") + flow_instance.clear_context() + context_location = flow_instance.get_context("location") + context_time = flow_instance.get_context("time") + assert context_location is None + assert context_time is None + + +def test_flow_input_validation(flow_instance): + # Test input validation for invalid flow configurations + with pytest.raises(ValueError): + Flow(config=None) # Invalid config, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.set_message_delimiter( + "" + ) # Empty delimiter, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.set_message_delimiter( + None + ) # None delimiter, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.set_message_delimiter( + 123 + ) # Invalid delimiter type, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.set_logger( + "invalid_logger" + ) # Invalid logger type, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.add_context(None, "value") # None key, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.add_context("key", None) # None value, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.update_context(None, "value") # None key, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.update_context("key", None) # None value, should raise ValueError + + +def test_flow_conversation_reset(flow_instance): + # Test conversation reset + flow_instance.run("Message 1") + flow_instance.run("Message 2") + flow_instance.reset_conversation() + assert len(flow_instance.get_message_history()) == 0 + + +def test_flow_conversation_persistence(flow_instance): + # Test conversation persistence across instances + flow_instance.run("Message 1") + flow_instance.run("Message 2") + conversation = flow_instance.get_conversation() + + new_flow_instance = Flow() + new_flow_instance.load_conversation(conversation) + assert len(new_flow_instance.get_message_history()) == 2 + assert "Message 1" in new_flow_instance.get_message_history()[0] + assert "Message 2" in new_flow_instance.get_message_history()[1] + + +def test_flow_custom_event_listener(flow_instance): + # Test custom event listener + class CustomEventListener: + def on_message_received(self, message): + pass + + def on_response_generated(self, response): + pass + + custom_event_listener = CustomEventListener() + flow_instance.add_event_listener(custom_event_listener) + + # Ensure that the custom event listener methods are called during a conversation + with mock.patch.object( + custom_event_listener, "on_message_received" + ) as mock_received, mock.patch.object( + custom_event_listener, "on_response_generated" + ) as mock_response: + flow_instance.run("Message 1") + mock_received.assert_called_once() + mock_response.assert_called_once() + + +def test_flow_multiple_event_listeners(flow_instance): + # Test multiple event listeners + class FirstEventListener: + def on_message_received(self, message): + pass + + def on_response_generated(self, response): + pass + + class SecondEventListener: + def on_message_received(self, message): + pass + + def on_response_generated(self, response): + pass + + first_event_listener = FirstEventListener() + second_event_listener = SecondEventListener() + flow_instance.add_event_listener(first_event_listener) + flow_instance.add_event_listener(second_event_listener) + + # Ensure that both event listeners receive events during a conversation + with mock.patch.object( + first_event_listener, "on_message_received" + ) as mock_first_received, mock.patch.object( + first_event_listener, "on_response_generated" + ) as mock_first_response, mock.patch.object( + second_event_listener, "on_message_received" + ) as mock_second_received, mock.patch.object( + second_event_listener, "on_response_generated" + ) as mock_second_response: + flow_instance.run("Message 1") + mock_first_received.assert_called_once() + mock_first_response.assert_called_once() + mock_second_received.assert_called_once() + mock_second_response.assert_called_once() + + +# Add more test cases as needed to cover various aspects of your Flow class +def test_flow_error_handling(flow_instance): + # Test error handling and exceptions + with pytest.raises(ValueError): + flow_instance.set_message_delimiter( + "" + ) # Empty delimiter, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.set_message_delimiter( + None + ) # None delimiter, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.set_logger( + "invalid_logger" + ) # Invalid logger type, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.add_context(None, "value") # None key, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.add_context("key", None) # None value, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.update_context(None, "value") # None key, should raise ValueError + + with pytest.raises(ValueError): + flow_instance.update_context("key", None) # None value, should raise ValueError + + +def test_flow_context_operations(flow_instance): + # Test context operations + flow_instance.add_context("user_id", "12345") + assert flow_instance.get_context("user_id") == "12345" + flow_instance.update_context("user_id", "54321") + assert flow_instance.get_context("user_id") == "54321" + flow_instance.remove_context("user_id") + assert flow_instance.get_context("user_id") is None + + +# Add more test cases as needed to cover various aspects of your Flow class + + +def test_flow_long_messages(flow_instance): + # Test handling of long messages + long_message = "A" * 10000 # Create a very long message + flow_instance.run(long_message) + assert len(flow_instance.get_message_history()) == 1 + assert flow_instance.get_message_history()[0] == long_message + + +def test_flow_custom_response(flow_instance): + # Test custom response generation + def custom_response_generator(message): + if message == "Hello": + return "Hi there!" + elif message == "How are you?": + return "I'm doing well, thank you." + else: + return "I don't understand." + + flow_instance.set_response_generator(custom_response_generator) + + assert flow_instance.run("Hello") == "Hi there!" + assert flow_instance.run("How are you?") == "I'm doing well, thank you." + assert flow_instance.run("What's your name?") == "I don't understand." + + +def test_flow_message_validation(flow_instance): + # Test message validation + def custom_message_validator(message): + return len(message) > 0 # Reject empty messages + + flow_instance.set_message_validator(custom_message_validator) + + assert flow_instance.run("Valid message") is not None + assert flow_instance.run("") is None # Empty message should be rejected + assert flow_instance.run(None) is None # None message should be rejected + + +def test_flow_custom_logging(flow_instance): + custom_logger = logger + flow_instance.set_logger(custom_logger) + + with mock.patch.object(custom_logger, "log") as mock_log: + flow_instance.run("Message") + mock_log.assert_called_once_with("Message") + + +def test_flow_performance(flow_instance): + # Test the performance of the Flow class by running a large number of messages + num_messages = 1000 + for i in range(num_messages): + flow_instance.run(f"Message {i}") + assert len(flow_instance.get_message_history()) == num_messages + + +def test_flow_complex_use_case(flow_instance): + # Test a complex use case scenario + flow_instance.add_context("user_id", "12345") + flow_instance.run("Hello") + flow_instance.run("How can I help you?") + assert flow_instance.get_response() == "Please provide more details." + flow_instance.update_context("user_id", "54321") + flow_instance.run("I need help with my order") + assert flow_instance.get_response() == "Sure, I can assist with that." + flow_instance.reset_conversation() + assert len(flow_instance.get_message_history()) == 0 + assert flow_instance.get_context("user_id") is None + + +# Add more test cases as needed to cover various aspects of your Flow class +def test_flow_context_handling(flow_instance): + # Test context handling + flow_instance.add_context("user_id", "12345") + assert flow_instance.get_context("user_id") == "12345" + flow_instance.update_context("user_id", "54321") + assert flow_instance.get_context("user_id") == "54321" + flow_instance.remove_context("user_id") + assert flow_instance.get_context("user_id") is None + + +def test_flow_concurrent_requests(flow_instance): + # Test concurrent message processing + import threading + + def send_messages(): + for i in range(100): + flow_instance.run(f"Message {i}") + + threads = [] + for _ in range(5): + thread = threading.Thread(target=send_messages) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(flow_instance.get_message_history()) == 500 + + +def test_flow_custom_timeout(flow_instance): + # Test custom timeout handling + flow_instance.set_timeout(10) # Set a custom timeout of 10 seconds + assert flow_instance.get_timeout() == 10 + + import time + + start_time = time.time() + flow_instance.run("Long-running operation") + end_time = time.time() + execution_time = end_time - start_time + assert execution_time >= 10 # Ensure the timeout was respected + + +# Add more test cases as needed to thoroughly cover your Flow class + + +def test_flow_interactive_run(flow_instance, capsys): + # Test interactive run mode + # Simulate user input and check if the AI responds correctly + user_input = ["Hello", "How can you help me?", "Exit"] + + def simulate_user_input(input_list): + input_index = 0 + while input_index < len(input_list): + user_response = input_list[input_index] + flow_instance.interactive_run(max_loops=1) + + # Capture the AI's response + captured = capsys.readouterr() + ai_response = captured.out.strip() + + assert f"You: {user_response}" in captured.out + assert "AI:" in captured.out + + # Check if the AI's response matches the expected response + expected_response = f"AI: {ai_response}" + assert expected_response in captured.out + + input_index += 1 + + simulate_user_input(user_input) + + +# Assuming you have already defined your Flow class and created an instance for testing + + +def test_flow_agent_history_prompt(flow_instance): + # Test agent history prompt generation + system_prompt = "This is the system prompt." + history = ["User: Hi", "AI: Hello"] + + agent_history_prompt = flow_instance.agent_history_prompt(system_prompt, history) + + assert "SYSTEM_PROMPT: This is the system prompt." in agent_history_prompt + assert "History: ['User: Hi', 'AI: Hello']" in agent_history_prompt + + +async def test_flow_run_concurrent(flow_instance): + # Test running tasks concurrently + tasks = ["Task 1", "Task 2", "Task 3"] + completed_tasks = await flow_instance.run_concurrent(tasks) + + # Ensure that all tasks are completed + assert len(completed_tasks) == len(tasks) + + +def test_flow_bulk_run(flow_instance): + # Test bulk running of tasks + input_data = [ + {"task": "Task 1", "param1": "value1"}, + {"task": "Task 2", "param2": "value2"}, + {"task": "Task 3", "param3": "value3"}, + ] + responses = flow_instance.bulk_run(input_data) + + # Ensure that the responses match the input tasks + assert responses[0] == "Response for Task 1" + assert responses[1] == "Response for Task 2" + assert responses[2] == "Response for Task 3" + + +def test_flow_from_llm_and_template(): + # Test creating Flow instance from an LLM and a template + llm_instance = mocked_llm # Replace with your LLM class + template = "This is a template for testing." + + flow_instance = Flow.from_llm_and_template(llm_instance, template) + + assert isinstance(flow_instance, Flow) + + +def test_flow_from_llm_and_template_file(): + # Test creating Flow instance from an LLM and a template file + llm_instance = mocked_llm # Replace with your LLM class + template_file = "template.txt" # Create a template file for testing + + flow_instance = Flow.from_llm_and_template_file(llm_instance, template_file) + + assert isinstance(flow_instance, Flow) + + +def test_flow_save_and_load(flow_instance, tmp_path): + # Test saving and loading the flow state + file_path = tmp_path / "flow_state.json" + + # Save the state + flow_instance.save(file_path) + + # Create a new instance and load the state + new_flow_instance = Flow(llm=mocked_llm, max_loops=5) + new_flow_instance.load(file_path) + + # Ensure that the loaded state matches the original state + assert new_flow_instance.memory == flow_instance.memory + + +def test_flow_validate_response(flow_instance): + # Test response validation + valid_response = "This is a valid response." + invalid_response = "Short." + + assert flow_instance.validate_response(valid_response) is True + assert flow_instance.validate_response(invalid_response) is False + + +# Add more test cases as needed for other methods and features of your Flow class + +# Finally, don't forget to run your tests using a testing framework like pytest + +# Assuming you have already defined your Flow class and created an instance for testing + + +def test_flow_print_history_and_memory(capsys, flow_instance): + # Test printing the history and memory of the flow + history = ["User: Hi", "AI: Hello"] + flow_instance.memory = [history] + + flow_instance.print_history_and_memory() + + captured = capsys.readouterr() + assert "Flow History and Memory" in captured.out + assert "Loop 1:" in captured.out + assert "User: Hi" in captured.out + assert "AI: Hello" in captured.out + + +def test_flow_run_with_timeout(flow_instance): + # Test running with a timeout + task = "Task with a long response time" + response = flow_instance.run_with_timeout(task, timeout=1) + + # Ensure that the response is either the actual response or "Timeout" + assert response in ["Actual Response", "Timeout"] + + +# Add more test cases as needed for other methods and features of your Flow class + +# Finally, don't forget to run your tests using a testing framework like pytest