[MIGRATION][Migrate SwarmMatcher, JsonFormer, and other packages and modules to a new package called swarms-utils to remove torch and transformers depedency] [next need to do numpy]
parent
f14b282a27
commit
e0da12be90
@ -1,256 +0,0 @@
|
||||
from typing import List, Optional, Dict, Any, Callable
|
||||
from loguru import logger
|
||||
from swarms.agents.exceptions import (
|
||||
ToolExecutionError,
|
||||
ToolValidationError,
|
||||
ToolNotFoundError,
|
||||
ToolParameterError,
|
||||
)
|
||||
|
||||
|
||||
class ToolAgent:
|
||||
"""
|
||||
A wrapper class for vLLM that provides a similar interface to LiteLLM.
|
||||
This class handles model initialization and inference using vLLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
|
||||
system_prompt: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 4000,
|
||||
max_completion_tokens: int = 4000,
|
||||
tools_list_dictionary: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: str = "auto",
|
||||
parallel_tool_calls: bool = False,
|
||||
retry_attempts: int = 3,
|
||||
retry_interval: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the vLLM wrapper with the given parameters.
|
||||
Args:
|
||||
model_name (str): The name of the model to use. Defaults to "meta-llama/Llama-2-7b-chat-hf".
|
||||
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||
stream (bool): Whether to stream the output. Defaults to False.
|
||||
temperature (float): The temperature for sampling. Defaults to 0.5.
|
||||
max_tokens (int): The maximum number of tokens to generate. Defaults to 4000.
|
||||
max_completion_tokens (int): The maximum number of completion tokens. Defaults to 4000.
|
||||
tools_list_dictionary (List[Dict[str, Any]], optional): List of available tools. Defaults to None.
|
||||
tool_choice (str): How to choose tools. Defaults to "auto".
|
||||
parallel_tool_calls (bool): Whether to allow parallel tool calls. Defaults to False.
|
||||
retry_attempts (int): Number of retry attempts for failed operations. Defaults to 3.
|
||||
retry_interval (float): Time to wait between retries in seconds. Defaults to 1.0.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.system_prompt = system_prompt
|
||||
self.stream = stream
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tools_list_dictionary = tools_list_dictionary
|
||||
self.tool_choice = tool_choice
|
||||
self.parallel_tool_calls = parallel_tool_calls
|
||||
self.retry_attempts = retry_attempts
|
||||
self.retry_interval = retry_interval
|
||||
|
||||
# Initialize vLLM
|
||||
try:
|
||||
self.llm = LLM(model=model_name, **kwargs)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolExecutionError(
|
||||
"model_initialization",
|
||||
e,
|
||||
{"model_name": model_name, "kwargs": kwargs},
|
||||
)
|
||||
|
||||
def _validate_tool(
|
||||
self, tool_name: str, parameters: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Validate tool parameters before execution.
|
||||
Args:
|
||||
tool_name (str): Name of the tool to validate
|
||||
parameters (Dict[str, Any]): Parameters to validate
|
||||
Raises:
|
||||
ToolValidationError: If validation fails
|
||||
"""
|
||||
if not self.tools_list_dictionary:
|
||||
raise ToolValidationError(
|
||||
tool_name,
|
||||
"parameters",
|
||||
"No tools available for validation",
|
||||
)
|
||||
|
||||
tool_spec = next(
|
||||
(
|
||||
tool
|
||||
for tool in self.tools_list_dictionary
|
||||
if tool["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not tool_spec:
|
||||
raise ToolNotFoundError(tool_name)
|
||||
|
||||
required_params = {
|
||||
param["name"]
|
||||
for param in tool_spec["parameters"]
|
||||
if param.get("required", True)
|
||||
}
|
||||
|
||||
missing_params = required_params - set(parameters.keys())
|
||||
if missing_params:
|
||||
raise ToolParameterError(
|
||||
tool_name,
|
||||
f"Missing required parameters: {', '.join(missing_params)}",
|
||||
)
|
||||
|
||||
def _execute_with_retry(
|
||||
self, func: Callable, *args, **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a function with retry logic.
|
||||
Args:
|
||||
func (Callable): Function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
Returns:
|
||||
Any: Result of the function execution
|
||||
Raises:
|
||||
ToolExecutionError: If all retry attempts fail
|
||||
"""
|
||||
last_error = None
|
||||
for attempt in range(self.retry_attempts):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1}/{self.retry_attempts} failed: {str(e)}"
|
||||
)
|
||||
if attempt < self.retry_attempts - 1:
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
raise ToolExecutionError(
|
||||
func.__name__,
|
||||
last_error,
|
||||
{"attempts": self.retry_attempts},
|
||||
)
|
||||
|
||||
def run(self, task: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
Run the tool agent for the specified task.
|
||||
Args:
|
||||
task (str): The task to be performed by the tool agent.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
Returns:
|
||||
The output of the tool agent.
|
||||
Raises:
|
||||
ToolExecutionError: If an error occurs during execution.
|
||||
"""
|
||||
try:
|
||||
if not self.llm:
|
||||
raise ToolExecutionError(
|
||||
"run",
|
||||
Exception("LLM not initialized"),
|
||||
{"task": task},
|
||||
)
|
||||
|
||||
logger.info(f"Running task: {task}")
|
||||
|
||||
# Prepare the prompt
|
||||
prompt = self._prepare_prompt(task)
|
||||
|
||||
# Execute with retry logic
|
||||
outputs = self._execute_with_retry(
|
||||
self.llm.generate, prompt, self.sampling_params
|
||||
)
|
||||
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
return response
|
||||
|
||||
except Exception as error:
|
||||
logger.error(f"Error running task: {error}")
|
||||
raise ToolExecutionError(
|
||||
"run",
|
||||
error,
|
||||
{"task": task, "args": args, "kwargs": kwargs},
|
||||
)
|
||||
|
||||
def _prepare_prompt(self, task: str) -> str:
|
||||
"""
|
||||
Prepare the prompt for the given task.
|
||||
Args:
|
||||
task (str): The task to prepare the prompt for.
|
||||
Returns:
|
||||
str: The prepared prompt.
|
||||
"""
|
||||
if self.system_prompt:
|
||||
return f"{self.system_prompt}\n\nUser: {task}\nAssistant:"
|
||||
return f"User: {task}\nAssistant:"
|
||||
|
||||
def __call__(self, task: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
Call the model for the given task.
|
||||
Args:
|
||||
task (str): The task to run the model for.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Returns:
|
||||
str: The model's response.
|
||||
"""
|
||||
return self.run(task, *args, **kwargs)
|
||||
|
||||
def batched_run(
|
||||
self, tasks: List[str], batch_size: int = 10
|
||||
) -> List[str]:
|
||||
"""
|
||||
Run the model for multiple tasks in batches.
|
||||
Args:
|
||||
tasks (List[str]): List of tasks to run.
|
||||
batch_size (int): Size of each batch. Defaults to 10.
|
||||
Returns:
|
||||
List[str]: List of model responses.
|
||||
Raises:
|
||||
ToolExecutionError: If an error occurs during batch execution.
|
||||
"""
|
||||
logger.info(
|
||||
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}"
|
||||
)
|
||||
results = []
|
||||
|
||||
try:
|
||||
for i in range(0, len(tasks), batch_size):
|
||||
batch = tasks[i : i + batch_size]
|
||||
for task in batch:
|
||||
logger.info(f"Running task: {task}")
|
||||
try:
|
||||
result = self.run(task)
|
||||
results.append(result)
|
||||
except ToolExecutionError as e:
|
||||
logger.error(
|
||||
f"Failed to execute task '{task}': {e}"
|
||||
)
|
||||
results.append(f"Error: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info("Completed all tasks.")
|
||||
return results
|
||||
|
||||
except Exception as error:
|
||||
logger.error(f"Error in batch execution: {error}")
|
||||
raise ToolExecutionError(
|
||||
"batched_run",
|
||||
error,
|
||||
{"tasks": tasks, "batch_size": batch_size},
|
||||
)
|
@ -1,599 +0,0 @@
|
||||
import json
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from swarms.utils.loguru_logger import initialize_logger
|
||||
|
||||
logger = initialize_logger(log_folder="swarm_matcher")
|
||||
|
||||
|
||||
class SwarmType(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
|
||||
|
||||
class SwarmMatcherConfig(BaseModel):
|
||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dim: int = (
|
||||
512 # Dimension of the sentence-transformers model
|
||||
)
|
||||
|
||||
|
||||
class SwarmMatcher:
|
||||
"""
|
||||
A class for matching tasks to swarm types based on their descriptions.
|
||||
It utilizes a transformer model to generate embeddings for task and swarm type descriptions,
|
||||
and then calculates the dot product to find the best match.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SwarmMatcherConfig):
|
||||
"""
|
||||
Initializes the SwarmMatcher with a configuration.
|
||||
|
||||
Args:
|
||||
config (SwarmMatcherConfig): The configuration for the SwarmMatcher.
|
||||
"""
|
||||
logger.add("swarm_matcher_debug.log", level="DEBUG")
|
||||
logger.debug("Initializing SwarmMatcher")
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"torch package not found. Pip install torch."
|
||||
)
|
||||
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"transformers package not found. Pip install transformers."
|
||||
)
|
||||
|
||||
self.torch = torch
|
||||
try:
|
||||
self.config = config
|
||||
self.tokenizer = (
|
||||
transformers.AutoTokenizer.from_pretrained(
|
||||
config.model_name
|
||||
)
|
||||
)
|
||||
self.model = transformers.AutoModel.from_pretrained(
|
||||
config.model_name
|
||||
)
|
||||
self.swarm_types: List[SwarmType] = []
|
||||
logger.debug("SwarmMatcher initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing SwarmMatcher: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Generates an embedding for a given text using the configured model.
|
||||
|
||||
Args:
|
||||
text (str): The text for which to generate an embedding.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The embedding vector for the text.
|
||||
"""
|
||||
logger.debug(f"Getting embedding for text: {text[:50]}...")
|
||||
try:
|
||||
inputs = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
)
|
||||
with self.torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
embedding = (
|
||||
outputs.last_hidden_state.mean(dim=1)
|
||||
.squeeze()
|
||||
.numpy()
|
||||
)
|
||||
logger.debug("Embedding generated successfully")
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {str(e)}")
|
||||
raise
|
||||
|
||||
def add_swarm_type(self, swarm_type: SwarmType):
|
||||
"""
|
||||
Adds a swarm type to the list of swarm types, generating an embedding for its description.
|
||||
|
||||
Args:
|
||||
swarm_type (SwarmType): The swarm type to add.
|
||||
"""
|
||||
logger.debug(f"Adding swarm type: {swarm_type.name}")
|
||||
try:
|
||||
embedding = self.get_embedding(swarm_type.description)
|
||||
swarm_type.embedding = embedding.tolist()
|
||||
self.swarm_types.append(swarm_type)
|
||||
logger.info(f"Added swarm type: {swarm_type.name}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding swarm type {swarm_type.name}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def find_best_match(self, task: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Finds the best match for a given task among the registered swarm types.
|
||||
|
||||
Args:
|
||||
task (str): The task for which to find the best match.
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: A tuple containing the name of the best matching swarm type and the score.
|
||||
"""
|
||||
logger.debug(f"Finding best match for task: {task[:50]}...")
|
||||
try:
|
||||
task_embedding = self.get_embedding(task)
|
||||
best_match = None
|
||||
best_score = -float("inf")
|
||||
for swarm_type in self.swarm_types:
|
||||
score = np.dot(
|
||||
task_embedding, np.array(swarm_type.embedding)
|
||||
)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = swarm_type
|
||||
logger.info(
|
||||
f"Best match for task: {best_match.name} (score: {best_score})"
|
||||
)
|
||||
return best_match.name, float(best_score)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error finding best match for task: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def auto_select_swarm(self, task: str) -> str:
|
||||
"""
|
||||
Automatically selects the best swarm type for a given task based on their descriptions.
|
||||
|
||||
Args:
|
||||
task (str): The task for which to select a swarm type.
|
||||
|
||||
Returns:
|
||||
str: The name of the selected swarm type.
|
||||
"""
|
||||
logger.debug(f"Auto-selecting swarm for task: {task[:50]}...")
|
||||
best_match, score = self.find_best_match(task)
|
||||
logger.info(f"Task: {task}")
|
||||
logger.info(f"Selected Swarm Type: {best_match}")
|
||||
logger.info(f"Confidence Score: {score:.2f}")
|
||||
return best_match
|
||||
|
||||
def run_multiple(self, tasks: List[str], *args, **kwargs) -> str:
|
||||
swarms = []
|
||||
|
||||
for task in tasks:
|
||||
output = self.auto_select_swarm(task)
|
||||
|
||||
# Append
|
||||
swarms.append(output)
|
||||
|
||||
return swarms
|
||||
|
||||
def save_swarm_types(self, filename: str):
|
||||
"""
|
||||
Saves the registered swarm types to a JSON file.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to which to save the swarm types.
|
||||
"""
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump([st.dict() for st in self.swarm_types], f)
|
||||
logger.info(f"Saved swarm types to {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving swarm types: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_swarm_types(self, filename: str):
|
||||
"""
|
||||
Loads swarm types from a JSON file.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file from which to load the swarm types.
|
||||
"""
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
swarm_types_data = json.load(f)
|
||||
self.swarm_types = [
|
||||
SwarmType(**st) for st in swarm_types_data
|
||||
]
|
||||
logger.info(f"Loaded swarm types from {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading swarm types: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def initialize_swarm_types(matcher: SwarmMatcher):
|
||||
logger.debug("Initializing swarm types")
|
||||
swarm_types = [
|
||||
SwarmType(
|
||||
name="AgentRearrange",
|
||||
description="Optimize agent order and rearrange flow for multi-step tasks, ensuring efficient task allocation and minimizing bottlenecks. Keywords: orchestration, coordination, pipeline optimization, task scheduling, resource allocation, workflow management, agent organization, process optimization",
|
||||
),
|
||||
SwarmType(
|
||||
name="MixtureOfAgents",
|
||||
description="Combine diverse expert agents for comprehensive analysis, fostering a collaborative approach to problem-solving and leveraging individual strengths. Keywords: multi-agent system, expert collaboration, distributed intelligence, collective problem solving, agent specialization, team coordination, hybrid approaches, knowledge synthesis",
|
||||
),
|
||||
SwarmType(
|
||||
name="SpreadSheetSwarm",
|
||||
description="Collaborative data processing and analysis in a spreadsheet-like environment, facilitating real-time data sharing and visualization. Keywords: data analysis, tabular processing, collaborative editing, data transformation, spreadsheet operations, data visualization, real-time collaboration, structured data",
|
||||
),
|
||||
SwarmType(
|
||||
name="SequentialWorkflow",
|
||||
description="Execute tasks in a step-by-step, sequential process workflow, ensuring a logical and methodical approach to task execution. Keywords: linear processing, waterfall methodology, step-by-step execution, ordered tasks, sequential operations, process flow, systematic approach, staged execution",
|
||||
),
|
||||
SwarmType(
|
||||
name="ConcurrentWorkflow",
|
||||
description="Process multiple tasks or data sources concurrently in parallel, maximizing productivity and reducing processing time. Keywords: parallel processing, multi-threading, asynchronous execution, distributed computing, concurrent operations, simultaneous tasks, parallel workflows, scalable processing",
|
||||
),
|
||||
SwarmType(
|
||||
name="HierarchicalSwarm",
|
||||
description="Organize agents in a hierarchical structure with clear reporting lines and delegation of responsibilities. Keywords: management hierarchy, organizational structure, delegation, supervision, chain of command, tiered organization, structured coordination",
|
||||
),
|
||||
# SwarmType(
|
||||
# name="AdaptiveSwarm",
|
||||
# description="Dynamically adjust agent behavior and swarm configuration based on task requirements and performance feedback. Keywords: dynamic adaptation, self-optimization, feedback loops, learning systems, flexible configuration, responsive behavior, adaptive algorithms",
|
||||
# ),
|
||||
# SwarmType(
|
||||
# name="ConsensusSwarm",
|
||||
# description="Achieve group decisions through consensus mechanisms and voting protocols among multiple agents. Keywords: group decision making, voting systems, collective intelligence, agreement protocols, democratic processes, collaborative decisions",
|
||||
# ),
|
||||
]
|
||||
|
||||
for swarm_type in swarm_types:
|
||||
matcher.add_swarm_type(swarm_type)
|
||||
logger.debug("Swarm types initialized")
|
||||
|
||||
|
||||
def swarm_matcher(task: str, *args, **kwargs):
|
||||
"""
|
||||
Runs the SwarmMatcher example with predefined tasks and swarm types.
|
||||
"""
|
||||
config = SwarmMatcherConfig()
|
||||
matcher = SwarmMatcher(config)
|
||||
initialize_swarm_types(matcher)
|
||||
|
||||
# matcher.save_swarm_types(f"swarm_logs/{uuid4().hex}.json")
|
||||
|
||||
swarm_type = matcher.auto_select_swarm(task)
|
||||
|
||||
logger.info(f"{swarm_type}")
|
||||
|
||||
return swarm_type
|
||||
|
||||
|
||||
# from typing import List, Tuple, Dict
|
||||
# from pydantic import BaseModel, Field
|
||||
# from loguru import logger
|
||||
# from uuid import uuid4
|
||||
# import chromadb
|
||||
# import json
|
||||
# from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
# class SwarmType(BaseModel):
|
||||
# """A swarm type with its name, description and optional metadata"""
|
||||
|
||||
# id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
# name: str
|
||||
# description: str
|
||||
# metadata: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
# class SwarmMatcherConfig(BaseModel):
|
||||
# """Configuration for the SwarmMatcher"""
|
||||
|
||||
# collection_name: str = "swarm_types"
|
||||
# distance_metric: str = "cosine" # or "l2" or "ip"
|
||||
# embedding_function: str = (
|
||||
# "sentence-transformers/all-mpnet-base-v2" # Better model than MiniLM
|
||||
# )
|
||||
# persist_directory: str = "./chroma_db"
|
||||
|
||||
|
||||
# class SwarmMatcher:
|
||||
# """
|
||||
# An improved swarm matcher that uses ChromaDB for better vector similarity search.
|
||||
# Features:
|
||||
# - Persistent storage of embeddings
|
||||
# - Better vector similarity search with multiple distance metrics
|
||||
# - Improved embedding model
|
||||
# - Metadata filtering capabilities
|
||||
# - Batch operations support
|
||||
# """
|
||||
|
||||
# def __init__(self, config: SwarmMatcherConfig):
|
||||
# """Initialize the improved swarm matcher"""
|
||||
# logger.add("swarm_matcher.log", rotation="100 MB")
|
||||
# self.config = config
|
||||
|
||||
# # Initialize ChromaDB client with persistence
|
||||
# self.chroma_client = chromadb.Client()
|
||||
|
||||
# # Get or create collection
|
||||
# try:
|
||||
# self.collection = self.chroma_client.get_collection(
|
||||
# name=config.collection_name,
|
||||
# )
|
||||
# except ValueError:
|
||||
# self.collection = self.chroma_client.create_collection(
|
||||
# name=config.collection_name,
|
||||
# metadata={"hnsw:space": config.distance_metric},
|
||||
# )
|
||||
|
||||
# logger.info(
|
||||
# f"Initialized SwarmMatcher with collection '{config.collection_name}'"
|
||||
# )
|
||||
|
||||
# def add_swarm_type(self, swarm_type: SwarmType) -> None:
|
||||
# """Add a single swarm type to the collection"""
|
||||
# try:
|
||||
# self.collection.add(
|
||||
# ids=[swarm_type.id],
|
||||
# documents=[swarm_type.description],
|
||||
# metadatas=[
|
||||
# {"name": swarm_type.name, **swarm_type.metadata}
|
||||
# ],
|
||||
# )
|
||||
# logger.info(f"Added swarm type: {swarm_type.name}")
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Error adding swarm type {swarm_type.name}: {str(e)}"
|
||||
# )
|
||||
# raise
|
||||
|
||||
# def add_swarm_types(self, swarm_types: List[SwarmType]) -> None:
|
||||
# """Add multiple swarm types in batch"""
|
||||
# try:
|
||||
# self.collection.add(
|
||||
# ids=[st.id for st in swarm_types],
|
||||
# documents=[st.description for st in swarm_types],
|
||||
# metadatas=[
|
||||
# {"name": st.name, **st.metadata}
|
||||
# for st in swarm_types
|
||||
# ],
|
||||
# )
|
||||
# logger.info(f"Added {len(swarm_types)} swarm types")
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Error adding swarm types in batch: {str(e)}"
|
||||
# )
|
||||
# raise
|
||||
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(3),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
# )
|
||||
# def find_best_matches(
|
||||
# self,
|
||||
# task: str,
|
||||
# n_results: int = 3,
|
||||
# score_threshold: float = 0.7,
|
||||
# ) -> List[Tuple[str, float]]:
|
||||
# """
|
||||
# Find the best matching swarm types for a given task
|
||||
# Returns multiple matches with their scores
|
||||
# """
|
||||
# try:
|
||||
# results = self.collection.query(
|
||||
# query_texts=[task],
|
||||
# n_results=n_results,
|
||||
# include=["metadatas", "distances"],
|
||||
# )
|
||||
|
||||
# matches = []
|
||||
# for metadata, distance in zip(
|
||||
# results["metadatas"][0], results["distances"][0]
|
||||
# ):
|
||||
# # Convert distance to similarity score (1 - normalized_distance)
|
||||
# score = 1 - (
|
||||
# distance / 2
|
||||
# ) # Normalize cosine distance to [0,1]
|
||||
# if score >= score_threshold:
|
||||
# matches.append((metadata["name"], score))
|
||||
|
||||
# logger.info(f"Found {len(matches)} matches for task")
|
||||
# return matches
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error finding matches for task: {str(e)}")
|
||||
# raise
|
||||
|
||||
# def auto_select_swarm(self, task: str) -> str:
|
||||
# """
|
||||
# Automatically select the best swarm type for a task
|
||||
# Returns only the top match
|
||||
# """
|
||||
# matches = self.find_best_matches(task, n_results=1)
|
||||
# if not matches:
|
||||
# logger.warning("No suitable matches found for task")
|
||||
# return "SequentialWorkflow" # Default fallback
|
||||
|
||||
# best_match, score = matches[0]
|
||||
# logger.info(
|
||||
# f"Selected swarm type '{best_match}' with confidence {score:.3f}"
|
||||
# )
|
||||
# return best_match
|
||||
|
||||
# def run_multiple(self, tasks: List[str]) -> List[str]:
|
||||
# """Process multiple tasks in batch"""
|
||||
# return [self.auto_select_swarm(task) for task in tasks]
|
||||
|
||||
# def save_swarm_types(self, filename: str) -> None:
|
||||
# """Export swarm types to JSON"""
|
||||
# try:
|
||||
# all_data = self.collection.get(
|
||||
# include=["metadatas", "documents"]
|
||||
# )
|
||||
# swarm_types = [
|
||||
# SwarmType(
|
||||
# id=id_,
|
||||
# name=metadata["name"],
|
||||
# description=document,
|
||||
# metadata={
|
||||
# k: v
|
||||
# for k, v in metadata.items()
|
||||
# if k != "name"
|
||||
# },
|
||||
# )
|
||||
# for id_, metadata, document in zip(
|
||||
# all_data["ids"],
|
||||
# all_data["metadatas"],
|
||||
# all_data["documents"],
|
||||
# )
|
||||
# ]
|
||||
|
||||
# with open(filename, "w") as f:
|
||||
# json.dump(
|
||||
# [st.dict() for st in swarm_types], f, indent=2
|
||||
# )
|
||||
# logger.info(f"Saved swarm types to {filename}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error saving swarm types: {str(e)}")
|
||||
# raise
|
||||
|
||||
# def load_swarm_types(self, filename: str) -> None:
|
||||
# """Import swarm types from JSON"""
|
||||
# try:
|
||||
# with open(filename, "r") as f:
|
||||
# swarm_types_data = json.load(f)
|
||||
# swarm_types = [SwarmType(**st) for st in swarm_types_data]
|
||||
# self.add_swarm_types(swarm_types)
|
||||
# logger.info(f"Loaded swarm types from {filename}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error loading swarm types: {str(e)}")
|
||||
# raise
|
||||
|
||||
|
||||
# def initialize_default_swarm_types(matcher: SwarmMatcher) -> None:
|
||||
# """Initialize the matcher with default swarm types"""
|
||||
# swarm_types = [
|
||||
# SwarmType(
|
||||
# name="AgentRearrange",
|
||||
# description="""
|
||||
# Optimize agent order and rearrange flow for multi-step tasks, ensuring efficient task allocation
|
||||
# and minimizing bottlenecks. Specialized in orchestration, coordination, pipeline optimization,
|
||||
# task scheduling, resource allocation, workflow management, agent organization, and process optimization.
|
||||
# Best for tasks requiring complex agent interactions and workflow optimization.
|
||||
# """,
|
||||
# metadata={
|
||||
# "category": "optimization",
|
||||
# "complexity": "high",
|
||||
# },
|
||||
# ),
|
||||
# SwarmType(
|
||||
# name="MixtureOfAgents",
|
||||
# description="""
|
||||
# Combine diverse expert agents for comprehensive analysis, fostering a collaborative approach
|
||||
# to problem-solving and leveraging individual strengths. Focuses on multi-agent systems,
|
||||
# expert collaboration, distributed intelligence, collective problem solving, agent specialization,
|
||||
# team coordination, hybrid approaches, and knowledge synthesis. Ideal for complex problems
|
||||
# requiring multiple areas of expertise.
|
||||
# """,
|
||||
# metadata={
|
||||
# "category": "collaboration",
|
||||
# "complexity": "high",
|
||||
# },
|
||||
# ),
|
||||
# SwarmType(
|
||||
# name="SpreadSheetSwarm",
|
||||
# description="""
|
||||
# Collaborative data processing and analysis in a spreadsheet-like environment, facilitating
|
||||
# real-time data sharing and visualization. Specializes in data analysis, tabular processing,
|
||||
# collaborative editing, data transformation, spreadsheet operations, data visualization,
|
||||
# real-time collaboration, and structured data handling. Perfect for data-intensive tasks
|
||||
# requiring structured analysis.
|
||||
# """,
|
||||
# metadata={
|
||||
# "category": "data_processing",
|
||||
# "complexity": "medium",
|
||||
# },
|
||||
# ),
|
||||
# SwarmType(
|
||||
# name="SequentialWorkflow",
|
||||
# description="""
|
||||
# Execute tasks in a step-by-step, sequential process workflow, ensuring a logical and methodical
|
||||
# approach to task execution. Focuses on linear processing, waterfall methodology, step-by-step
|
||||
# execution, ordered tasks, sequential operations, process flow, systematic approach, and staged
|
||||
# execution. Best for tasks requiring strict order and dependencies.
|
||||
# """,
|
||||
# metadata={"category": "workflow", "complexity": "low"},
|
||||
# ),
|
||||
# SwarmType(
|
||||
# name="ConcurrentWorkflow",
|
||||
# description="""
|
||||
# Process multiple tasks or data sources concurrently in parallel, maximizing productivity
|
||||
# and reducing processing time. Specializes in parallel processing, multi-threading,
|
||||
# asynchronous execution, distributed computing, concurrent operations, simultaneous tasks,
|
||||
# parallel workflows, and scalable processing. Ideal for independent tasks that can be
|
||||
# processed simultaneously.
|
||||
# """,
|
||||
# metadata={"category": "workflow", "complexity": "medium"},
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# matcher.add_swarm_types(swarm_types)
|
||||
# logger.info("Initialized default swarm types")
|
||||
|
||||
|
||||
# def create_swarm_matcher(
|
||||
# persist_dir: str = "./chroma_db",
|
||||
# collection_name: str = "swarm_types",
|
||||
# ) -> SwarmMatcher:
|
||||
# """Convenience function to create and initialize a swarm matcher"""
|
||||
# config = SwarmMatcherConfig(
|
||||
# persist_directory=persist_dir, collection_name=collection_name
|
||||
# )
|
||||
# matcher = SwarmMatcher(config)
|
||||
# initialize_default_swarm_types(matcher)
|
||||
# return matcher
|
||||
|
||||
|
||||
# # Example usage
|
||||
# def swarm_matcher(task: str) -> str:
|
||||
# # Create and initialize matcher
|
||||
# matcher = create_swarm_matcher()
|
||||
|
||||
# swarm_type = matcher.auto_select_swarm(task)
|
||||
# print(f"Task: {task}\nSelected Swarm: {swarm_type}\n")
|
||||
|
||||
# return swarm_type
|
||||
|
||||
|
||||
# # # Example usage
|
||||
# # if __name__ == "__main__":
|
||||
# # # Create and initialize matcher
|
||||
# # matcher = create_swarm_matcher()
|
||||
|
||||
# # # Example tasks
|
||||
# # tasks = [
|
||||
# # "Analyze this spreadsheet of sales data and create visualizations",
|
||||
# # "Coordinate multiple AI agents to solve a complex problem",
|
||||
# # "Process these tasks one after another in a specific order",
|
||||
# # "Write multiple blog posts about the latest advancements in swarm intelligence all at once",
|
||||
# # "Write a blog post about the latest advancements in swarm intelligence",
|
||||
# # ]
|
||||
|
||||
# # # Process tasks
|
||||
# # for task in tasks:
|
||||
# # swarm_type = matcher.auto_select_swarm(task)
|
||||
# # print(f"Task: {task}\nSelected Swarm: {swarm_type}\n")
|
@ -1,421 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from swarms.tools.logits_processor import (
|
||||
NumberStoppingCriteria,
|
||||
OutputNumbersTokens,
|
||||
StringStoppingCriteria,
|
||||
)
|
||||
from swarms.utils.auto_download_check_packages import (
|
||||
auto_check_and_download_package,
|
||||
)
|
||||
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
auto_check_and_download_package(
|
||||
"transformers", package_manager="pip"
|
||||
)
|
||||
import transformers
|
||||
|
||||
|
||||
GENERATION_MARKER = "|GENERATION|"
|
||||
|
||||
|
||||
class Jsonformer:
|
||||
"""
|
||||
Initializes the FormatTools class.
|
||||
|
||||
Args:
|
||||
model (PreTrainedModel): The pre-trained model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer for the model.
|
||||
json_schema (Dict[str, Any]): The JSON schema.
|
||||
prompt (str): The prompt for generation.
|
||||
|
||||
Keyword Args:
|
||||
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
||||
max_array_length (int, optional): The maximum length of an array. Defaults to 10.
|
||||
max_number_tokens (int, optional): The maximum number of tokens for numbers. Defaults to 6.
|
||||
temperature (float, optional): The temperature for generation. Defaults to 1.0.
|
||||
max_string_token_length (int, optional): The maximum length of a string token. Defaults to 10.
|
||||
"""
|
||||
|
||||
value: Dict[str, Any] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: transformers.PreTrainedModel = None, # type: ignore
|
||||
tokenizer: transformers.PreTrainedTokenizer = None, # type: ignore
|
||||
json_schema: Union[Dict[str, Any], BaseModel] = None,
|
||||
schemas: List[Union[Dict[str, Any], BaseModel]] = [],
|
||||
prompt: str = None,
|
||||
*,
|
||||
debug: bool = False,
|
||||
max_array_length: int = 10,
|
||||
max_number_tokens: int = 6,
|
||||
temperature: float = 1.0,
|
||||
max_string_token_length: int = 10,
|
||||
llm: Any = None,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.json_schema = json_schema
|
||||
self.prompt = prompt
|
||||
self.llm = llm
|
||||
self.schemas = schemas
|
||||
|
||||
self.number_logit_processor = OutputNumbersTokens(
|
||||
self.tokenizer, self.prompt
|
||||
)
|
||||
|
||||
self.generation_marker = "|GENERATION|"
|
||||
self.debug_on = debug
|
||||
self.max_array_length = max_array_length
|
||||
|
||||
self.max_number_tokens = max_number_tokens
|
||||
self.temperature = temperature
|
||||
self.max_string_token_length = max_string_token_length
|
||||
|
||||
def generate_number(
|
||||
self, temperature: Union[float, None] = None, iterations=0
|
||||
):
|
||||
"""
|
||||
Generates a number based on the given prompt.
|
||||
|
||||
Args:
|
||||
temperature (float, optional): The temperature value for number generation. Defaults to None.
|
||||
iterations (int, optional): The number of iterations for generating a valid number. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
float: The generated number.
|
||||
|
||||
Raises:
|
||||
ValueError: If a valid number cannot be generated after 3 iterations.
|
||||
"""
|
||||
if self.model:
|
||||
prompt = self.get_prompt()
|
||||
self.debug("[generate_number]", prompt, is_prompt=True)
|
||||
input_tokens = self.tokenizer.encode(
|
||||
prompt, return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
|
||||
response = self.model.generate(
|
||||
input_tokens,
|
||||
max_new_tokens=self.max_number_tokens,
|
||||
num_return_sequences=1,
|
||||
logits_processor=[self.number_logit_processor],
|
||||
stopping_criteria=[
|
||||
NumberStoppingCriteria(
|
||||
self.tokenizer, len(input_tokens[0])
|
||||
)
|
||||
],
|
||||
temperature=temperature or self.temperature,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
response = self.tokenizer.decode(
|
||||
response[0], skip_special_tokens=True
|
||||
)
|
||||
|
||||
response = response[len(prompt) :]
|
||||
response = response.strip().rstrip(".")
|
||||
self.debug("[generate_number]", response)
|
||||
try:
|
||||
return float(response)
|
||||
except ValueError:
|
||||
if iterations > 3:
|
||||
raise ValueError(
|
||||
"Failed to generate a valid number"
|
||||
)
|
||||
|
||||
return self.generate_number(
|
||||
temperature=self.temperature * 1.3,
|
||||
iterations=iterations + 1,
|
||||
)
|
||||
elif self.llm:
|
||||
prompt = self.get_prompt()
|
||||
self.debug("[generate_number]", prompt, is_prompt=True)
|
||||
response = self.llm(prompt)
|
||||
response = response[len(prompt) :]
|
||||
response = response.strip().rstrip(".")
|
||||
self.debug("[generate_number]", response)
|
||||
try:
|
||||
return float(response)
|
||||
except ValueError:
|
||||
if iterations > 3:
|
||||
raise ValueError(
|
||||
"Failed to generate a valid number"
|
||||
)
|
||||
|
||||
return self.generate_number(
|
||||
temperature=self.temperature * 1.3,
|
||||
iterations=iterations + 1,
|
||||
)
|
||||
|
||||
elif self.llm and self.model:
|
||||
raise ValueError("Both LLM and model cannot be None")
|
||||
|
||||
def generate_boolean(self) -> bool:
|
||||
"""
|
||||
Generates a boolean value based on the given prompt.
|
||||
|
||||
Returns:
|
||||
bool: The generated boolean value.
|
||||
"""
|
||||
if self.model:
|
||||
prompt = self.get_prompt()
|
||||
self.debug("[generate_boolean]", prompt, is_prompt=True)
|
||||
|
||||
input_tensor = self.tokenizer.encode(
|
||||
prompt, return_tensors="pt"
|
||||
)
|
||||
output = self.model.forward(
|
||||
input_tensor.to(self.model.device)
|
||||
)
|
||||
logits = output.logits[0, -1]
|
||||
|
||||
# todo: this assumes that "true" and "false" are both tokenized to a single token
|
||||
# this is probably not true for all tokenizers
|
||||
# this can be fixed by looking at only the first token of both "true" and "false"
|
||||
true_token_id = self.tokenizer.convert_tokens_to_ids(
|
||||
"true"
|
||||
)
|
||||
false_token_id = self.tokenizer.convert_tokens_to_ids(
|
||||
"false"
|
||||
)
|
||||
|
||||
result = logits[true_token_id] > logits[false_token_id]
|
||||
|
||||
self.debug("[generate_boolean]", result)
|
||||
|
||||
return result.item()
|
||||
|
||||
elif self.llm:
|
||||
prompt = self.get_prompt()
|
||||
self.debug("[generate_boolean]", prompt, is_prompt=True)
|
||||
|
||||
output = self.llm(prompt)
|
||||
|
||||
return output if output == "true" or "false" else None
|
||||
|
||||
else:
|
||||
raise ValueError("Both LLM and model cannot be None")
|
||||
|
||||
def generate_string(self) -> str:
|
||||
if self.model:
|
||||
prompt = self.get_prompt() + '"'
|
||||
self.debug("[generate_string]", prompt, is_prompt=True)
|
||||
input_tokens = self.tokenizer.encode(
|
||||
prompt, return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
|
||||
response = self.model.generate(
|
||||
input_tokens,
|
||||
max_new_tokens=self.max_string_token_length,
|
||||
num_return_sequences=1,
|
||||
temperature=self.temperature,
|
||||
stopping_criteria=[
|
||||
StringStoppingCriteria(
|
||||
self.tokenizer, len(input_tokens[0])
|
||||
)
|
||||
],
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# Some models output the prompt as part of the response
|
||||
# This removes the prompt from the response if it is present
|
||||
if (
|
||||
len(response[0]) >= len(input_tokens[0])
|
||||
and (
|
||||
response[0][: len(input_tokens[0])]
|
||||
== input_tokens
|
||||
).all()
|
||||
):
|
||||
response = response[0][len(input_tokens[0]) :]
|
||||
if response.shape[0] == 1:
|
||||
response = response[0]
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
response, skip_special_tokens=True
|
||||
)
|
||||
|
||||
self.debug("[generate_string]", "|" + response + "|")
|
||||
|
||||
if response.count('"') < 1:
|
||||
return response
|
||||
|
||||
return response.split('"')[0].strip()
|
||||
|
||||
elif self.llm:
|
||||
prompt = self.get_prompt() + '"'
|
||||
self.debug("[generate_string]", prompt, is_prompt=True)
|
||||
|
||||
response = self.llm(prompt)
|
||||
|
||||
# Some models output the prompt as part of the response
|
||||
# This removes the prompt from the response if it is present
|
||||
if (
|
||||
len(response[0]) >= len(input_tokens[0])
|
||||
and (
|
||||
response[0][: len(input_tokens[0])]
|
||||
== input_tokens
|
||||
).all()
|
||||
):
|
||||
response = response[0][len(input_tokens[0]) :]
|
||||
if response.shape[0] == 1:
|
||||
response = response[0]
|
||||
|
||||
self.debug("[generate_string]", "|" + response + "|")
|
||||
|
||||
if response.count('"') < 1:
|
||||
return response
|
||||
|
||||
return response.split('"')[0].strip()
|
||||
|
||||
else:
|
||||
raise ValueError("Both LLM and model cannot be None")
|
||||
|
||||
def generate_object(
|
||||
self, properties: Dict[str, Any], obj: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
for key, schema in properties.items():
|
||||
self.debug("[generate_object] generating value for", key)
|
||||
obj[key] = self.generate_value(schema, obj, key)
|
||||
return obj
|
||||
|
||||
def generate_value(
|
||||
self,
|
||||
schema: Dict[str, Any],
|
||||
obj: Union[Dict[str, Any], List[Any]],
|
||||
key: Union[str, None] = None,
|
||||
) -> Any:
|
||||
schema_type = schema["type"]
|
||||
if schema_type == "number":
|
||||
if key:
|
||||
obj[key] = self.generation_marker
|
||||
else:
|
||||
obj.append(self.generation_marker)
|
||||
return self.generate_number()
|
||||
elif schema_type == "boolean":
|
||||
if key:
|
||||
obj[key] = self.generation_marker
|
||||
else:
|
||||
obj.append(self.generation_marker)
|
||||
return self.generate_boolean()
|
||||
elif schema_type == "string":
|
||||
if key:
|
||||
obj[key] = self.generation_marker
|
||||
else:
|
||||
obj.append(self.generation_marker)
|
||||
return self.generate_string()
|
||||
elif schema_type == "array":
|
||||
new_array = []
|
||||
obj[key] = new_array
|
||||
return self.generate_array(schema["items"], new_array)
|
||||
elif schema_type == "object":
|
||||
new_obj = {}
|
||||
if key:
|
||||
obj[key] = new_obj
|
||||
else:
|
||||
obj.append(new_obj)
|
||||
return self.generate_object(schema["properties"], new_obj)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported schema type: {schema_type}"
|
||||
)
|
||||
|
||||
def generate_array(
|
||||
self, item_schema: Dict[str, Any], obj: Dict[str, Any]
|
||||
) -> list:
|
||||
if self.model:
|
||||
for _ in range(self.max_array_length):
|
||||
# forces array to have at least one element
|
||||
element = self.generate_value(item_schema, obj)
|
||||
obj[-1] = element
|
||||
|
||||
obj.append(self.generation_marker)
|
||||
input_prompt = self.get_prompt()
|
||||
obj.pop()
|
||||
input_tensor = self.tokenizer.encode(
|
||||
input_prompt, return_tensors="pt"
|
||||
)
|
||||
output = self.model.forward(
|
||||
input_tensor.to(self.model.device)
|
||||
)
|
||||
logits = output.logits[0, -1]
|
||||
|
||||
top_indices = logits.topk(30).indices
|
||||
sorted_token_ids = top_indices[
|
||||
logits[top_indices].argsort(descending=True)
|
||||
]
|
||||
|
||||
found_comma = False
|
||||
found_close_bracket = False
|
||||
|
||||
for token_id in sorted_token_ids:
|
||||
decoded_token = self.tokenizer.decode(token_id)
|
||||
if "," in decoded_token:
|
||||
found_comma = True
|
||||
break
|
||||
if "]" in decoded_token:
|
||||
found_close_bracket = True
|
||||
break
|
||||
|
||||
if found_close_bracket or not found_comma:
|
||||
break
|
||||
|
||||
return obj
|
||||
|
||||
elif self.llm:
|
||||
for _ in range(self.max_array_length):
|
||||
# forces array to have at least one element
|
||||
element = self.generate_value(item_schema, obj)
|
||||
obj[-1] = element
|
||||
|
||||
obj.append(self.generation_marker)
|
||||
input_prompt = self.get_prompt()
|
||||
obj.pop()
|
||||
output = self.llm(input_prompt)
|
||||
|
||||
found_comma = False
|
||||
found_close_bracket = False
|
||||
|
||||
for token_id in output:
|
||||
decoded_token = str(token_id)
|
||||
if "," in decoded_token:
|
||||
found_comma = True
|
||||
break
|
||||
if "]" in decoded_token:
|
||||
found_close_bracket = True
|
||||
break
|
||||
|
||||
if found_close_bracket or not found_comma:
|
||||
break
|
||||
|
||||
return obj
|
||||
|
||||
def get_prompt(self):
|
||||
template = """{prompt}\nOutput result in the following JSON schema format:\n{schema}\nResult: {progress}"""
|
||||
progress = json.dumps(self.value)
|
||||
gen_marker_index = progress.find(
|
||||
f'"{self.generation_marker}"'
|
||||
)
|
||||
if gen_marker_index != -1:
|
||||
progress = progress[:gen_marker_index]
|
||||
else:
|
||||
raise ValueError("Failed to find generation marker")
|
||||
|
||||
prompt = template.format(
|
||||
prompt=self.prompt,
|
||||
schema=json.dumps(self.json_schema),
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def __call__(self) -> Dict[str, Any]:
|
||||
self.value = {}
|
||||
generated_data = self.generate_object(
|
||||
self.json_schema["properties"], self.value
|
||||
)
|
||||
return generated_data
|
@ -1,109 +0,0 @@
|
||||
from swarms.utils.auto_download_check_packages import (
|
||||
auto_check_and_download_package,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
auto_check_and_download_package(
|
||||
"torch", package_manager="pip", upgrade=True
|
||||
)
|
||||
import torch
|
||||
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
auto_check_and_download_package(
|
||||
"transformers", package_manager="pip", upgrade=True
|
||||
)
|
||||
import transformers
|
||||
|
||||
|
||||
class StringStoppingCriteria(transformers.StoppingCriteria):
|
||||
def __init__(
|
||||
self, tokenizer: Any, prompt_length: int # type: ignore
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.prompt_length = prompt_length
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor, # type: ignore
|
||||
_,
|
||||
) -> bool:
|
||||
if len(input_ids[0]) <= self.prompt_length:
|
||||
return False
|
||||
|
||||
last_token_id = input_ids[0][-1]
|
||||
last_token = self.tokenizer.decode(
|
||||
last_token_id, skip_special_tokens=True
|
||||
)
|
||||
|
||||
result = '"' in last_token
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class NumberStoppingCriteria(transformers.StoppingCriteria):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any, # type: ignore
|
||||
prompt_length: int,
|
||||
precision: int = 3,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.precision = precision
|
||||
self.prompt_length = prompt_length
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor, # type: ignore
|
||||
scores: torch.FloatTensor, # type: ignore
|
||||
) -> bool:
|
||||
decoded = self.tokenizer.decode(
|
||||
input_ids[0][self.prompt_length :],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if decoded.count(".") > 1:
|
||||
return True
|
||||
|
||||
if (
|
||||
decoded.count(".") == 1
|
||||
and len(decoded.strip().split(".")[1]) > self.precision
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
len(decoded) > 1
|
||||
and any(c.isdigit() for c in decoded)
|
||||
and decoded[-1] in [" ", "\n"]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OutputNumbersTokens(transformers.LogitsWarper):
|
||||
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, prompt: str): # type: ignore
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenized_prompt = tokenizer(prompt, return_tensors="pt")
|
||||
vocab_size = len(tokenizer)
|
||||
self.allowed_mask = torch.zeros(vocab_size, dtype=torch.bool)
|
||||
|
||||
for _, token_id in tokenizer.get_vocab().items():
|
||||
token_str = tokenizer.decode(token_id).strip()
|
||||
|
||||
if token_str == "" or (
|
||||
all(c.isdigit() or c == "." for c in token_str)
|
||||
and token_str.count(".") <= 1
|
||||
):
|
||||
self.allowed_mask[token_id] = True
|
||||
|
||||
def __call__(self, _, scores):
|
||||
mask = self.allowed_mask.expand_as(scores)
|
||||
scores[~mask] = -float("inf")
|
||||
|
||||
return scores
|
Loading…
Reference in new issue