[REFACTORING]

pull/430/head
Kye 10 months ago
parent 997fd1e143
commit f715a0c5bc

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "4.6.0"
version = "4.6.1"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]
@ -41,7 +41,6 @@ loguru = "0.7.2"
pydantic = "2.6.4"
tenacity = "8.2.3"
Pillow = "10.2.0"
termcolor = "2.2.0"
rich = "13.5.2"
psutil = "*"
sentry-sdk = "*"

@ -18,6 +18,5 @@ pydantic==2.6.4
tenacity==8.2.3
Pillow==10.2.0
termcolor==2.2.0
rich==13.5.2
psutil
sentry-sdk

@ -1,362 +0,0 @@
Modularizing the provided framework for scalability and reliability will involve breaking down the overall architecture into smaller, more manageable pieces, as well as introducing additional features and capabilities to enhance reliability. Here's a list of ideas to achieve this:
### 1. Dynamic Agent Management
To ensure the swarm is both cost-effective and efficient, dynamically creating and destroying agents depending on the workload can be a game changer:
**Idea**: Instead of having a fixed number of agents, allow the `AutoScaler` to both instantiate and destroy agents as necessary.
**Example**:
```python
class AutoScaler:
# ...
def remove_agent(self):
with self.lock:
if self.agents_pool:
agent_to_remove = self.agents_pool.pop()
del agent_to_remove
```
### 2. Task Segmentation & Aggregation
Breaking down tasks into sub-tasks and then aggregating results ensures scalability:
**Idea**: Create a method in the `Orchestrator` to break down larger tasks into smaller tasks and another method to aggregate results from sub-tasks.
**Example**:
```python
class Orchestrator(ABC):
# ...
def segment_task(self, main_task: str) -> List[str]:
# Break down main_task into smaller tasks
# ...
return sub_tasks
def aggregate_results(self, sub_results: List[Any]) -> Any:
# Combine results from sub-tasks into a cohesive output
# ...
return main_result
```
### 3. Enhanced Task Queuing
**Idea**: Prioritize tasks based on importance or deadlines.
**Example**: Use a priority queue for the `task_queue`, ensuring tasks of higher importance are tackled first.
### 4. Error Recovery & Retry Mechanisms
**Idea**: Introduce a retry mechanism for tasks that fail due to transient errors.
**Example**:
```python
class Orchestrator(ABC):
MAX_RETRIES = 3
retry_counts = defaultdict(int)
# ...
def assign_task(self, agent_id, task):
# ...
except Exception as error:
if self.retry_counts[task] < self.MAX_RETRIES:
self.retry_counts[task] += 1
self.task_queue.put(task)
```
### 5. Swarm Communication & Collaboration
**Idea**: Allow agents to communicate or request help from their peers.
**Example**: Implement a `request_assistance` method within agents where, upon facing a challenging task, they can ask for help from other agents.
### 6. Database Management
**Idea**: Periodically clean, optimize, and back up the vector database to ensure data integrity and optimal performance.
### 7. Logging & Monitoring
**Idea**: Implement advanced logging and monitoring capabilities to provide insights into swarm performance, potential bottlenecks, and failures.
**Example**: Use tools like Elasticsearch, Logstash, and Kibana (ELK stack) to monitor logs in real-time.
### 8. Load Balancing
**Idea**: Distribute incoming tasks among agents evenly, ensuring no single agent is overloaded.
**Example**: Use algorithms or tools that assign tasks based on current agent workloads.
### 9. Feedback Loop
**Idea**: Allow the system to learn from its mistakes or inefficiencies. Agents can rate the difficulty of their tasks and this information can be used to adjust future task assignments.
### 10. Agent Specialization
**Idea**: Not all agents are equal. Some might be better suited to certain tasks.
**Example**: Maintain a performance profile for each agent, categorizing them based on their strengths. Assign tasks to agents based on their specialization for optimal performance.
By implementing these ideas and constantly iterating based on real-world usage and performance metrics, it's possible to create a robust and scalable multi-agent collaboration framework.
# 10 improvements to the `Orchestrator` class to enable more flexibility and usability:
1. Dynamic Agent Creation: Allow the number of agents to be specified at runtime, rather than being fixed at the time of instantiation.
```
def add_agents(self, num_agents: int):
for _ in range(num_agents):
self.agents.put(self.agent())
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
```
1. Agent Removal: Allow agents to be removed from the pool.
```
def remove_agents(self, num_agents: int):
for _ in range(num_agents):
if not self.agents.empty():
self.agents.get()
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
```
1. Task Prioritization: Allow tasks to be prioritized.
```
from queue import PriorityQueue
def __init__(self, agent, agent_list: List[Any], task_queue: List[Any], collection_name: str = "swarm", api_key: str = None, model_name: str = None):
# ...
self.task_queue = PriorityQueue()
# ...
def add_task(self, task: Dict[str, Any], priority: int = 0):
self.task_queue.put((priority, task))
```
1. Task Status: Track the status of tasks.
```
from enum import Enum
class TaskStatus(Enum):
QUEUED = 1
RUNNING = 2
COMPLETED = 3
FAILED = 4
# In assign_task method
self.current_tasks[id(task)] = TaskStatus.RUNNING
# On successful completion
self.current_tasks[id(task)] = TaskStatus.COMPLETED
# On failure
self.current_tasks[id(task)] = TaskStatus.FAILED
```
1. Result Retrieval: Allow results to be retrieved by task ID.
```
def retrieve_result(self, task_id: int) -> Any:
return self.collection.query(query_texts=[str(task_id)], n_results=1)
```
1. Batch Task Assignment: Allow multiple tasks to be assigned at once.
```
def assign_tasks(self, tasks: List[Dict[str, Any]]):
for task in tasks:
self.task_queue.put(task)
```
1. Error Handling: Improve error handling by re-queuing failed tasks.
```
# In assign_task method
except Exception as error:
logging.error(f"Failed to process task {id(task)} by agent {id(agent)}. Error: {error}")
self.task_queue.put(task)
```
1. Agent Status: Track the status of agents (e.g., idle, working).
```
self.agent_status = {id(agent): "idle" for agent in self.agents.queue}
# In assign_task method
self.agent_status[id(agent)] = "working"
# On task completion
self.agent_status[id(agent)] = "idle"
```
1. Custom Embedding Function: Allow a custom embedding function to be used.
```
def __init__(self, agent, agent_list: List[Any], task_queue: List[Any], collection_name: str = "swarm", api_key: str = None, model_name: str = None, embed_func=None):
# ...
self.embed_func = embed_func if embed_func else self.embed
# ...
def embed(self, input, api_key, model_name):
# ...
embedding = self.embed_func(input)
# ...
```
1. Agent Communication: Allow agents to communicate with each other.
```
def communicate(self, sender_id: int, receiver_id: int, message: str):
message_vector = self.embed_func(message)
self.collection.add(embeddings=[message_vector], documents=[message], ids=[f"{sender_id}_to_{receiver_id}"])
```
```
import logging
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List
from enum import Enum
import chromadb
from chromadb.utils import embedding_functions
class TaskStatus(Enum):
QUEUED = 1
RUNNING = 2
COMPLETED = 3
FAILED = 4
class Orchestrator:
def __init__(self, agent, agent_list: List[Any], task_queue: List[Any], collection_name: str = "swarm", api_key: str = None, model_name: str = None, embed_func=None):
self.agent = agent
self.agents = queue.Queue()
self.agent_status = {}
self.add_agents(agent_list)
self.task_queue = queue.PriorityQueue()
self.chroma_client = chromadb.Client()
self.collection = self.chroma_client.create_collection(name = collection_name)
self.current_tasks = {}
self.lock = threading.Lock()
self.condition = threading.Condition(self.lock)
self.embed_func = embed_func if embed_func else self.embed
def add_agents(self, num_agents: int):
for _ in range(num_agents):
agent = self.agent()
self.agents.put(agent)
self.agent_status[id(agent)] = "idle"
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
def remove_agents(self, num_agents: int):
for _ in range(num_agents):
if not self.agents.empty():
agent = self.agents.get()
del self.agent_status[id(agent)]
self.executor = ThreadPoolExecutor(max_workers=self.agents.qsize())
def assign_task(self, agent_id: int, task: Dict[str, Any]) -> None:
while True:
with self.condition:
while not self.task_queue:
self.condition.wait()
agent = self.agents.get()
task = self.task_queue.get()
try:
self.agent_status[id(agent)] = "working"
result = self.worker.run(task["content"])
vector_representation = self.embed_func(result)
self.collection.add(embeddings=[vector_representation], documents=[str(id(task))], ids=[str(id(task))])
logging.info(f"Task {id(str)} has been processed by agent {id(agent)} with")
self.current_tasks[id(task)] = TaskStatus.COMPLETED
except Exception as error:
logging.error(f"Failed to process task {id(task)} by agent {id(agent)}. Error: {error}")
self.current_tasks[id(task)] = TaskStatus.FAILED
self.task_queue.put(task)
finally:
with self.condition:
self.agent_status[id(agent)] = "idle"
self.agents.put(agent)
self.condition.notify()
def embed(self, input):
openai = embedding_functions.OpenAIEmbeddingFunction(api_key=self.api_key, model_name=self.model_name)
embedding = openai(input)
return embedding
def retrieve_results(self, agent_id: int) -> Any:
try:
results = self.collection.query(query_texts=[str(agent_id)], n_results=10)
return results
except Exception as e:
logging.error(f"Failed to retrieve results from agent {id(agent_id)}. Error {e}")
raise
def update_vector_db(self, data) -> None:
try:
self.collection.add(embeddings=[data["vector"]], documents=[str(data["task_id"])], ids=[str(data["task_id"])])
except Exception as e:
logging.error(f"Failed to update the vector database. Error: {e}")
raise
def get_vector_db(self):
return self.collection
def append_to_db(self, result: str):
try:
self.collection.add(documents=[result], ids=[str(id(result))])
except Exception as e:
logging.error(f"Failed to append the agent output to database. Error: {e}")
raise
def run(self, objective:str):
if not objective or not isinstance(objective, str):
logging.error("Invalid objective")
raise ValueError("A valid objective is required")
try:
self.task_queue.put((0, objective))
results = [self.assign_task(agent_id, task) for agent_id, task in zip(range(len(self.agents)), self.task_queue)]
for result in results:
self.append_to_db(result)
logging.info(f"Successfully ran swarms with results: {results}")
return results
except Exception as e:
logging.error(f"An error occured in swarm: {e}")
return None
def chat(self, sender_id: int, receiver_id: int, message: str):
message_vector = self.embed_func(message)
# Store the message in the vector database
self.collection.add(embeddings=[message_vector], documents=[message], ids=[f"{sender_id}_to_{receiver_id}"])
def assign_tasks(self, tasks: List[Dict[str, Any]], priority: int = 0):
for task in tasks:
self.task_queue.put((priority, task))
def retrieve_result(self, task_id: int) -> Any:
try:
result = self.collection.query(query_texts=[str(task_id)], n_results=1)
return result
except Exception as e:
logging.error(f"Failed to retrieve result for task {task_id}. Error: {e}")
raise
```
With these improvements, the `Orchestrator` class now supports dynamic agent creation and removal, task prioritization, task status tracking, result retrieval by task ID, batch task assignment, improved error handling, agent status tracking, custom embedding functions, and agent communication. This should make the class more flexible and easier to use when creating swarms of LLMs.

@ -1,5 +1,10 @@
from swarms.structs.agent import Agent
from swarms.structs.agent_job import AgentJob
from swarms.structs.agent_process import (
AgentProcess,
AgentProcessQueue,
)
from swarms.structs.auto_swarm import AutoSwarm, AutoSwarmRouter
from swarms.structs.autoscaler import AutoScaler
from swarms.structs.base import BaseStructure
from swarms.structs.base_swarm import AbstractSwarm
@ -74,12 +79,6 @@ from swarms.structs.utils import (
find_token_in_text,
parse_tasks,
)
from swarms.structs.auto_swarm import AutoSwarm, AutoSwarmRouter
from swarms.structs.agent_process import (
AgentProcess,
AgentProcessQueue,
)
__all__ = [
"Agent",

@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Sequence
from swarms.tools.tool import BaseTool
from pydantic import BaseModel
@dataclass
class Step:
class Step(BaseModel):
"""
Represents a step in a process.
@ -17,8 +16,10 @@ class Step:
tool (BaseTool): The tool used to execute the step.
"""
task: str
id: int
dep: List[int]
args: Dict[str, str]
tool: BaseTool
task: str = None
id: int = 0
dep: List[int] = []
args: Dict[str, str] = {}
tool: BaseTool = None
tools: Sequence[BaseTool] = []
metadata: Dict[str, str] = {}

@ -4,7 +4,7 @@ import queue
import threading
from typing import List, Optional
from fastapi import FastAPI
# from fastapi import FastAPI
from swarms.structs.agent import Agent
from swarms.structs.base import BaseStructure
@ -89,9 +89,6 @@ class SwarmNetwork(BaseStructure):
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
if api_enabled:
self.api = FastAPI()
# For each agent in the pool, run it on it's own thread
if agents is not None:
for agent in agents:

@ -1,80 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def continuous_tensor(
inputs: torch.Tensor, seq_length: torch.LongTensor
):
"""Convert batched tensor to continuous tensor.
Args:
inputs (Tensor): batched tensor.
seq_length (Tensor): length of each sequence.
Return:
Tensor: continuoused tensor.
"""
assert inputs.dim() > 1
if inputs.size(1) == 1:
return inputs.reshape(1, -1)
inputs = [inp[:slen] for inp, slen in zip(inputs, seq_length)]
inputs = torch.cat(inputs).unsqueeze(0)
return inputs
def batch_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor):
"""Convert continuoused tensor to batched tensor.
Args:
inputs (Tensor): continuoused tensor.
seq_length (Tensor): length of each sequence.
Return:
Tensor: batched tensor.
"""
from torch.nn.utils.rnn import pad_sequence
end_loc = seq_length.cumsum(0)
start_loc = end_loc - seq_length
inputs = [
inputs[0, sloc:eloc] for sloc, eloc in zip(start_loc, end_loc)
]
inputs = pad_sequence(inputs, batch_first=True)
return inputs
def page_cache(
paged_cache: torch.Tensor,
batched_cache: torch.Tensor,
cache_length: torch.Tensor,
block_offsets: torch.Tensor,
permute_head: bool = True,
):
"""Convert batched cache to paged cache.
Args:
paged_cache (Tensor): Output paged cache.
batched_cache (Tensor): Input batched cache.
cache_length (Tensor): length of the cache.
block_offsets (Tensor): Offset of each blocks.
"""
assert block_offsets.dim() == 2
block_size = paged_cache.size(1)
batch_size = batched_cache.size(0)
if permute_head:
batched_cache = batched_cache.permute(0, 2, 1, 3)
for b_idx in range(batch_size):
cache_len = cache_length[b_idx]
b_cache = batched_cache[b_idx]
block_off = block_offsets[b_idx]
block_off_idx = 0
for s_start in range(0, cache_len, block_size):
s_end = min(s_start + block_size, cache_len)
s_len = s_end - s_start
b_off = block_off[block_off_idx]
paged_cache[b_off, :s_len] = b_cache[s_start:s_end]
block_off_idx += 1

@ -1,152 +0,0 @@
from unittest.mock import MagicMock
import pytest
from swarms.structs.agent import Agent
from swarms.structs.majority_voting import MajorityVoting
def test_majority_voting_run_concurrent(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=True,
multithreaded=False,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
def test_majority_voting_run_multithreaded(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=False,
multithreaded=True,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None
@pytest.mark.asyncio
async def test_majority_voting_run_asynchronous(mocker):
# Create mock agents
agent1 = MagicMock(spec=Agent)
agent2 = MagicMock(spec=Agent)
agent3 = MagicMock(spec=Agent)
# Create mock majority voting
mv = MajorityVoting(
agents=[agent1, agent2, agent3],
concurrent=False,
multithreaded=False,
asynchronous=True,
)
# Create mock conversation
conversation = MagicMock()
mv.conversation = conversation
# Create mock results
results = ["Paris", "Paris", "Lyon"]
# Mock agent.run method
agent1.run.return_value = results[0]
agent2.run.return_value = results[1]
agent3.run.return_value = results[2]
# Run majority voting
majority_vote = await mv.run("What is the capital of France?")
# Assert agent.run method was called with the correct task
agent1.run.assert_called_once_with(
"What is the capital of France?"
)
agent2.run.assert_called_once_with(
"What is the capital of France?"
)
agent3.run.assert_called_once_with(
"What is the capital of France?"
)
# Assert conversation.add method was called with the correct responses
conversation.add.assert_any_call(agent1.agent_name, results[0])
conversation.add.assert_any_call(agent2.agent_name, results[1])
conversation.add.assert_any_call(agent3.agent_name, results[2])
# Assert majority vote is correct
assert majority_vote is not None

@ -1,36 +0,0 @@
import json
from abc import ABC, abstractmethod
class JSON(ABC):
def __init__(self, schema_path):
"""
Initializes a JSONSchema object.
Args:
schema_path (str): The path to the JSON schema file.
"""
self.schema_path = schema_path
self.schema = self.load_schema()
def load_schema(self):
"""
Loads the JSON schema from the specified file.
Returns:
dict: The loaded JSON schema.
"""
with open(self.schema_path) as f:
return json.load(f)
@abstractmethod
def validate(self, data):
"""
Validates the given data against the JSON schema.
Args:
data (dict): The data to be validated.
Raises:
NotImplementedError: This method needs to be implemented by the subclass.
"""

@ -1,5 +1,3 @@
# from swarms.telemetry.posthog_utils import posthog
from swarms.telemetry.log_all import log_all_calls, log_calls
from swarms.telemetry.sys_info import (
get_cpu_info,

@ -1,6 +1,7 @@
import subprocess
from swarms.telemetry.check_update import check_for_update
from termcolor import colored
def auto_update():
@ -13,6 +14,6 @@ def auto_update():
)
subprocess.run(["pip", "install", "--upgrade", "swarms"])
else:
print("swarms is up to date!")
colored("swarms is up to date!", "red")
except Exception as e:
print(e)

@ -1,59 +0,0 @@
import datetime
import logging
import platform
import pymongo
class Telemetry:
def __init__(self, db_url, db_name):
self.logger = self.setup_logging()
self.db = self.setup_db(db_url, db_name)
def setup_logging(self):
logger = logging.getLogger("telemetry")
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
)
logger.addHandler(handler)
return logger
def setup_db(self, db_url, db_name):
client = pymongo.MongoClient(db_url)
return client[db_name]
def capture_device_data(self):
data = {
"system": platform.system(),
"node": platform.node(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine(),
"processor": platform.processor(),
"time": datetime.datetime.now(),
}
return data
def send_to_db(self, collection_name, data):
collection = self.db[collection_name]
collection.insert_one(data)
def log_and_capture(self, message, level, collection_name):
if level == "info":
self.logger.info(message)
elif level == "error":
self.logger.error(message)
data = self.capture_device_data()
data["log"] = message
self.send_to_db(collection_name, data)
def log_import(self, module_name):
self.logger.info(f"Importing module {module_name}")
module = __import__(module_name, fromlist=["*"])
for k in dir(module):
if not k.startswith("__"):
self.logger.info(f"Imported {k} from {module_name}")

@ -1,6 +0,0 @@
from posthog import Posthog
posthog = Posthog(
project_api_key="phc_Gz6XxldNZIkzW7QnSTGr5HZ28OAYPIfpE7X5A3vUsfO",
host="https://app.posthog.com",
)

@ -86,6 +86,3 @@ def get_user_device_data():
"Swarms [Version]": check_for_package("swarms"),
}
return data
#

@ -40,66 +40,47 @@ from swarms.utils.remove_json_whitespace import (
remove_whitespace_from_yaml,
)
from swarms.utils.save_logs import parse_log_file
from swarms.utils.supervision_masking import (
FeatureType,
compute_mask_iou_vectorized,
filter_masks_by_relative_area,
mask_non_max_suppression,
masks_to_marks,
refine_marks,
)
from swarms.utils.supervision_visualizer import MarkVisualizer
from swarms.utils.token_count_tiktoken import limit_tokens_from_string
from swarms.utils.try_except_wrapper import try_except_wrapper
from swarms.utils.video_to_frames import (
save_frames_as_images,
video_to_frames,
)
from swarms.utils.yaml_output_parser import YamlOutputParser
from swarms.utils.concurrent_utils import execute_concurrently
__all__ = [
"SubprocessCodeInterpreter",
"display_markdown_message",
"extract_code_from_markdown",
"find_image_path",
"limit_tokens_from_string",
"load_model_torch",
"math_eval",
"metrics_decorator",
"pdf_to_text",
"prep_torch_inference",
"print_class_parameters",
"check_device",
"SubprocessCodeInterpreter",
"csv_to_dataframe",
"dataframe_to_strings",
"csv_to_text",
"data_to_text",
"json_to_text",
"txt_to_text",
"data_to_text",
"try_except_wrapper",
"check_device",
"download_img_from_url",
"download_weights_from_url",
"parse_log_file",
"YamlOutputParser",
"ExponentialBackoffMixin",
"load_json",
"sanitize_file_path",
"zip_workspace",
"create_file_in_folder",
"zip_folders",
"find_image_path",
"JsonOutputParser",
"metrics_decorator",
"load_model_torch",
"display_markdown_message",
"math_eval",
"dataframe_to_text",
"extract_code_from_markdown",
"pdf_to_text",
"prep_torch_inference",
"remove_whitespace_from_json",
"remove_whitespace_from_yaml",
"ExponentialBackoffMixin",
"download_img_from_url",
"FeatureType",
"compute_mask_iou_vectorized",
"mask_non_max_suppression",
"filter_masks_by_relative_area",
"masks_to_marks",
"refine_marks",
"parse_log_file",
"MarkVisualizer",
"video_to_frames",
"save_frames_as_images",
"dataframe_to_text",
"zip_workspace",
"sanitize_file_path",
"load_json",
"csv_to_dataframe",
"dataframe_to_strings",
"limit_tokens_from_string",
"try_except_wrapper",
"YamlOutputParser",
"execute_concurrently",
"create_file_in_folder",
"zip_folders",
]
]

@ -1,6 +1,4 @@
from rich.console import Console
from rich.markdown import Markdown
from rich.rule import Rule
from termcolor import colored
def display_markdown_message(message: str, color: str = "cyan"):
@ -9,19 +7,18 @@ def display_markdown_message(message: str, color: str = "cyan"):
Will automatically make single line > tags beautiful.
"""
console = Console()
for line in message.split("\n"):
line = line.strip()
if line == "":
console.print("")
print()
elif line == "---":
console.print(Rule(style=color))
print(colored("-" * 50, color))
else:
console.print(Markdown(line, style=color))
print(colored(line, color))
if "\n" not in message and message.startswith(">"):
# Aesthetic choice. For these tags, they need a space below them
console.print("")
print()
# display_markdown_message("I love you and you are beautiful.", "cyan")

@ -1,43 +0,0 @@
from typing import List
import cv2
def video_to_frames(video_file: str) -> List:
"""
Convert a video into frames.
Args:
video_file (str): The path to the video file.
Returns:
List[np.array]: A list of frames from the video.
"""
# Open the video file
vidcap = cv2.VideoCapture(video_file)
frames = []
success, image = vidcap.read()
while success:
frames.append(image)
success, image = vidcap.read()
return frames
def save_frames_as_images(frames, output_dir) -> None:
"""
Save a list of frames as image files.
Args:
frames (list of np.array): The list of frames.
output_dir (str): The directory where the images will be saved.
"""
for i, frame in enumerate(frames):
cv2.imwrite(f"{output_dir}/frame{i}.jpg", frame)
# out = save_frames_as_images(frames, "playground/demos/security_team/frames")
# print(out)
Loading…
Cancel
Save