parent
df222331f7
commit
a584ce6bb3
@ -0,0 +1,22 @@
|
|||||||
|
mkdocs
|
||||||
|
mkdocs-material
|
||||||
|
mkdocs-glightbox
|
||||||
|
mkdocs-git-authors-plugin
|
||||||
|
mkdocs-git-revision-date-plugin
|
||||||
|
mkdocs-git-committers-plugin
|
||||||
|
mkdocstrings
|
||||||
|
mike
|
||||||
|
mkdocs-jupyter
|
||||||
|
mkdocs-git-committers-plugin-2
|
||||||
|
mkdocs-git-revision-date-localized-plugin
|
||||||
|
mkdocs-redirects
|
||||||
|
mkdocs-material-extensions
|
||||||
|
mkdocs-simple-hooks
|
||||||
|
mkdocs-awesome-pages-plugin
|
||||||
|
mkdocs-versioning
|
||||||
|
mkdocs-mermaid2-plugin
|
||||||
|
mkdocs-include-markdown-plugin
|
||||||
|
mkdocs-enumerate-headings-plugin
|
||||||
|
mkdocs-autolinks-plugin
|
||||||
|
mkdocs-minify-html-plugin
|
||||||
|
mkdocs-autolinks-plugin
|
@ -0,0 +1,7 @@
|
|||||||
|
from swarms import llama3Hosted
|
||||||
|
|
||||||
|
llama3 = llama3Hosted()
|
||||||
|
|
||||||
|
task = "What is the capital of France?"
|
||||||
|
response = llama3.run(task)
|
||||||
|
print(response)
|
@ -0,0 +1,29 @@
|
|||||||
|
torch>=2.1.1,<3.0
|
||||||
|
transformers>=4.39.0,<5.0.0
|
||||||
|
asyncio>=3.4.3,<4.0
|
||||||
|
langchain-community==0.0.29
|
||||||
|
langchain-experimental==0.0.55
|
||||||
|
backoff==2.2.1
|
||||||
|
toml
|
||||||
|
pypdf==4.1.0
|
||||||
|
ratelimit==2.2.1
|
||||||
|
loguru==0.7.2
|
||||||
|
pydantic==2.7.1
|
||||||
|
tenacity==8.2.3
|
||||||
|
Pillow==10.3.0
|
||||||
|
psutil
|
||||||
|
sentry-sdk
|
||||||
|
python-dotenv
|
||||||
|
opencv-python-headless
|
||||||
|
PyYAML
|
||||||
|
docstring_parser==0.16
|
||||||
|
black>=23.1,<25.0
|
||||||
|
ruff>=0.0.249,<0.4.5
|
||||||
|
types-toml>=0.10.8.1
|
||||||
|
types-pytz>=2023.3,<2025.0
|
||||||
|
types-chardet>=5.0.4.6
|
||||||
|
mypy-protobuf>=3.0.0
|
||||||
|
pytest>=8.1.1
|
||||||
|
termcolor>=2.4.0
|
||||||
|
pandas>=2.2.2
|
||||||
|
fastapi>=0.110.1
|
@ -1,134 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
from swarms.models.base_llm import BaseLLM
|
|
||||||
from swarms.structs.message import Message
|
|
||||||
|
|
||||||
|
|
||||||
class Mistral(BaseLLM):
|
|
||||||
"""
|
|
||||||
Mistral is an all-new llm
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ai_name (str, optional): Name of the AI. Defaults to "Mistral".
|
|
||||||
system_prompt (str, optional): System prompt. Defaults to None.
|
|
||||||
model_name (str, optional): Model name. Defaults to "mistralai/Mistral-7B-v0.1".
|
|
||||||
device (str, optional): Device to use. Defaults to "cuda".
|
|
||||||
use_flash_attention (bool, optional): Whether to use flash attention. Defaults to False.
|
|
||||||
temperature (float, optional): Temperature. Defaults to 1.0.
|
|
||||||
max_length (int, optional): Max length. Defaults to 100.
|
|
||||||
do_sample (bool, optional): Whether to sample. Defaults to True.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from swarms.models import Mistral
|
|
||||||
|
|
||||||
model = Mistral(device="cuda", use_flash_attention=True, temperature=0.7, max_length=200)
|
|
||||||
|
|
||||||
task = "My favourite condiment is"
|
|
||||||
result = model.run(task)
|
|
||||||
print(result)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ai_name: str = "Node Model Agent",
|
|
||||||
system_prompt: str = None,
|
|
||||||
model_name: str = "mistralai/Mistral-7B-v0.1",
|
|
||||||
device: str = "cuda",
|
|
||||||
use_flash_attention: bool = False,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
max_length: int = 100,
|
|
||||||
do_sample: bool = True,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ai_name = ai_name
|
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.model_name = model_name
|
|
||||||
self.device = device
|
|
||||||
self.use_flash_attention = use_flash_attention
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_length = max_length
|
|
||||||
self.do_sample = do_sample
|
|
||||||
|
|
||||||
# Check if the specified device is available
|
|
||||||
if not torch.cuda.is_available() and device == "cuda":
|
|
||||||
raise ValueError(
|
|
||||||
"CUDA is not available. Please choose a different"
|
|
||||||
" device."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.history = []
|
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_name, *args, **kwargs
|
|
||||||
)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs):
|
|
||||||
"""Run the model on a given task."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_inputs = self.tokenizer([task], return_tensors="pt").to(
|
|
||||||
self.device
|
|
||||||
)
|
|
||||||
generated_ids = self.model.generate(
|
|
||||||
**model_inputs,
|
|
||||||
max_length=self.max_length,
|
|
||||||
do_sample=self.do_sample,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_new_tokens=self.max_length,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
output_text = self.tokenizer.batch_decode(generated_ids)[0]
|
|
||||||
return output_text
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Error running the model: {str(e)}")
|
|
||||||
|
|
||||||
def chat(self, msg: str = None, streaming: bool = False):
|
|
||||||
"""
|
|
||||||
Run chat
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msg (str, optional): Message to send to the agent. Defaults to None.
|
|
||||||
language (str, optional): Language to use. Defaults to None.
|
|
||||||
streaming (bool, optional): Whether to stream the response. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Response from the agent
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
--------------
|
|
||||||
agent = MultiModalAgent()
|
|
||||||
agent.chat("Hello")
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# add users message to the history
|
|
||||||
self.history.append(Message("User", msg))
|
|
||||||
|
|
||||||
# process msg
|
|
||||||
try:
|
|
||||||
response = self.agent.run(msg)
|
|
||||||
|
|
||||||
# add agent's response to the history
|
|
||||||
self.history.append(Message("Agent", response))
|
|
||||||
|
|
||||||
# if streaming is = True
|
|
||||||
if streaming:
|
|
||||||
return self._stream_response(response)
|
|
||||||
else:
|
|
||||||
response
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
error_message = f"Error processing message: {str(error)}"
|
|
||||||
|
|
||||||
# add error to history
|
|
||||||
self.history.append(Message("Agent", error_message))
|
|
||||||
|
|
||||||
return error_message
|
|
@ -1,75 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
from swarms.models.base_llm import BaseLLM
|
|
||||||
|
|
||||||
|
|
||||||
class Mixtral(BaseLLM):
|
|
||||||
"""Mixtral model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): The name or path of the pre-trained Mixtral model.
|
|
||||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from swarms.models import Mixtral
|
|
||||||
>>> mixtral = Mixtral()
|
|
||||||
>>> mixtral.run("Test task")
|
|
||||||
'Generated text'
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = "mistralai/Mixtral-8x7B-v0.1",
|
|
||||||
max_new_tokens: int = 500,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initializes a Mixtral model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name (str): The name or path of the pre-trained Mixtral model.
|
|
||||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
"""
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.model_name = model_name
|
|
||||||
self.max_new_tokens = max_new_tokens
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self, task: Optional[str] = None, **kwargs):
|
|
||||||
"""
|
|
||||||
Generates text based on the given task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str, optional): The task or prompt for text generation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The generated text.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
inputs = self.tokenizer(task, return_tensors="pt")
|
|
||||||
|
|
||||||
outputs = self.model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.tokenizer.decode(
|
|
||||||
outputs[0],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
except Exception as error:
|
|
||||||
print(f"There is an error: {error} in Mixtral model.")
|
|
||||||
raise error
|
|
@ -1,38 +0,0 @@
|
|||||||
import logging
|
|
||||||
from functools import wraps
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def block(
|
|
||||||
function: Callable[..., Any],
|
|
||||||
device: str = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
) -> Callable[..., Any]:
|
|
||||||
"""
|
|
||||||
A decorator that transforms a function into a block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function (Callable[..., Any]): The function to transform.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable[..., Any]: The transformed function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(function)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
# Here you can add code to execute the function on various hardwares
|
|
||||||
# For now, we'll just call the function normally
|
|
||||||
try:
|
|
||||||
return function(*args, **kwargs)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(f"Error in {function.__name__}: {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
# Set the wrapper function's name and docstring to those of the original function
|
|
||||||
wrapper.__name__ = function.__name__
|
|
||||||
wrapper.__doc__ = function.__doc__
|
|
||||||
|
|
||||||
return wrapper
|
|
@ -1,116 +0,0 @@
|
|||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from swarms.structs.base_structure import BaseStructure
|
|
||||||
|
|
||||||
|
|
||||||
class BlocksDict(BaseStructure):
|
|
||||||
"""
|
|
||||||
A class representing a dictionary of blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the blocks dictionary.
|
|
||||||
description (str): The description of the blocks dictionary.
|
|
||||||
blocks (Dict[str, Any]): The dictionary of blocks.
|
|
||||||
parent (Optional[Any], optional): The parent of the blocks dictionary. Defaults to None.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
parent (Optional[Any]): The parent of the blocks dictionary.
|
|
||||||
blocks (Dict[str, Any]): The dictionary of blocks.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
add(key: str, block: Any): Add a block to the dictionary.
|
|
||||||
remove(key: str): Remove a block from the dictionary.
|
|
||||||
get(key: str): Get a block from the dictionary.
|
|
||||||
update(key: str, block: Any): Update a block in the dictionary.
|
|
||||||
keys(): Get a list of keys in the dictionary.
|
|
||||||
values(): Get a list of values in the dictionary.
|
|
||||||
items(): Get a list of key-value pairs in the dictionary.
|
|
||||||
clear(): Clear all blocks from the dictionary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
blocks: Dict[str, Any],
|
|
||||||
parent: Optional[Any] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(name=name, description=description, **kwargs)
|
|
||||||
self.parent = parent
|
|
||||||
self.blocks = blocks
|
|
||||||
|
|
||||||
def add(self, key: str, block: Any):
|
|
||||||
"""
|
|
||||||
Add a block to the dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): The key of the block.
|
|
||||||
block (Any): The block to be added.
|
|
||||||
"""
|
|
||||||
self.blocks[key] = block
|
|
||||||
|
|
||||||
def remove(self, key: str):
|
|
||||||
"""
|
|
||||||
Remove a block from the dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): The key of the block to be removed.
|
|
||||||
"""
|
|
||||||
del self.blocks[key]
|
|
||||||
|
|
||||||
def get(self, key: str):
|
|
||||||
"""
|
|
||||||
Get a block from the dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): The key of the block to be retrieved.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Any: The retrieved block.
|
|
||||||
"""
|
|
||||||
return self.blocks.get(key)
|
|
||||||
|
|
||||||
def update(self, key: str, block: Any):
|
|
||||||
"""
|
|
||||||
Update a block in the dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): The key of the block to be updated.
|
|
||||||
block (Any): The updated block.
|
|
||||||
"""
|
|
||||||
self.blocks[key] = block
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
"""
|
|
||||||
Get a list of keys in the dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: A list of keys.
|
|
||||||
"""
|
|
||||||
return list(self.blocks.keys())
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
"""
|
|
||||||
Get a list of values in the dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Any]: A list of values.
|
|
||||||
"""
|
|
||||||
return list(self.blocks.values())
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
"""
|
|
||||||
Get a list of key-value pairs in the dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Tuple[str, Any]]: A list of key-value pairs.
|
|
||||||
"""
|
|
||||||
return list(self.blocks.items())
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
"""
|
|
||||||
Clear all blocks from the dictionary.
|
|
||||||
"""
|
|
||||||
self.blocks.clear()
|
|
@ -1,161 +0,0 @@
|
|||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
from swarms.structs.base_structure import BaseStructure
|
|
||||||
|
|
||||||
|
|
||||||
class BlocksList(BaseStructure):
|
|
||||||
"""
|
|
||||||
A class representing a list of blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): The name of the blocks list.
|
|
||||||
description (str): The description of the blocks list.
|
|
||||||
blocks (List[Any]): The list of blocks.
|
|
||||||
parent (Optional[Any], optional): The parent of the blocks list. Defaults to None.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
parent (Optional[Any]): The parent of the blocks list.
|
|
||||||
blocks (List[Any]): The list of blocks.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
add(block: Any): Add a block to the list.
|
|
||||||
remove(block: Any): Remove a block from the list.
|
|
||||||
update(block: Any): Update a block in the list.
|
|
||||||
get(index: int): Get a block at the specified index.
|
|
||||||
get_all(): Get all blocks in the list.
|
|
||||||
get_by_name(name: str): Get blocks by name.
|
|
||||||
get_by_type(type: str): Get blocks by type.
|
|
||||||
get_by_id(id: str): Get blocks by ID.
|
|
||||||
get_by_parent(parent: str): Get blocks by parent.
|
|
||||||
get_by_parent_id(parent_id: str): Get blocks by parent ID.
|
|
||||||
get_by_parent_name(parent_name: str): Get blocks by parent name.
|
|
||||||
get_by_parent_type(parent_type: str): Get blocks by parent type.
|
|
||||||
get_by_parent_description(parent_description: str): Get blocks by parent description.
|
|
||||||
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from swarms.structs.block import Block
|
|
||||||
>>> from swarms.structs.blockslist import BlocksList
|
|
||||||
>>> block = Block("block", "A block")
|
|
||||||
>>> blockslist = BlocksList("blockslist", "A list of blocks", [block])
|
|
||||||
>>> blockslist
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
blocks: List[Any],
|
|
||||||
parent: Optional[Any] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(name=name, description=description, **kwargs)
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.blocks = blocks
|
|
||||||
self.parent = parent
|
|
||||||
|
|
||||||
def add(self, block: Any):
|
|
||||||
self.blocks.append(block)
|
|
||||||
|
|
||||||
def remove(self, block: Any):
|
|
||||||
self.blocks.remove(block)
|
|
||||||
|
|
||||||
def update(self, block: Any):
|
|
||||||
self.blocks[self.blocks.index(block)] = block
|
|
||||||
|
|
||||||
def get(self, index: int):
|
|
||||||
return self.blocks[index]
|
|
||||||
|
|
||||||
def get_all(self):
|
|
||||||
return self.blocks
|
|
||||||
|
|
||||||
def run_block(self, block: Any, task: str, *args, **kwargs):
|
|
||||||
"""Run the block for the specified task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task to be performed by the block.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The output of the block.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If an error occurs during the execution of the block.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return block.run(task, *args, **kwargs)
|
|
||||||
except Exception as error:
|
|
||||||
print(f"[Error] [Block] {error}")
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def get_by_name(self, name: str):
|
|
||||||
return [block for block in self.blocks if block.name == name]
|
|
||||||
|
|
||||||
def get_by_type(self, type: str):
|
|
||||||
return [block for block in self.blocks if block.type == type]
|
|
||||||
|
|
||||||
def get_by_id(self, id: str):
|
|
||||||
return [block for block in self.blocks if block.id == id]
|
|
||||||
|
|
||||||
def get_by_parent(self, parent: str):
|
|
||||||
return [block for block in self.blocks if block.parent == parent]
|
|
||||||
|
|
||||||
def get_by_parent_id(self, parent_id: str):
|
|
||||||
return [
|
|
||||||
block for block in self.blocks if block.parent_id == parent_id
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_by_parent_name(self, parent_name: str):
|
|
||||||
return [
|
|
||||||
block
|
|
||||||
for block in self.blocks
|
|
||||||
if block.parent_name == parent_name
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_by_parent_type(self, parent_type: str):
|
|
||||||
return [
|
|
||||||
block
|
|
||||||
for block in self.blocks
|
|
||||||
if block.parent_type == parent_type
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_by_parent_description(self, parent_description: str):
|
|
||||||
return [
|
|
||||||
block
|
|
||||||
for block in self.blocks
|
|
||||||
if block.parent_description == parent_description
|
|
||||||
]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.blocks)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
return self.blocks[index]
|
|
||||||
|
|
||||||
def __setitem__(self, index, value):
|
|
||||||
self.blocks[index] = value
|
|
||||||
|
|
||||||
def __delitem__(self, index):
|
|
||||||
del self.blocks[index]
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.blocks)
|
|
||||||
|
|
||||||
def __reversed__(self):
|
|
||||||
return reversed(self.blocks)
|
|
||||||
|
|
||||||
def __contains__(self, item):
|
|
||||||
return item in self.blocks
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.name}({self.blocks})"
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.name}({self.blocks})"
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return self.blocks == other.blocks
|
|
@ -1,93 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Literal, Sequence
|
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
|
|
||||||
class Document(Serializable):
|
|
||||||
"""Class for storing a piece of text and associated metadata."""
|
|
||||||
|
|
||||||
page_content: str
|
|
||||||
"""String text."""
|
|
||||||
metadata: dict = Field(default_factory=dict)
|
|
||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
|
||||||
documents, etc.).
|
|
||||||
"""
|
|
||||||
type: Literal["Document"] = "Document"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this class is serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentTransformer(ABC):
|
|
||||||
"""Abstract base class for document transformation systems.
|
|
||||||
|
|
||||||
A document transformation system takes a sequence of Documents and returns a
|
|
||||||
sequence of transformed Documents.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
|
||||||
embeddings: Embeddings
|
|
||||||
similarity_fn: Callable = cosine_similarity
|
|
||||||
similarity_threshold: float = 0.95
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def transform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
stateful_documents = get_stateful_documents(documents)
|
|
||||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
|
||||||
self.embeddings, stateful_documents
|
|
||||||
)
|
|
||||||
included_idxs = _filter_similar_embeddings(
|
|
||||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
|
||||||
)
|
|
||||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
|
||||||
|
|
||||||
async def atransform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def transform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""Transform a list of documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: A sequence of Documents to be transformed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of transformed Documents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def atransform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""Asynchronously transform a list of documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: A sequence of Documents to be transformed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of transformed Documents.
|
|
||||||
"""
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
partial(self.transform_documents, **kwargs),
|
|
||||||
documents,
|
|
||||||
)
|
|
@ -1,151 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from swarms.structs.agent import Agent
|
|
||||||
from swarms.utils.parse_code import extract_code_from_markdown
|
|
||||||
|
|
||||||
|
|
||||||
class LongContextSwarmLeader:
|
|
||||||
"""
|
|
||||||
Represents a leader in a long context swarm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- llm (str): The language model to use for the agent.
|
|
||||||
- agents (List[Agent]): The agents in the swarm.
|
|
||||||
- prompt_template_json (str): The SOP template in JSON format.
|
|
||||||
- return_parsed (bool): Whether to return the parsed output.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
llm,
|
|
||||||
agents: List[Agent] = None,
|
|
||||||
prompt_template_json: str = None,
|
|
||||||
return_parsed: bool = False,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.llm = llm
|
|
||||||
self.agents = agents
|
|
||||||
self.prompt_template_json = prompt_template_json
|
|
||||||
self.return_parsed = return_parsed
|
|
||||||
|
|
||||||
# Create an instance of the Agent class
|
|
||||||
self.agent = Agent(
|
|
||||||
llm=llm,
|
|
||||||
system_prompt=None,
|
|
||||||
sop=self.prompt_template_json,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prep_schema(self, task: str, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Returns a formatted string containing the metadata of all agents in the swarm.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- task (str): The description of the task.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- prompt (str): The formatted string containing the agent metadata.
|
|
||||||
"""
|
|
||||||
prompt = f"""
|
|
||||||
|
|
||||||
You need to recruit a team of members to solve a
|
|
||||||
task. Select the appropriate member based on the
|
|
||||||
task description:
|
|
||||||
|
|
||||||
# Task Description
|
|
||||||
{task}
|
|
||||||
|
|
||||||
# Members
|
|
||||||
|
|
||||||
Your output must follow this JSON schema below in markdown format:
|
|
||||||
{{
|
|
||||||
"agent_id": "string",
|
|
||||||
"agent_name": "string",
|
|
||||||
"agent_description": "string"
|
|
||||||
}}
|
|
||||||
|
|
||||||
"""
|
|
||||||
for agent in self.agents:
|
|
||||||
prompt += (
|
|
||||||
f"Member Name: {agent.ai_name}\nMember ID:"
|
|
||||||
f" {agent.id}\nMember Description:"
|
|
||||||
f" {agent.description}\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def prep_schema_second(self, task_description: str, task: str):
|
|
||||||
prompt = f"""
|
|
||||||
You are the leader of a team of {len(self.agents)}
|
|
||||||
members. Your team will need to collaborate to
|
|
||||||
solve a task. The rule is:
|
|
||||||
|
|
||||||
1. Only you know the task description and task
|
|
||||||
objective; the other members do not.
|
|
||||||
2. But they will receive different documents that
|
|
||||||
may contain answers, and you need to send them
|
|
||||||
an instruction to query their document.
|
|
||||||
3. Your instruction need to include your
|
|
||||||
understanding of the task and what you need them
|
|
||||||
to focus on. If necessary, your instructions can
|
|
||||||
explicitly include the task objective.
|
|
||||||
4. Finally, you need to complete the task based on
|
|
||||||
the query results they return.
|
|
||||||
|
|
||||||
# Task Description:
|
|
||||||
{task_description}
|
|
||||||
|
|
||||||
# Task Objective:
|
|
||||||
{task}
|
|
||||||
|
|
||||||
# Generate Instruction for Members:
|
|
||||||
Now, you need to generate an instruction for all
|
|
||||||
team members. You can ask them to answer a
|
|
||||||
certain question, or to extract information related
|
|
||||||
to the task, based on their respective documents.
|
|
||||||
Your output must following the JSON
|
|
||||||
format: {{"type": "instruction", "content":
|
|
||||||
"your_instruction_content"}}
|
|
||||||
|
|
||||||
"""
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Executes the specified task using the agent's run method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The task to be executed.
|
|
||||||
*args: Additional positional arguments for the task.
|
|
||||||
**kwargs: Additional keyword arguments for the task.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of the task execution.
|
|
||||||
"""
|
|
||||||
task = self.prep_schema(task)
|
|
||||||
out = self.agent.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
if self.return_parsed:
|
|
||||||
out = extract_code_from_markdown(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# class LongContextSwarm(BaseSwarm):
|
|
||||||
# def __init__(
|
|
||||||
# self,
|
|
||||||
# agents: List[Agent],
|
|
||||||
# Leader: Agent,
|
|
||||||
# team_loops: int,
|
|
||||||
# *args,
|
|
||||||
# **kwargs,
|
|
||||||
# ):
|
|
||||||
# super().__init__()
|
|
||||||
# self.agents = agents
|
|
||||||
# self.leader = Leader
|
|
||||||
# self.team_loops = team_loops
|
|
||||||
# self.chunks = len(agents)
|
|
@ -1,181 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from typing import Callable, List
|
|
||||||
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelParallelizer:
|
|
||||||
"""
|
|
||||||
ModelParallelizer, a class that parallelizes the execution of a task
|
|
||||||
across multiple language models. It is a wrapper around the
|
|
||||||
LanguageModel class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llms (List[Callable]): A list of language models.
|
|
||||||
retry_attempts (int): The number of retry attempts.
|
|
||||||
max_loops (int): The number of iterations to run the task.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
llms (List[Callable]): A list of language models.
|
|
||||||
retry_attempts (int): The number of retry attempts.
|
|
||||||
max_loops (int): The number of iterations to run the task.
|
|
||||||
last_responses (List[str]): The last responses from the language
|
|
||||||
models.
|
|
||||||
task_history (List[str]): The task history.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from swarms.structs import ModelParallelizer
|
|
||||||
>>> from swarms.llms import OpenAIChat
|
|
||||||
>>> llms = [
|
|
||||||
... OpenAIChat(
|
|
||||||
... temperature=0.5,
|
|
||||||
... openai_api_key="OPENAI_API_KEY",
|
|
||||||
... ),
|
|
||||||
... OpenAIChat(
|
|
||||||
... temperature=0.5,
|
|
||||||
... openai_api_key="OPENAI_API_KEY",
|
|
||||||
... ),
|
|
||||||
... ]
|
|
||||||
>>> mp = ModelParallelizer(llms)
|
|
||||||
>>> mp.run("Generate a 10,000 word blog on health and wellness.")
|
|
||||||
['Generate a 10,000 word blog on health and wellness.', 'Generate a 10,000 word blog on health and wellness.']
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
llms: List[Callable] = None,
|
|
||||||
retry_attempts: int = 3,
|
|
||||||
max_loops: int = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.llms = llms
|
|
||||||
self.retry_attempts = retry_attempts
|
|
||||||
self.max_loops = max_loops
|
|
||||||
self.last_responses = None
|
|
||||||
self.task_history = []
|
|
||||||
|
|
||||||
def run(self, task: str):
|
|
||||||
"""Run the task string"""
|
|
||||||
try:
|
|
||||||
for i in range(self.max_loops):
|
|
||||||
with ThreadPoolExecutor() as executor:
|
|
||||||
responses = executor.map(
|
|
||||||
lambda llm: llm(task), self.llms
|
|
||||||
)
|
|
||||||
return list(responses)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_all(self, task):
|
|
||||||
"""Run the task on all LLMs"""
|
|
||||||
responses = []
|
|
||||||
for llm in self.llms:
|
|
||||||
responses.append(llm(task))
|
|
||||||
return responses
|
|
||||||
|
|
||||||
# New Features
|
|
||||||
def save_responses_to_file(self, filename):
|
|
||||||
"""Save responses to file"""
|
|
||||||
with open(filename, "w") as file:
|
|
||||||
table = [
|
|
||||||
[f"LLM {i+1}", response]
|
|
||||||
for i, response in enumerate(self.last_responses)
|
|
||||||
]
|
|
||||||
file.write(table)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_llms_from_file(cls, filename):
|
|
||||||
"""Load llms from file"""
|
|
||||||
with open(filename) as file:
|
|
||||||
llms = [line.strip() for line in file.readlines()]
|
|
||||||
return cls(llms)
|
|
||||||
|
|
||||||
def get_task_history(self):
|
|
||||||
"""Get Task history"""
|
|
||||||
return self.task_history
|
|
||||||
|
|
||||||
def summary(self):
|
|
||||||
"""Summary"""
|
|
||||||
print("Tasks History:")
|
|
||||||
for i, task in enumerate(self.task_history):
|
|
||||||
print(f"{i + 1}. {task}")
|
|
||||||
print("\nLast Responses:")
|
|
||||||
table = [
|
|
||||||
[f"LLM {i+1}", response]
|
|
||||||
for i, response in enumerate(self.last_responses)
|
|
||||||
]
|
|
||||||
print(
|
|
||||||
colored(
|
|
||||||
table,
|
|
||||||
"cyan",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def arun(self, task: str):
|
|
||||||
"""Asynchronous run the task string"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
futures = [
|
|
||||||
loop.run_in_executor(None, lambda llm: llm(task), llm)
|
|
||||||
for llm in self.llms
|
|
||||||
]
|
|
||||||
for response in await asyncio.gather(*futures):
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
def concurrent_run(self, task: str) -> List[str]:
|
|
||||||
"""Synchronously run the task on all llms and collect responses"""
|
|
||||||
try:
|
|
||||||
with ThreadPoolExecutor() as executor:
|
|
||||||
future_to_llm = {
|
|
||||||
executor.submit(llm, task): llm for llm in self.llms
|
|
||||||
}
|
|
||||||
responses = []
|
|
||||||
for future in as_completed(future_to_llm):
|
|
||||||
try:
|
|
||||||
responses.append(future.result())
|
|
||||||
except Exception as error:
|
|
||||||
print(
|
|
||||||
f"{future_to_llm[future]} generated an"
|
|
||||||
f" exception: {error}"
|
|
||||||
)
|
|
||||||
self.last_responses = responses
|
|
||||||
self.task_history.append(task)
|
|
||||||
return responses
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def add_llm(self, llm: Callable):
|
|
||||||
"""Add an llm to the god mode"""
|
|
||||||
logger.info(f"[INFO][ModelParallelizer] Adding LLM {llm}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.llms.append(llm)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
|
|
||||||
)
|
|
||||||
raise error
|
|
||||||
|
|
||||||
def remove_llm(self, llm: Callable):
|
|
||||||
"""Remove an llm from the god mode"""
|
|
||||||
logger.info(f"[INFO][ModelParallelizer] Removing LLM {llm}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.llms.remove(llm)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
|
|
||||||
)
|
|
||||||
raise error
|
|
@ -1,141 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from swarms.models.llama_function_caller import LlamaFunctionCaller
|
|
||||||
|
|
||||||
|
|
||||||
# Define fixtures if needed
|
|
||||||
@pytest.fixture
|
|
||||||
def llama_caller():
|
|
||||||
# Initialize the LlamaFunctionCaller with a sample model
|
|
||||||
return LlamaFunctionCaller()
|
|
||||||
|
|
||||||
|
|
||||||
# Basic test for model loading
|
|
||||||
def test_llama_model_loading(llama_caller):
|
|
||||||
assert llama_caller.model is not None
|
|
||||||
assert llama_caller.tokenizer is not None
|
|
||||||
|
|
||||||
|
|
||||||
# Test adding and calling custom functions
|
|
||||||
def test_llama_custom_function(llama_caller):
|
|
||||||
def sample_function(arg1, arg2):
|
|
||||||
return f"Sample function called with args: {arg1}, {arg2}"
|
|
||||||
|
|
||||||
llama_caller.add_func(
|
|
||||||
name="sample_function",
|
|
||||||
function=sample_function,
|
|
||||||
description="Sample custom function",
|
|
||||||
arguments=[
|
|
||||||
{
|
|
||||||
"name": "arg1",
|
|
||||||
"type": "string",
|
|
||||||
"description": "Argument 1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "arg2",
|
|
||||||
"type": "string",
|
|
||||||
"description": "Argument 2",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = llama_caller.call_function(
|
|
||||||
"sample_function", arg1="arg1_value", arg2="arg2_value"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result
|
|
||||||
== "Sample function called with args: arg1_value, arg2_value"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Test streaming user prompts
|
|
||||||
def test_llama_streaming(llama_caller):
|
|
||||||
user_prompt = "Tell me about the tallest mountain in the world."
|
|
||||||
response = llama_caller(user_prompt)
|
|
||||||
assert isinstance(response, str)
|
|
||||||
assert len(response) > 0
|
|
||||||
|
|
||||||
|
|
||||||
# Test custom function not found
|
|
||||||
def test_llama_custom_function_not_found(llama_caller):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
llama_caller.call_function("non_existent_function")
|
|
||||||
|
|
||||||
|
|
||||||
# Test invalid arguments for custom function
|
|
||||||
def test_llama_custom_function_invalid_arguments(llama_caller):
|
|
||||||
def sample_function(arg1, arg2):
|
|
||||||
return f"Sample function called with args: {arg1}, {arg2}"
|
|
||||||
|
|
||||||
llama_caller.add_func(
|
|
||||||
name="sample_function",
|
|
||||||
function=sample_function,
|
|
||||||
description="Sample custom function",
|
|
||||||
arguments=[
|
|
||||||
{
|
|
||||||
"name": "arg1",
|
|
||||||
"type": "string",
|
|
||||||
"description": "Argument 1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "arg2",
|
|
||||||
"type": "string",
|
|
||||||
"description": "Argument 2",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
llama_caller.call_function("sample_function", arg1="arg1_value")
|
|
||||||
|
|
||||||
|
|
||||||
# Test streaming with custom runtime
|
|
||||||
def test_llama_custom_runtime():
|
|
||||||
llama_caller = LlamaFunctionCaller(
|
|
||||||
model_id="Your-Model-ID",
|
|
||||||
cache_dir="Your-Cache-Directory",
|
|
||||||
runtime="cuda",
|
|
||||||
)
|
|
||||||
user_prompt = "Tell me about the tallest mountain in the world."
|
|
||||||
response = llama_caller(user_prompt)
|
|
||||||
assert isinstance(response, str)
|
|
||||||
assert len(response) > 0
|
|
||||||
|
|
||||||
|
|
||||||
# Test caching functionality
|
|
||||||
def test_llama_cache():
|
|
||||||
llama_caller = LlamaFunctionCaller(
|
|
||||||
model_id="Your-Model-ID",
|
|
||||||
cache_dir="Your-Cache-Directory",
|
|
||||||
runtime="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform a request to populate the cache
|
|
||||||
user_prompt = "Tell me about the tallest mountain in the world."
|
|
||||||
response = llama_caller(user_prompt)
|
|
||||||
|
|
||||||
# Check if the response is retrieved from the cache
|
|
||||||
llama_caller.model.from_cache = True
|
|
||||||
response_from_cache = llama_caller(user_prompt)
|
|
||||||
assert response == response_from_cache
|
|
||||||
|
|
||||||
|
|
||||||
# Test response length within max_tokens limit
|
|
||||||
def test_llama_response_length():
|
|
||||||
llama_caller = LlamaFunctionCaller(
|
|
||||||
model_id="Your-Model-ID",
|
|
||||||
cache_dir="Your-Cache-Directory",
|
|
||||||
runtime="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate a long prompt
|
|
||||||
long_prompt = "A " + "test " * 100 # Approximately 500 tokens
|
|
||||||
|
|
||||||
# Ensure the response does not exceed max_tokens
|
|
||||||
response = llama_caller(long_prompt)
|
|
||||||
assert len(response.split()) <= 500
|
|
||||||
|
|
||||||
|
|
||||||
# Add more test cases as needed to cover different aspects of your code
|
|
||||||
|
|
||||||
# ...
|
|
@ -1,46 +0,0 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from swarms.models.mistral import Mistral
|
|
||||||
|
|
||||||
|
|
||||||
def test_mistral_initialization():
|
|
||||||
mistral = Mistral(device="cpu")
|
|
||||||
assert isinstance(mistral, Mistral)
|
|
||||||
assert mistral.ai_name == "Node Model Agent"
|
|
||||||
assert mistral.system_prompt is None
|
|
||||||
assert mistral.model_name == "mistralai/Mistral-7B-v0.1"
|
|
||||||
assert mistral.device == "cpu"
|
|
||||||
assert mistral.use_flash_attention is False
|
|
||||||
assert mistral.temperature == 1.0
|
|
||||||
assert mistral.max_length == 100
|
|
||||||
assert mistral.history == []
|
|
||||||
|
|
||||||
|
|
||||||
@patch("your_module.AutoModelForCausalLM.from_pretrained")
|
|
||||||
@patch("your_module.AutoTokenizer.from_pretrained")
|
|
||||||
def test_mistral_load_model(mock_tokenizer, mock_model):
|
|
||||||
mistral = Mistral(device="cpu")
|
|
||||||
mistral.load_model()
|
|
||||||
mock_model.assert_called_once()
|
|
||||||
mock_tokenizer.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("your_module.Mistral.load_model")
|
|
||||||
def test_mistral_run(mock_load_model):
|
|
||||||
mistral = Mistral(device="cpu")
|
|
||||||
mistral.run("What's the weather in miami")
|
|
||||||
mock_load_model.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("your_module.Mistral.run")
|
|
||||||
def test_mistral_chat(mock_run):
|
|
||||||
mistral = Mistral(device="cpu")
|
|
||||||
mistral.chat("What's the weather in miami")
|
|
||||||
mock_run.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_mistral__stream_response():
|
|
||||||
mistral = Mistral(device="cpu")
|
|
||||||
response = "It's sunny in Miami."
|
|
||||||
tokens = list(mistral._stream_response(response))
|
|
||||||
assert tokens == ["It's", "sunny", "in", "Miami."]
|
|
@ -1,51 +0,0 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from swarms.models.mixtral import Mixtral
|
|
||||||
|
|
||||||
|
|
||||||
@patch("swarms.models.mixtral.AutoTokenizer")
|
|
||||||
@patch("swarms.models.mixtral.AutoModelForCausalLM")
|
|
||||||
def test_mixtral_init(mock_model, mock_tokenizer):
|
|
||||||
mixtral = Mixtral()
|
|
||||||
mock_tokenizer.from_pretrained.assert_called_once()
|
|
||||||
mock_model.from_pretrained.assert_called_once()
|
|
||||||
assert mixtral.model_name == "mistralai/Mixtral-8x7B-v0.1"
|
|
||||||
assert mixtral.max_new_tokens == 20
|
|
||||||
|
|
||||||
|
|
||||||
@patch("swarms.models.mixtral.AutoTokenizer")
|
|
||||||
@patch("swarms.models.mixtral.AutoModelForCausalLM")
|
|
||||||
def test_mixtral_run(mock_model, mock_tokenizer):
|
|
||||||
mixtral = Mixtral()
|
|
||||||
mock_tokenizer_instance = MagicMock()
|
|
||||||
mock_model_instance = MagicMock()
|
|
||||||
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
|
|
||||||
mock_model.from_pretrained.return_value = mock_model_instance
|
|
||||||
mock_tokenizer_instance.return_tensors = "pt"
|
|
||||||
mock_model_instance.generate.return_value = [101, 102, 103]
|
|
||||||
mock_tokenizer_instance.decode.return_value = "Generated text"
|
|
||||||
result = mixtral.run("Test task")
|
|
||||||
assert result == "Generated text"
|
|
||||||
mock_tokenizer_instance.assert_called_once_with(
|
|
||||||
"Test task", return_tensors="pt"
|
|
||||||
)
|
|
||||||
mock_model_instance.generate.assert_called_once()
|
|
||||||
mock_tokenizer_instance.decode.assert_called_once_with(
|
|
||||||
[101, 102, 103], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@patch("swarms.models.mixtral.AutoTokenizer")
|
|
||||||
@patch("swarms.models.mixtral.AutoModelForCausalLM")
|
|
||||||
def test_mixtral_run_error(mock_model, mock_tokenizer):
|
|
||||||
mixtral = Mixtral()
|
|
||||||
mock_tokenizer_instance = MagicMock()
|
|
||||||
mock_model_instance = MagicMock()
|
|
||||||
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
|
|
||||||
mock_model.from_pretrained.return_value = mock_model_instance
|
|
||||||
mock_tokenizer_instance.return_tensors = "pt"
|
|
||||||
mock_model_instance.generate.side_effect = Exception("Test error")
|
|
||||||
with pytest.raises(Exception, match="Test error"):
|
|
||||||
mixtral.run("Test task")
|
|
@ -1,86 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
from swarms.models.mpt import MPT7B
|
|
||||||
|
|
||||||
|
|
||||||
def test_mpt7b_init():
|
|
||||||
mpt = MPT7B(
|
|
||||||
"mosaicml/mpt-7b-storywriter",
|
|
||||||
"EleutherAI/gpt-neox-20b",
|
|
||||||
max_tokens=150,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(mpt, MPT7B)
|
|
||||||
assert mpt.model_name == "mosaicml/mpt-7b-storywriter"
|
|
||||||
assert mpt.tokenizer_name == "EleutherAI/gpt-neox-20b"
|
|
||||||
assert isinstance(mpt.tokenizer, AutoTokenizer)
|
|
||||||
assert isinstance(mpt.model, AutoModelForCausalLM)
|
|
||||||
assert mpt.max_tokens == 150
|
|
||||||
|
|
||||||
|
|
||||||
def test_mpt7b_run():
|
|
||||||
mpt = MPT7B(
|
|
||||||
"mosaicml/mpt-7b-storywriter",
|
|
||||||
"EleutherAI/gpt-neox-20b",
|
|
||||||
max_tokens=150,
|
|
||||||
)
|
|
||||||
output = mpt.run(
|
|
||||||
"generate", "Once upon a time in a land far, far away..."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(output, str)
|
|
||||||
assert output.startswith("Once upon a time in a land far, far away...")
|
|
||||||
|
|
||||||
|
|
||||||
def test_mpt7b_run_invalid_task():
|
|
||||||
mpt = MPT7B(
|
|
||||||
"mosaicml/mpt-7b-storywriter",
|
|
||||||
"EleutherAI/gpt-neox-20b",
|
|
||||||
max_tokens=150,
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
mpt.run(
|
|
||||||
"invalid_task",
|
|
||||||
"Once upon a time in a land far, far away...",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mpt7b_generate():
|
|
||||||
mpt = MPT7B(
|
|
||||||
"mosaicml/mpt-7b-storywriter",
|
|
||||||
"EleutherAI/gpt-neox-20b",
|
|
||||||
max_tokens=150,
|
|
||||||
)
|
|
||||||
output = mpt.generate("Once upon a time in a land far, far away...")
|
|
||||||
|
|
||||||
assert isinstance(output, str)
|
|
||||||
assert output.startswith("Once upon a time in a land far, far away...")
|
|
||||||
|
|
||||||
|
|
||||||
def test_mpt7b_batch_generate():
|
|
||||||
mpt = MPT7B(
|
|
||||||
"mosaicml/mpt-7b-storywriter",
|
|
||||||
"EleutherAI/gpt-neox-20b",
|
|
||||||
max_tokens=150,
|
|
||||||
)
|
|
||||||
prompts = ["In the deep jungles,", "At the heart of the city,"]
|
|
||||||
outputs = mpt.batch_generate(prompts, temperature=0.7)
|
|
||||||
|
|
||||||
assert isinstance(outputs, list)
|
|
||||||
assert len(outputs) == len(prompts)
|
|
||||||
for output in outputs:
|
|
||||||
assert isinstance(output, str)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mpt7b_unfreeze_model():
|
|
||||||
mpt = MPT7B(
|
|
||||||
"mosaicml/mpt-7b-storywriter",
|
|
||||||
"EleutherAI/gpt-neox-20b",
|
|
||||||
max_tokens=150,
|
|
||||||
)
|
|
||||||
mpt.unfreeze_model()
|
|
||||||
|
|
||||||
for param in mpt.model.parameters():
|
|
||||||
assert param.requires_grad
|
|
Loading…
Reference in new issue