[MIGRATION][Migrate SwarmMatcher, JsonFormer, and other packages and modules to a new package called swarms-utils to remove torch and transformers depedency] [next need to do numpy]
parent
f14b282a27
commit
e0da12be90
@ -1,256 +0,0 @@
|
|||||||
from typing import List, Optional, Dict, Any, Callable
|
|
||||||
from loguru import logger
|
|
||||||
from swarms.agents.exceptions import (
|
|
||||||
ToolExecutionError,
|
|
||||||
ToolValidationError,
|
|
||||||
ToolNotFoundError,
|
|
||||||
ToolParameterError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolAgent:
|
|
||||||
"""
|
|
||||||
A wrapper class for vLLM that provides a similar interface to LiteLLM.
|
|
||||||
This class handles model initialization and inference using vLLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
temperature: float = 0.5,
|
|
||||||
max_tokens: int = 4000,
|
|
||||||
max_completion_tokens: int = 4000,
|
|
||||||
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: str = "auto",
|
|
||||||
parallel_tool_calls: bool = False,
|
|
||||||
retry_attempts: int = 3,
|
|
||||||
retry_interval: float = 1.0,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the vLLM wrapper with the given parameters.
|
|
||||||
Args:
|
|
||||||
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
|
|
||||||
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
|
||||||
stream (bool): Whether to stream the output. Defaults to False.
|
|
||||||
temperature (float): The temperature for sampling. Defaults to 0.5.
|
|
||||||
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
|
|
||||||
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
|
|
||||||
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
|
|
||||||
tool_choice (str): How to choose tools. Defaults to "auto".
|
|
||||||
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
|
|
||||||
retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3.
|
|
||||||
retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0.
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.system_prompt = system_prompt
|
|
||||||
self.stream = stream
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
self.tools_list_dictionary = tools_list_dictionary
|
|
||||||
self.tool_choice = tool_choice
|
|
||||||
self.parallel_tool_calls = parallel_tool_calls
|
|
||||||
self.retry_attempts = retry_attempts
|
|
||||||
self.retry_interval = retry_interval
|
|
||||||
|
|
||||||
# Initialize vLLM
|
|
||||||
try:
|
|
||||||
self.llm = LLM(model=model_name, **kwargs)
|
|
||||||
self.sampling_params = SamplingParams(
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolExecutionError(
|
|
||||||
"model_initialization",
|
|
||||||
e,
|
|
||||||
{"model_name": model_name, "kwargs": kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_tool(
|
|
||||||
self, tool_name: str, parameters: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Validate tool parameters before execution.
|
|
||||||
Args:
|
|
||||||
tool_name (str): Name of the tool to validate
|
|
||||||
parameters (Dict[str, Any]): Parameters to validate
|
|
||||||
Raises:
|
|
||||||
ToolValidationError: If validation fails
|
|
||||||
"""
|
|
||||||
if not self.tools_list_dictionary:
|
|
||||||
raise ToolValidationError(
|
|
||||||
tool_name,
|
|
||||||
"parameters",
|
|
||||||
"No tools available for validation",
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_spec = next(
|
|
||||||
(
|
|
||||||
tool
|
|
||||||
for tool in self.tools_list_dictionary
|
|
||||||
if tool["name"] == tool_name
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_spec:
|
|
||||||
raise ToolNotFoundError(tool_name)
|
|
||||||
|
|
||||||
required_params = {
|
|
||||||
param["name"]
|
|
||||||
for param in tool_spec["parameters"]
|
|
||||||
if param.get("required", True)
|
|
||||||
}
|
|
||||||
|
|
||||||
missing_params = required_params - set(parameters.keys())
|
|
||||||
if missing_params:
|
|
||||||
raise ToolParameterError(
|
|
||||||
tool_name,
|
|
||||||
f"Missing required parameters: {', '.join(missing_params)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_with_retry(
|
|
||||||
self, func: Callable, *args, **kwargs
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Execute a function with retry logic.
|
|
||||||
Args:
|
|
||||||
func (Callable): Function to execute
|
|
||||||
*args: Positional arguments for the function
|
|
||||||
**kwargs: Keyword arguments for the function
|
|
||||||
Returns:
|
|
||||||
Any: Result of the function execution
|
|
||||||
Raises:
|
|
||||||
ToolExecutionError: If all retry attempts fail
|
|
||||||
"""
|
|
||||||
last_error = None
|
|
||||||
for attempt in range(self.retry_attempts):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
logger.warning(
|
|
||||||
f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}"
|
|
||||||
)
|
|
||||||
if attempt < self.retry_attempts - 1:
|
|
||||||
time.sleep(self.retry_interval)
|
|
||||||
|
|
||||||
raise ToolExecutionError(
|
|
||||||
func.__name__,
|
|
||||||
last_error,
|
|
||||||
{"attempts": self.retry_attempts},
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Run the tool agent for the specified task.
|
|
||||||
Args:
|
|
||||||
task (str): The task to be performed by the tool agent.
|
|
||||||
*args: Variable length argument list.
|
|
||||||
**kwargs: Arbitrary keyword arguments.
|
|
||||||
Returns:
|
|
||||||
The output of the tool agent.
|
|
||||||
Raises:
|
|
||||||
ToolExecutionError: If an error occurs during execution.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not self.llm:
|
|
||||||
raise ToolExecutionError(
|
|
||||||
"run",
|
|
||||||
Exception("LLM not initialized"),
|
|
||||||
{"task": task},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Running task: {task}")
|
|
||||||
|
|
||||||
# Prepare the prompt
|
|
||||||
prompt = self._prepare_prompt(task)
|
|
||||||
|
|
||||||
# Execute with retry logic
|
|
||||||
outputs = self._execute_with_retry(
|
|
||||||
self.llm.generate, prompt, self.sampling_params
|
|
||||||
)
|
|
||||||
|
|
||||||
response = outputs[0].outputs[0].text.strip()
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(f"Error running task: {error}")
|
|
||||||
raise ToolExecutionError(
|
|
||||||
"run",
|
|
||||||
error,
|
|
||||||
{"task": task, "args": args, "kwargs": kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_prompt(self, task: str) -> str:
|
|
||||||
"""
|
|
||||||
Prepare the prompt for the given task.
|
|
||||||
Args:
|
|
||||||
task (str): The task to prepare the prompt for.
|
|
||||||
Returns:
|
|
||||||
str: The prepared prompt.
|
|
||||||
"""
|
|
||||||
if self.system_prompt:
|
|
||||||
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
|
|
||||||
return f"User: {task}\nAssistant:"
|
|
||||||
|
|
||||||
def __call__(self, task: str, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Call the model for the given task.
|
|
||||||
Args:
|
|
||||||
task (str): The task to run the model for.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
Returns:
|
|
||||||
str: The model's response.
|
|
||||||
"""
|
|
||||||
return self.run(task, *args, **kwargs)
|
|
||||||
|
|
||||||
def batched_run(
|
|
||||||
self, tasks: List[str], batch_size: int = 10
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Run the model for multiple tasks in batches.
|
|
||||||
Args:
|
|
||||||
tasks (List[str]): List of tasks to run.
|
|
||||||
batch_size (int): Size of each batch. Defaults to 10.
|
|
||||||
Returns:
|
|
||||||
List[str]: List of model responses.
|
|
||||||
Raises:
|
|
||||||
ToolExecutionError: If an error occurs during batch execution.
|
|
||||||
"""
|
|
||||||
logger.info(
|
|
||||||
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
|
|
||||||
)
|
|
||||||
results = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i in range(0, len(tasks), batch_size):
|
|
||||||
batch = tasks[i : i + batch_size]
|
|
||||||
for task in batch:
|
|
||||||
logger.info(f"Running task: {task}")
|
|
||||||
try:
|
|
||||||
result = self.run(task)
|
|
||||||
results.append(result)
|
|
||||||
except ToolExecutionError as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to execute task '{task}': {e}"
|
|
||||||
)
|
|
||||||
results.append(f"Error: {str(e)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info("Completed all tasks.")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(f"Error in batch execution: {error}")
|
|
||||||
raise ToolExecutionError(
|
|
||||||
"batched_run",
|
|
||||||
error,
|
|
||||||
{"tasks": tasks, "batch_size": batch_size},
|
|
||||||
)
|
|
@ -1,599 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
||||||
|
|
||||||
from swarms.utils.loguru_logger import initialize_logger
|
|
||||||
|
|
||||||
logger = initialize_logger(log_folder="swarm_matcher")
|
|
||||||
|
|
||||||
|
|
||||||
class SwarmType(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
embedding: Optional[List[float]] = Field(
|
|
||||||
default=None, exclude=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SwarmMatcherConfig(BaseModel):
|
|
||||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
embedding_dim: int = (
|
|
||||||
512 # Dimension of the sentence-transformers model
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SwarmMatcher:
|
|
||||||
"""
|
|
||||||
A class for matching tasks to swarm types based on their descriptions.
|
|
||||||
It utilizes a transformer model to generate embeddings for task and swarm type descriptions,
|
|
||||||
and then calculates the dot product to find the best match.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: SwarmMatcherConfig):
|
|
||||||
"""
|
|
||||||
Initializes the SwarmMatcher with a configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (SwarmMatcherConfig): The configuration for the SwarmMatcher.
|
|
||||||
"""
|
|
||||||
logger.add("swarm_matcher_debug.log", level="DEBUG")
|
|
||||||
logger.debug("Initializing SwarmMatcher")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"torch package not found. Pip install torch."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import transformers
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"transformers package not found. Pip install transformers."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.torch = torch
|
|
||||||
try:
|
|
||||||
self.config = config
|
|
||||||
self.tokenizer = (
|
|
||||||
transformers.AutoTokenizer.from_pretrained(
|
|
||||||
config.model_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.model = transformers.AutoModel.from_pretrained(
|
|
||||||
config.model_name
|
|
||||||
)
|
|
||||||
self.swarm_types: List[SwarmType] = []
|
|
||||||
logger.debug("SwarmMatcher initialized successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing SwarmMatcher: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
)
|
|
||||||
def get_embedding(self, text: str) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Generates an embedding for a given text using the configured model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): The text for which to generate an embedding.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: The embedding vector for the text.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Getting embedding for text: {text[:50]}...")
|
|
||||||
try:
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
text,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=512,
|
|
||||||
)
|
|
||||||
with self.torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
embedding = (
|
|
||||||
outputs.last_hidden_state.mean(dim=1)
|
|
||||||
.squeeze()
|
|
||||||
.numpy()
|
|
||||||
)
|
|
||||||
logger.debug("Embedding generated successfully")
|
|
||||||
return embedding
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating embedding: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def add_swarm_type(self, swarm_type: SwarmType):
|
|
||||||
"""
|
|
||||||
Adds a swarm type to the list of swarm types, generating an embedding for its description.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
swarm_type (SwarmType): The swarm type to add.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Adding swarm type: {swarm_type.name}")
|
|
||||||
try:
|
|
||||||
embedding = self.get_embedding(swarm_type.description)
|
|
||||||
swarm_type.embedding = embedding.tolist()
|
|
||||||
self.swarm_types.append(swarm_type)
|
|
||||||
logger.info(f"Added swarm type: {swarm_type.name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error adding swarm type {swarm_type.name}: {str(e)}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def find_best_match(self, task: str) -> Tuple[str, float]:
|
|
||||||
"""
|
|
||||||
Finds the best match for a given task among the registered swarm types.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task for which to find the best match.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, float]: A tuple containing the name of the best matching swarm type and the score.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Finding best match for task: {task[:50]}...")
|
|
||||||
try:
|
|
||||||
task_embedding = self.get_embedding(task)
|
|
||||||
best_match = None
|
|
||||||
best_score = -float("inf")
|
|
||||||
for swarm_type in self.swarm_types:
|
|
||||||
score = np.dot(
|
|
||||||
task_embedding, np.array(swarm_type.embedding)
|
|
||||||
)
|
|
||||||
if score > best_score:
|
|
||||||
best_score = score
|
|
||||||
best_match = swarm_type
|
|
||||||
logger.info(
|
|
||||||
f"Best match for task: {best_match.name} (score: {best_score})"
|
|
||||||
)
|
|
||||||
return best_match.name, float(best_score)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error finding best match for task: {str(e)}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def auto_select_swarm(self, task: str) -> str:
|
|
||||||
"""
|
|
||||||
Automatically selects the best swarm type for a given task based on their descriptions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task for which to select a swarm type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The name of the selected swarm type.
|
|
||||||
"""
|
|
||||||
logger.debug(f"Auto-selecting swarm for task: {task[:50]}...")
|
|
||||||
best_match, score = self.find_best_match(task)
|
|
||||||
logger.info(f"Task: {task}")
|
|
||||||
logger.info(f"Selected Swarm Type: {best_match}")
|
|
||||||
logger.info(f"Confidence Score: {score:.2f}")
|
|
||||||
return best_match
|
|
||||||
|
|
||||||
def run_multiple(self, tasks: List[str], *args, **kwargs) -> str:
|
|
||||||
swarms = []
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
output = self.auto_select_swarm(task)
|
|
||||||
|
|
||||||
# Append
|
|
||||||
swarms.append(output)
|
|
||||||
|
|
||||||
return swarms
|
|
||||||
|
|
||||||
def save_swarm_types(self, filename: str):
|
|
||||||
"""
|
|
||||||
Saves the registered swarm types to a JSON file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (str): The name of the file to which to save the swarm types.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump([st.dict() for st in self.swarm_types], f)
|
|
||||||
logger.info(f"Saved swarm types to {filename}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving swarm types: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def load_swarm_types(self, filename: str):
|
|
||||||
"""
|
|
||||||
Loads swarm types from a JSON file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (str): The name of the file from which to load the swarm types.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
swarm_types_data = json.load(f)
|
|
||||||
self.swarm_types = [
|
|
||||||
SwarmType(**st) for st in swarm_types_data
|
|
||||||
]
|
|
||||||
logger.info(f"Loaded swarm types from {filename}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading swarm types: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_swarm_types(matcher: SwarmMatcher):
|
|
||||||
logger.debug("Initializing swarm types")
|
|
||||||
swarm_types = [
|
|
||||||
SwarmType(
|
|
||||||
name="AgentRearrange",
|
|
||||||
description="Optimize agent order and rearrange flow for multi-step tasks, ensuring efficient task allocation and minimizing bottlenecks. Keywords: orchestration, coordination, pipeline optimization, task scheduling, resource allocation, workflow management, agent organization, process optimization",
|
|
||||||
),
|
|
||||||
SwarmType(
|
|
||||||
name="MixtureOfAgents",
|
|
||||||
description="Combine diverse expert agents for comprehensive analysis, fostering a collaborative approach to problem-solving and leveraging individual strengths. Keywords: multi-agent system, expert collaboration, distributed intelligence, collective problem solving, agent specialization, team coordination, hybrid approaches, knowledge synthesis",
|
|
||||||
),
|
|
||||||
SwarmType(
|
|
||||||
name="SpreadSheetSwarm",
|
|
||||||
description="Collaborative data processing and analysis in a spreadsheet-like environment, facilitating real-time data sharing and visualization. Keywords: data analysis, tabular processing, collaborative editing, data transformation, spreadsheet operations, data visualization, real-time collaboration, structured data",
|
|
||||||
),
|
|
||||||
SwarmType(
|
|
||||||
name="SequentialWorkflow",
|
|
||||||
description="Execute tasks in a step-by-step, sequential process workflow, ensuring a logical and methodical approach to task execution. Keywords: linear processing, waterfall methodology, step-by-step execution, ordered tasks, sequential operations, process flow, systematic approach, staged execution",
|
|
||||||
),
|
|
||||||
SwarmType(
|
|
||||||
name="ConcurrentWorkflow",
|
|
||||||
description="Process multiple tasks or data sources concurrently in parallel, maximizing productivity and reducing processing time. Keywords: parallel processing, multi-threading, asynchronous execution, distributed computing, concurrent operations, simultaneous tasks, parallel workflows, scalable processing",
|
|
||||||
),
|
|
||||||
SwarmType(
|
|
||||||
name="HierarchicalSwarm",
|
|
||||||
description="Organize agents in a hierarchical structure with clear reporting lines and delegation of responsibilities. Keywords: management hierarchy, organizational structure, delegation, supervision, chain of command, tiered organization, structured coordination",
|
|
||||||
),
|
|
||||||
# SwarmType(
|
|
||||||
# name="AdaptiveSwarm",
|
|
||||||
# description="Dynamically adjust agent behavior and swarm configuration based on task requirements and performance feedback. Keywords: dynamic adaptation, self-optimization, feedback loops, learning systems, flexible configuration, responsive behavior, adaptive algorithms",
|
|
||||||
# ),
|
|
||||||
# SwarmType(
|
|
||||||
# name="ConsensusSwarm",
|
|
||||||
# description="Achieve group decisions through consensus mechanisms and voting protocols among multiple agents. Keywords: group decision making, voting systems, collective intelligence, agreement protocols, democratic processes, collaborative decisions",
|
|
||||||
# ),
|
|
||||||
]
|
|
||||||
|
|
||||||
for swarm_type in swarm_types:
|
|
||||||
matcher.add_swarm_type(swarm_type)
|
|
||||||
logger.debug("Swarm types initialized")
|
|
||||||
|
|
||||||
|
|
||||||
def swarm_matcher(task: str, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Runs the SwarmMatcher example with predefined tasks and swarm types.
|
|
||||||
"""
|
|
||||||
config = SwarmMatcherConfig()
|
|
||||||
matcher = SwarmMatcher(config)
|
|
||||||
initialize_swarm_types(matcher)
|
|
||||||
|
|
||||||
# matcher.save_swarm_types(f"swarm_logs/{uuid4().hex}.json")
|
|
||||||
|
|
||||||
swarm_type = matcher.auto_select_swarm(task)
|
|
||||||
|
|
||||||
logger.info(f"{swarm_type}")
|
|
||||||
|
|
||||||
return swarm_type
|
|
||||||
|
|
||||||
|
|
||||||
# from typing import List, Tuple, Dict
|
|
||||||
# from pydantic import BaseModel, Field
|
|
||||||
# from loguru import logger
|
|
||||||
# from uuid import uuid4
|
|
||||||
# import chromadb
|
|
||||||
# import json
|
|
||||||
# from tenacity import retry, stop_after_attempt, wait_exponential
|
|
||||||
|
|
||||||
|
|
||||||
# class SwarmType(BaseModel):
|
|
||||||
# """A swarm type with its name, description and optional metadata"""
|
|
||||||
|
|
||||||
# id: str = Field(default_factory=lambda: str(uuid4()))
|
|
||||||
# name: str
|
|
||||||
# description: str
|
|
||||||
# metadata: Dict = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# class SwarmMatcherConfig(BaseModel):
|
|
||||||
# """Configuration for the SwarmMatcher"""
|
|
||||||
|
|
||||||
# collection_name: str = "swarm_types"
|
|
||||||
# distance_metric: str = "cosine" # or "l2" or "ip"
|
|
||||||
# embedding_function: str = (
|
|
||||||
# "sentence-transformers/all-mpnet-base-v2" # Better model than MiniLM
|
|
||||||
# )
|
|
||||||
# persist_directory: str = "./chroma_db"
|
|
||||||
|
|
||||||
|
|
||||||
# class SwarmMatcher:
|
|
||||||
# """
|
|
||||||
# An improved swarm matcher that uses ChromaDB for better vector similarity search.
|
|
||||||
# Features:
|
|
||||||
# - Persistent storage of embeddings
|
|
||||||
# - Better vector similarity search with multiple distance metrics
|
|
||||||
# - Improved embedding model
|
|
||||||
# - Metadata filtering capabilities
|
|
||||||
# - Batch operations support
|
|
||||||
# """
|
|
||||||
|
|
||||||
# def __init__(self, config: SwarmMatcherConfig):
|
|
||||||
# """Initialize the improved swarm matcher"""
|
|
||||||
# logger.add("swarm_matcher.log", rotation="100 MB")
|
|
||||||
# self.config = config
|
|
||||||
|
|
||||||
# # Initialize ChromaDB client with persistence
|
|
||||||
# self.chroma_client = chromadb.Client()
|
|
||||||
|
|
||||||
# # Get or create collection
|
|
||||||
# try:
|
|
||||||
# self.collection = self.chroma_client.get_collection(
|
|
||||||
# name=config.collection_name,
|
|
||||||
# )
|
|
||||||
# except ValueError:
|
|
||||||
# self.collection = self.chroma_client.create_collection(
|
|
||||||
# name=config.collection_name,
|
|
||||||
# metadata={"hnsw:space": config.distance_metric},
|
|
||||||
# )
|
|
||||||
|
|
||||||
# logger.info(
|
|
||||||
# f"Initialized SwarmMatcher with collection '{config.collection_name}'"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def add_swarm_type(self, swarm_type: SwarmType) -> None:
|
|
||||||
# """Add a single swarm type to the collection"""
|
|
||||||
# try:
|
|
||||||
# self.collection.add(
|
|
||||||
# ids=[swarm_type.id],
|
|
||||||
# documents=[swarm_type.description],
|
|
||||||
# metadatas=[
|
|
||||||
# {"name": swarm_type.name, **swarm_type.metadata}
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# logger.info(f"Added swarm type: {swarm_type.name}")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(
|
|
||||||
# f"Error adding swarm type {swarm_type.name}: {str(e)}"
|
|
||||||
# )
|
|
||||||
# raise
|
|
||||||
|
|
||||||
# def add_swarm_types(self, swarm_types: List[SwarmType]) -> None:
|
|
||||||
# """Add multiple swarm types in batch"""
|
|
||||||
# try:
|
|
||||||
# self.collection.add(
|
|
||||||
# ids=[st.id for st in swarm_types],
|
|
||||||
# documents=[st.description for st in swarm_types],
|
|
||||||
# metadatas=[
|
|
||||||
# {"name": st.name, **st.metadata}
|
|
||||||
# for st in swarm_types
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# logger.info(f"Added {len(swarm_types)} swarm types")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(
|
|
||||||
# f"Error adding swarm types in batch: {str(e)}"
|
|
||||||
# )
|
|
||||||
# raise
|
|
||||||
|
|
||||||
# @retry(
|
|
||||||
# stop=stop_after_attempt(3),
|
|
||||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
# )
|
|
||||||
# def find_best_matches(
|
|
||||||
# self,
|
|
||||||
# task: str,
|
|
||||||
# n_results: int = 3,
|
|
||||||
# score_threshold: float = 0.7,
|
|
||||||
# ) -> List[Tuple[str, float]]:
|
|
||||||
# """
|
|
||||||
# Find the best matching swarm types for a given task
|
|
||||||
# Returns multiple matches with their scores
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# results = self.collection.query(
|
|
||||||
# query_texts=[task],
|
|
||||||
# n_results=n_results,
|
|
||||||
# include=["metadatas", "distances"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
# matches = []
|
|
||||||
# for metadata, distance in zip(
|
|
||||||
# results["metadatas"][0], results["distances"][0]
|
|
||||||
# ):
|
|
||||||
# # Convert distance to similarity score (1 - normalized_distance)
|
|
||||||
# score = 1 - (
|
|
||||||
# distance / 2
|
|
||||||
# ) # Normalize cosine distance to [0,1]
|
|
||||||
# if score >= score_threshold:
|
|
||||||
# matches.append((metadata["name"], score))
|
|
||||||
|
|
||||||
# logger.info(f"Found {len(matches)} matches for task")
|
|
||||||
# return matches
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error finding matches for task: {str(e)}")
|
|
||||||
# raise
|
|
||||||
|
|
||||||
# def auto_select_swarm(self, task: str) -> str:
|
|
||||||
# """
|
|
||||||
# Automatically select the best swarm type for a task
|
|
||||||
# Returns only the top match
|
|
||||||
# """
|
|
||||||
# matches = self.find_best_matches(task, n_results=1)
|
|
||||||
# if not matches:
|
|
||||||
# logger.warning("No suitable matches found for task")
|
|
||||||
# return "SequentialWorkflow" # Default fallback
|
|
||||||
|
|
||||||
# best_match, score = matches[0]
|
|
||||||
# logger.info(
|
|
||||||
# f"Selected swarm type '{best_match}' with confidence {score:.3f}"
|
|
||||||
# )
|
|
||||||
# return best_match
|
|
||||||
|
|
||||||
# def run_multiple(self, tasks: List[str]) -> List[str]:
|
|
||||||
# """Process multiple tasks in batch"""
|
|
||||||
# return [self.auto_select_swarm(task) for task in tasks]
|
|
||||||
|
|
||||||
# def save_swarm_types(self, filename: str) -> None:
|
|
||||||
# """Export swarm types to JSON"""
|
|
||||||
# try:
|
|
||||||
# all_data = self.collection.get(
|
|
||||||
# include=["metadatas", "documents"]
|
|
||||||
# )
|
|
||||||
# swarm_types = [
|
|
||||||
# SwarmType(
|
|
||||||
# id=id_,
|
|
||||||
# name=metadata["name"],
|
|
||||||
# description=document,
|
|
||||||
# metadata={
|
|
||||||
# k: v
|
|
||||||
# for k, v in metadata.items()
|
|
||||||
# if k != "name"
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# for id_, metadata, document in zip(
|
|
||||||
# all_data["ids"],
|
|
||||||
# all_data["metadatas"],
|
|
||||||
# all_data["documents"],
|
|
||||||
# )
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# with open(filename, "w") as f:
|
|
||||||
# json.dump(
|
|
||||||
# [st.dict() for st in swarm_types], f, indent=2
|
|
||||||
# )
|
|
||||||
# logger.info(f"Saved swarm types to {filename}")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error saving swarm types: {str(e)}")
|
|
||||||
# raise
|
|
||||||
|
|
||||||
# def load_swarm_types(self, filename: str) -> None:
|
|
||||||
# """Import swarm types from JSON"""
|
|
||||||
# try:
|
|
||||||
# with open(filename, "r") as f:
|
|
||||||
# swarm_types_data = json.load(f)
|
|
||||||
# swarm_types = [SwarmType(**st) for st in swarm_types_data]
|
|
||||||
# self.add_swarm_types(swarm_types)
|
|
||||||
# logger.info(f"Loaded swarm types from {filename}")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error loading swarm types: {str(e)}")
|
|
||||||
# raise
|
|
||||||
|
|
||||||
|
|
||||||
# def initialize_default_swarm_types(matcher: SwarmMatcher) -> None:
|
|
||||||
# """Initialize the matcher with default swarm types"""
|
|
||||||
# swarm_types = [
|
|
||||||
# SwarmType(
|
|
||||||
# name="AgentRearrange",
|
|
||||||
# description="""
|
|
||||||
# Optimize agent order and rearrange flow for multi-step tasks, ensuring efficient task allocation
|
|
||||||
# and minimizing bottlenecks. Specialized in orchestration, coordination, pipeline optimization,
|
|
||||||
# task scheduling, resource allocation, workflow management, agent organization, and process optimization.
|
|
||||||
# Best for tasks requiring complex agent interactions and workflow optimization.
|
|
||||||
# """,
|
|
||||||
# metadata={
|
|
||||||
# "category": "optimization",
|
|
||||||
# "complexity": "high",
|
|
||||||
# },
|
|
||||||
# ),
|
|
||||||
# SwarmType(
|
|
||||||
# name="MixtureOfAgents",
|
|
||||||
# description="""
|
|
||||||
# Combine diverse expert agents for comprehensive analysis, fostering a collaborative approach
|
|
||||||
# to problem-solving and leveraging individual strengths. Focuses on multi-agent systems,
|
|
||||||
# expert collaboration, distributed intelligence, collective problem solving, agent specialization,
|
|
||||||
# team coordination, hybrid approaches, and knowledge synthesis. Ideal for complex problems
|
|
||||||
# requiring multiple areas of expertise.
|
|
||||||
# """,
|
|
||||||
# metadata={
|
|
||||||
# "category": "collaboration",
|
|
||||||
# "complexity": "high",
|
|
||||||
# },
|
|
||||||
# ),
|
|
||||||
# SwarmType(
|
|
||||||
# name="SpreadSheetSwarm",
|
|
||||||
# description="""
|
|
||||||
# Collaborative data processing and analysis in a spreadsheet-like environment, facilitating
|
|
||||||
# real-time data sharing and visualization. Specializes in data analysis, tabular processing,
|
|
||||||
# collaborative editing, data transformation, spreadsheet operations, data visualization,
|
|
||||||
# real-time collaboration, and structured data handling. Perfect for data-intensive tasks
|
|
||||||
# requiring structured analysis.
|
|
||||||
# """,
|
|
||||||
# metadata={
|
|
||||||
# "category": "data_processing",
|
|
||||||
# "complexity": "medium",
|
|
||||||
# },
|
|
||||||
# ),
|
|
||||||
# SwarmType(
|
|
||||||
# name="SequentialWorkflow",
|
|
||||||
# description="""
|
|
||||||
# Execute tasks in a step-by-step, sequential process workflow, ensuring a logical and methodical
|
|
||||||
# approach to task execution. Focuses on linear processing, waterfall methodology, step-by-step
|
|
||||||
# execution, ordered tasks, sequential operations, process flow, systematic approach, and staged
|
|
||||||
# execution. Best for tasks requiring strict order and dependencies.
|
|
||||||
# """,
|
|
||||||
# metadata={"category": "workflow", "complexity": "low"},
|
|
||||||
# ),
|
|
||||||
# SwarmType(
|
|
||||||
# name="ConcurrentWorkflow",
|
|
||||||
# description="""
|
|
||||||
# Process multiple tasks or data sources concurrently in parallel, maximizing productivity
|
|
||||||
# and reducing processing time. Specializes in parallel processing, multi-threading,
|
|
||||||
# asynchronous execution, distributed computing, concurrent operations, simultaneous tasks,
|
|
||||||
# parallel workflows, and scalable processing. Ideal for independent tasks that can be
|
|
||||||
# processed simultaneously.
|
|
||||||
# """,
|
|
||||||
# metadata={"category": "workflow", "complexity": "medium"},
|
|
||||||
# ),
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# matcher.add_swarm_types(swarm_types)
|
|
||||||
# logger.info("Initialized default swarm types")
|
|
||||||
|
|
||||||
|
|
||||||
# def create_swarm_matcher(
|
|
||||||
# persist_dir: str = "./chroma_db",
|
|
||||||
# collection_name: str = "swarm_types",
|
|
||||||
# ) -> SwarmMatcher:
|
|
||||||
# """Convenience function to create and initialize a swarm matcher"""
|
|
||||||
# config = SwarmMatcherConfig(
|
|
||||||
# persist_directory=persist_dir, collection_name=collection_name
|
|
||||||
# )
|
|
||||||
# matcher = SwarmMatcher(config)
|
|
||||||
# initialize_default_swarm_types(matcher)
|
|
||||||
# return matcher
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage
|
|
||||||
# def swarm_matcher(task: str) -> str:
|
|
||||||
# # Create and initialize matcher
|
|
||||||
# matcher = create_swarm_matcher()
|
|
||||||
|
|
||||||
# swarm_type = matcher.auto_select_swarm(task)
|
|
||||||
# print(f"Task: {task}\nSelected Swarm: {swarm_type}\n")
|
|
||||||
|
|
||||||
# return swarm_type
|
|
||||||
|
|
||||||
|
|
||||||
# # # Example usage
|
|
||||||
# # if __name__ == "__main__":
|
|
||||||
# # # Create and initialize matcher
|
|
||||||
# # matcher = create_swarm_matcher()
|
|
||||||
|
|
||||||
# # # Example tasks
|
|
||||||
# # tasks = [
|
|
||||||
# # "Analyze this spreadsheet of sales data and create visualizations",
|
|
||||||
# # "Coordinate multiple AI agents to solve a complex problem",
|
|
||||||
# # "Process these tasks one after another in a specific order",
|
|
||||||
# # "Write multiple blog posts about the latest advancements in swarm intelligence all at once",
|
|
||||||
# # "Write a blog post about the latest advancements in swarm intelligence",
|
|
||||||
# # ]
|
|
||||||
|
|
||||||
# # # Process tasks
|
|
||||||
# # for task in tasks:
|
|
||||||
# # swarm_type = matcher.auto_select_swarm(task)
|
|
||||||
# # print(f"Task: {task}\nSelected Swarm: {swarm_type}\n")
|
|
@ -1,421 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Any, Dict, List, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from swarms.tools.logits_processor import (
|
|
||||||
NumberStoppingCriteria,
|
|
||||||
OutputNumbersTokens,
|
|
||||||
StringStoppingCriteria,
|
|
||||||
)
|
|
||||||
from swarms.utils.auto_download_check_packages import (
|
|
||||||
auto_check_and_download_package,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import transformers
|
|
||||||
except ImportError:
|
|
||||||
auto_check_and_download_package(
|
|
||||||
"transformers", package_manager="pip"
|
|
||||||
)
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
GENERATION_MARKER = "|GENERATION|"
|
|
||||||
|
|
||||||
|
|
||||||
class Jsonformer:
|
|
||||||
"""
|
|
||||||
Initializes the FormatTools class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (PreTrainedModel): The pre-trained model.
|
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer for the model.
|
|
||||||
json_schema (Dict[str, Any]): The JSON schema.
|
|
||||||
prompt (str): The prompt for generation.
|
|
||||||
|
|
||||||
Keyword Args:
|
|
||||||
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
|
||||||
max_array_length (int, optional): The maximum length of an array. Defaults to 10.
|
|
||||||
max_number_tokens (int, optional): The maximum number of tokens for numbers. Defaults to 6.
|
|
||||||
temperature (float, optional): The temperature for generation. Defaults to 1.0.
|
|
||||||
max_string_token_length (int, optional): The maximum length of a string token. Defaults to 10.
|
|
||||||
"""
|
|
||||||
|
|
||||||
value: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: transformers.PreTrainedModel = None, # type: ignore
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer = None, # type: ignore
|
|
||||||
json_schema: Union[Dict[str, Any], BaseModel] = None,
|
|
||||||
schemas: List[Union[Dict[str, Any], BaseModel]] = [],
|
|
||||||
prompt: str = None,
|
|
||||||
*,
|
|
||||||
debug: bool = False,
|
|
||||||
max_array_length: int = 10,
|
|
||||||
max_number_tokens: int = 6,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
max_string_token_length: int = 10,
|
|
||||||
llm: Any = None,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.json_schema = json_schema
|
|
||||||
self.prompt = prompt
|
|
||||||
self.llm = llm
|
|
||||||
self.schemas = schemas
|
|
||||||
|
|
||||||
self.number_logit_processor = OutputNumbersTokens(
|
|
||||||
self.tokenizer, self.prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
self.generation_marker = "|GENERATION|"
|
|
||||||
self.debug_on = debug
|
|
||||||
self.max_array_length = max_array_length
|
|
||||||
|
|
||||||
self.max_number_tokens = max_number_tokens
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_string_token_length = max_string_token_length
|
|
||||||
|
|
||||||
def generate_number(
|
|
||||||
self, temperature: Union[float, None] = None, iterations=0
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generates a number based on the given prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temperature (float, optional): The temperature value for number generation. Defaults to None.
|
|
||||||
iterations (int, optional): The number of iterations for generating a valid number. Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The generated number.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If a valid number cannot be generated after 3 iterations.
|
|
||||||
"""
|
|
||||||
if self.model:
|
|
||||||
prompt = self.get_prompt()
|
|
||||||
self.debug("[generate_number]", prompt, is_prompt=True)
|
|
||||||
input_tokens = self.tokenizer.encode(
|
|
||||||
prompt, return_tensors="pt"
|
|
||||||
).to(self.model.device)
|
|
||||||
|
|
||||||
response = self.model.generate(
|
|
||||||
input_tokens,
|
|
||||||
max_new_tokens=self.max_number_tokens,
|
|
||||||
num_return_sequences=1,
|
|
||||||
logits_processor=[self.number_logit_processor],
|
|
||||||
stopping_criteria=[
|
|
||||||
NumberStoppingCriteria(
|
|
||||||
self.tokenizer, len(input_tokens[0])
|
|
||||||
)
|
|
||||||
],
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
response = self.tokenizer.decode(
|
|
||||||
response[0], skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
response = response[len(prompt) :]
|
|
||||||
response = response.strip().rstrip(".")
|
|
||||||
self.debug("[generate_number]", response)
|
|
||||||
try:
|
|
||||||
return float(response)
|
|
||||||
except ValueError:
|
|
||||||
if iterations > 3:
|
|
||||||
raise ValueError(
|
|
||||||
"Failed to generate a valid number"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.generate_number(
|
|
||||||
temperature=self.temperature * 1.3,
|
|
||||||
iterations=iterations + 1,
|
|
||||||
)
|
|
||||||
elif self.llm:
|
|
||||||
prompt = self.get_prompt()
|
|
||||||
self.debug("[generate_number]", prompt, is_prompt=True)
|
|
||||||
response = self.llm(prompt)
|
|
||||||
response = response[len(prompt) :]
|
|
||||||
response = response.strip().rstrip(".")
|
|
||||||
self.debug("[generate_number]", response)
|
|
||||||
try:
|
|
||||||
return float(response)
|
|
||||||
except ValueError:
|
|
||||||
if iterations > 3:
|
|
||||||
raise ValueError(
|
|
||||||
"Failed to generate a valid number"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.generate_number(
|
|
||||||
temperature=self.temperature * 1.3,
|
|
||||||
iterations=iterations + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.llm and self.model:
|
|
||||||
raise ValueError("Both LLM and model cannot be None")
|
|
||||||
|
|
||||||
def generate_boolean(self) -> bool:
|
|
||||||
"""
|
|
||||||
Generates a boolean value based on the given prompt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: The generated boolean value.
|
|
||||||
"""
|
|
||||||
if self.model:
|
|
||||||
prompt = self.get_prompt()
|
|
||||||
self.debug("[generate_boolean]", prompt, is_prompt=True)
|
|
||||||
|
|
||||||
input_tensor = self.tokenizer.encode(
|
|
||||||
prompt, return_tensors="pt"
|
|
||||||
)
|
|
||||||
output = self.model.forward(
|
|
||||||
input_tensor.to(self.model.device)
|
|
||||||
)
|
|
||||||
logits = output.logits[0, -1]
|
|
||||||
|
|
||||||
# todo: this assumes that "true" and "false" are both tokenized to a single token
|
|
||||||
# this is probably not true for all tokenizers
|
|
||||||
# this can be fixed by looking at only the first token of both "true" and "false"
|
|
||||||
true_token_id = self.tokenizer.convert_tokens_to_ids(
|
|
||||||
"true"
|
|
||||||
)
|
|
||||||
false_token_id = self.tokenizer.convert_tokens_to_ids(
|
|
||||||
"false"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = logits[true_token_id] > logits[false_token_id]
|
|
||||||
|
|
||||||
self.debug("[generate_boolean]", result)
|
|
||||||
|
|
||||||
return result.item()
|
|
||||||
|
|
||||||
elif self.llm:
|
|
||||||
prompt = self.get_prompt()
|
|
||||||
self.debug("[generate_boolean]", prompt, is_prompt=True)
|
|
||||||
|
|
||||||
output = self.llm(prompt)
|
|
||||||
|
|
||||||
return output if output == "true" or "false" else None
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Both LLM and model cannot be None")
|
|
||||||
|
|
||||||
def generate_string(self) -> str:
|
|
||||||
if self.model:
|
|
||||||
prompt = self.get_prompt() + '"'
|
|
||||||
self.debug("[generate_string]", prompt, is_prompt=True)
|
|
||||||
input_tokens = self.tokenizer.encode(
|
|
||||||
prompt, return_tensors="pt"
|
|
||||||
).to(self.model.device)
|
|
||||||
|
|
||||||
response = self.model.generate(
|
|
||||||
input_tokens,
|
|
||||||
max_new_tokens=self.max_string_token_length,
|
|
||||||
num_return_sequences=1,
|
|
||||||
temperature=self.temperature,
|
|
||||||
stopping_criteria=[
|
|
||||||
StringStoppingCriteria(
|
|
||||||
self.tokenizer, len(input_tokens[0])
|
|
||||||
)
|
|
||||||
],
|
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Some models output the prompt as part of the response
|
|
||||||
# This removes the prompt from the response if it is present
|
|
||||||
if (
|
|
||||||
len(response[0]) >= len(input_tokens[0])
|
|
||||||
and (
|
|
||||||
response[0][: len(input_tokens[0])]
|
|
||||||
== input_tokens
|
|
||||||
).all()
|
|
||||||
):
|
|
||||||
response = response[0][len(input_tokens[0]) :]
|
|
||||||
if response.shape[0] == 1:
|
|
||||||
response = response[0]
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(
|
|
||||||
response, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.debug("[generate_string]", "|" + response + "|")
|
|
||||||
|
|
||||||
if response.count('"') < 1:
|
|
||||||
return response
|
|
||||||
|
|
||||||
return response.split('"')[0].strip()
|
|
||||||
|
|
||||||
elif self.llm:
|
|
||||||
prompt = self.get_prompt() + '"'
|
|
||||||
self.debug("[generate_string]", prompt, is_prompt=True)
|
|
||||||
|
|
||||||
response = self.llm(prompt)
|
|
||||||
|
|
||||||
# Some models output the prompt as part of the response
|
|
||||||
# This removes the prompt from the response if it is present
|
|
||||||
if (
|
|
||||||
len(response[0]) >= len(input_tokens[0])
|
|
||||||
and (
|
|
||||||
response[0][: len(input_tokens[0])]
|
|
||||||
== input_tokens
|
|
||||||
).all()
|
|
||||||
):
|
|
||||||
response = response[0][len(input_tokens[0]) :]
|
|
||||||
if response.shape[0] == 1:
|
|
||||||
response = response[0]
|
|
||||||
|
|
||||||
self.debug("[generate_string]", "|" + response + "|")
|
|
||||||
|
|
||||||
if response.count('"') < 1:
|
|
||||||
return response
|
|
||||||
|
|
||||||
return response.split('"')[0].strip()
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError("Both LLM and model cannot be None")
|
|
||||||
|
|
||||||
def generate_object(
|
|
||||||
self, properties: Dict[str, Any], obj: Dict[str, Any]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
for key, schema in properties.items():
|
|
||||||
self.debug("[generate_object] generating value for", key)
|
|
||||||
obj[key] = self.generate_value(schema, obj, key)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def generate_value(
|
|
||||||
self,
|
|
||||||
schema: Dict[str, Any],
|
|
||||||
obj: Union[Dict[str, Any], List[Any]],
|
|
||||||
key: Union[str, None] = None,
|
|
||||||
) -> Any:
|
|
||||||
schema_type = schema["type"]
|
|
||||||
if schema_type == "number":
|
|
||||||
if key:
|
|
||||||
obj[key] = self.generation_marker
|
|
||||||
else:
|
|
||||||
obj.append(self.generation_marker)
|
|
||||||
return self.generate_number()
|
|
||||||
elif schema_type == "boolean":
|
|
||||||
if key:
|
|
||||||
obj[key] = self.generation_marker
|
|
||||||
else:
|
|
||||||
obj.append(self.generation_marker)
|
|
||||||
return self.generate_boolean()
|
|
||||||
elif schema_type == "string":
|
|
||||||
if key:
|
|
||||||
obj[key] = self.generation_marker
|
|
||||||
else:
|
|
||||||
obj.append(self.generation_marker)
|
|
||||||
return self.generate_string()
|
|
||||||
elif schema_type == "array":
|
|
||||||
new_array = []
|
|
||||||
obj[key] = new_array
|
|
||||||
return self.generate_array(schema["items"], new_array)
|
|
||||||
elif schema_type == "object":
|
|
||||||
new_obj = {}
|
|
||||||
if key:
|
|
||||||
obj[key] = new_obj
|
|
||||||
else:
|
|
||||||
obj.append(new_obj)
|
|
||||||
return self.generate_object(schema["properties"], new_obj)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported schema type: {schema_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_array(
|
|
||||||
self, item_schema: Dict[str, Any], obj: Dict[str, Any]
|
|
||||||
) -> list:
|
|
||||||
if self.model:
|
|
||||||
for _ in range(self.max_array_length):
|
|
||||||
# forces array to have at least one element
|
|
||||||
element = self.generate_value(item_schema, obj)
|
|
||||||
obj[-1] = element
|
|
||||||
|
|
||||||
obj.append(self.generation_marker)
|
|
||||||
input_prompt = self.get_prompt()
|
|
||||||
obj.pop()
|
|
||||||
input_tensor = self.tokenizer.encode(
|
|
||||||
input_prompt, return_tensors="pt"
|
|
||||||
)
|
|
||||||
output = self.model.forward(
|
|
||||||
input_tensor.to(self.model.device)
|
|
||||||
)
|
|
||||||
logits = output.logits[0, -1]
|
|
||||||
|
|
||||||
top_indices = logits.topk(30).indices
|
|
||||||
sorted_token_ids = top_indices[
|
|
||||||
logits[top_indices].argsort(descending=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
found_comma = False
|
|
||||||
found_close_bracket = False
|
|
||||||
|
|
||||||
for token_id in sorted_token_ids:
|
|
||||||
decoded_token = self.tokenizer.decode(token_id)
|
|
||||||
if "," in decoded_token:
|
|
||||||
found_comma = True
|
|
||||||
break
|
|
||||||
if "]" in decoded_token:
|
|
||||||
found_close_bracket = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if found_close_bracket or not found_comma:
|
|
||||||
break
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
elif self.llm:
|
|
||||||
for _ in range(self.max_array_length):
|
|
||||||
# forces array to have at least one element
|
|
||||||
element = self.generate_value(item_schema, obj)
|
|
||||||
obj[-1] = element
|
|
||||||
|
|
||||||
obj.append(self.generation_marker)
|
|
||||||
input_prompt = self.get_prompt()
|
|
||||||
obj.pop()
|
|
||||||
output = self.llm(input_prompt)
|
|
||||||
|
|
||||||
found_comma = False
|
|
||||||
found_close_bracket = False
|
|
||||||
|
|
||||||
for token_id in output:
|
|
||||||
decoded_token = str(token_id)
|
|
||||||
if "," in decoded_token:
|
|
||||||
found_comma = True
|
|
||||||
break
|
|
||||||
if "]" in decoded_token:
|
|
||||||
found_close_bracket = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if found_close_bracket or not found_comma:
|
|
||||||
break
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def get_prompt(self):
|
|
||||||
template = """{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}"""
|
|
||||||
progress = json.dumps(self.value)
|
|
||||||
gen_marker_index = progress.find(
|
|
||||||
f'"{self.generation_marker}"'
|
|
||||||
)
|
|
||||||
if gen_marker_index != -1:
|
|
||||||
progress = progress[:gen_marker_index]
|
|
||||||
else:
|
|
||||||
raise ValueError("Failed to find generation marker")
|
|
||||||
|
|
||||||
prompt = template.format(
|
|
||||||
prompt=self.prompt,
|
|
||||||
schema=json.dumps(self.json_schema),
|
|
||||||
progress=progress,
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def __call__(self) -> Dict[str, Any]:
|
|
||||||
self.value = {}
|
|
||||||
generated_data = self.generate_object(
|
|
||||||
self.json_schema["properties"], self.value
|
|
||||||
)
|
|
||||||
return generated_data
|
|
@ -1,109 +0,0 @@
|
|||||||
from swarms.utils.auto_download_check_packages import (
|
|
||||||
auto_check_and_download_package,
|
|
||||||
)
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
auto_check_and_download_package(
|
|
||||||
"torch", package_manager="pip", upgrade=True
|
|
||||||
)
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
import transformers
|
|
||||||
except ImportError:
|
|
||||||
auto_check_and_download_package(
|
|
||||||
"transformers", package_manager="pip", upgrade=True
|
|
||||||
)
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
class StringStoppingCriteria(transformers.StoppingCriteria):
|
|
||||||
def __init__(
|
|
||||||
self, tokenizer: Any, prompt_length: int # type: ignore
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.prompt_length = prompt_length
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor, # type: ignore
|
|
||||||
_,
|
|
||||||
) -> bool:
|
|
||||||
if len(input_ids[0]) <= self.prompt_length:
|
|
||||||
return False
|
|
||||||
|
|
||||||
last_token_id = input_ids[0][-1]
|
|
||||||
last_token = self.tokenizer.decode(
|
|
||||||
last_token_id, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
result = '"' in last_token
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class NumberStoppingCriteria(transformers.StoppingCriteria):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: Any, # type: ignore
|
|
||||||
prompt_length: int,
|
|
||||||
precision: int = 3,
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.precision = precision
|
|
||||||
self.prompt_length = prompt_length
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor, # type: ignore
|
|
||||||
scores: torch.FloatTensor, # type: ignore
|
|
||||||
) -> bool:
|
|
||||||
decoded = self.tokenizer.decode(
|
|
||||||
input_ids[0][self.prompt_length :],
|
|
||||||
skip_special_tokens=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decoded.count(".") > 1:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if (
|
|
||||||
decoded.count(".") == 1
|
|
||||||
and len(decoded.strip().split(".")[1]) > self.precision
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if (
|
|
||||||
len(decoded) > 1
|
|
||||||
and any(c.isdigit() for c in decoded)
|
|
||||||
and decoded[-1] in [" ", "\n"]
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class OutputNumbersTokens(transformers.LogitsWarper):
|
|
||||||
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, prompt: str): # type: ignore
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
|
|
||||||
vocab_size = len(tokenizer)
|
|
||||||
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)
|
|
||||||
|
|
||||||
for _, token_id in tokenizer.get_vocab().items():
|
|
||||||
token_str = tokenizer.decode(token_id).strip()
|
|
||||||
|
|
||||||
if token_str == "" or (
|
|
||||||
all(c.isdigit() or c == "." for c in token_str)
|
|
||||||
and token_str.count(".") <= 1
|
|
||||||
):
|
|
||||||
self.allowed_mask[token_id] = True
|
|
||||||
|
|
||||||
def __call__(self, _, scores):
|
|
||||||
mask = self.allowed_mask.expand_as(scores)
|
|
||||||
scores[~mask] = -float("inf")
|
|
||||||
|
|
||||||
return scores
|
|
Loading…
Reference in new issue