commit
bb0d18eacc
@ -0,0 +1,4 @@
|
|||||||
|
"""
|
||||||
|
This tutorial shows you how to integrate swarms with Langchain
|
||||||
|
|
||||||
|
"""
|
@ -0,0 +1,10 @@
|
|||||||
|
from swarms import Kosmos
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = Kosmos()
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
out = model.run("Analyze the reciepts in this image", "docs.jpg")
|
||||||
|
|
||||||
|
# Print the output
|
||||||
|
print(out)
|
@ -0,0 +1,26 @@
|
|||||||
|
from swarms.structs.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
def agent_wrapper(ClassToWrap):
|
||||||
|
"""
|
||||||
|
This function takes a class 'ClassToWrap' and returns a new class that
|
||||||
|
inherits from both 'ClassToWrap' and 'Agent'. The new class overrides
|
||||||
|
the '__init__' method of 'Agent' to call the '__init__' method of 'ClassToWrap'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ClassToWrap (type): The class to be wrapped and made to inherit from 'Agent'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
type: The new class that inherits from both 'ClassToWrap' and 'Agent'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class WrappedClass(ClassToWrap, Agent):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
Agent.__init__(self, *args, **kwargs)
|
||||||
|
ClassToWrap.__init__(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing WrappedClass: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return WrappedClass
|
@ -0,0 +1,141 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractVectorDatabase(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for a database.
|
||||||
|
|
||||||
|
This class defines the interface for interacting with a database.
|
||||||
|
Subclasses must implement the abstract methods to provide the
|
||||||
|
specific implementation details for connecting to a database,
|
||||||
|
executing queries, and performing CRUD operations.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def connect(self):
|
||||||
|
"""
|
||||||
|
Connect to the database.
|
||||||
|
|
||||||
|
This method establishes a connection to the database.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close the database connection.
|
||||||
|
|
||||||
|
This method closes the connection to the database.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(self, query: str):
|
||||||
|
"""
|
||||||
|
Execute a database query.
|
||||||
|
|
||||||
|
This method executes the given query on the database.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query (str): The query to be executed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fetch_all(self):
|
||||||
|
"""
|
||||||
|
Fetch all rows from the result set.
|
||||||
|
|
||||||
|
This method retrieves all rows from the result set of a query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of dictionaries representing the rows.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fetch_one(self):
|
||||||
|
"""
|
||||||
|
Fetch one row from the result set.
|
||||||
|
|
||||||
|
This method retrieves one row from the result set of a query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary representing the row.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(self, doc: str):
|
||||||
|
"""
|
||||||
|
Add a new record to the database.
|
||||||
|
|
||||||
|
This method adds a new record to the specified table in the database.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
table (str): The name of the table.
|
||||||
|
data (dict): A dictionary representing the data to be added.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, query: str):
|
||||||
|
"""
|
||||||
|
Get a record from the database.
|
||||||
|
|
||||||
|
This method retrieves a record from the specified table in the database based on the given ID.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
table (str): The name of the table.
|
||||||
|
id (int): The ID of the record to be retrieved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary representing the retrieved record.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, doc):
|
||||||
|
"""
|
||||||
|
Update a record in the database.
|
||||||
|
|
||||||
|
This method updates a record in the specified table in the database based on the given ID.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
table (str): The name of the table.
|
||||||
|
id (int): The ID of the record to be updated.
|
||||||
|
data (dict): A dictionary representing the updated data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, message):
|
||||||
|
"""
|
||||||
|
Delete a record from the database.
|
||||||
|
|
||||||
|
This method deletes a record from the specified table in the database based on the given ID.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
table (str): The name of the table.
|
||||||
|
id (int): The ID of the record to be deleted.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
@ -1,58 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDatabase(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def add(
|
|
||||||
self, vector: Dict[str, Any], metadata: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
add a vector into the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vector (Dict[str, Any]): The vector to add.
|
|
||||||
metadata (Dict[str, Any]): Metadata associated with the vector.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def query(self, text: str, num_results: int) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Query the database for vectors similar to the given vector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (Dict[str, Any]): The vector to compare against.
|
|
||||||
num_results (int): The number of similar vectors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: The most similar vectors and their associated metadata.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, vector_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Delete a vector from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vector_id (str): The ID of the vector to delete.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
vector_id: str,
|
|
||||||
vector: Dict[str, Any],
|
|
||||||
metadata: Dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update a vector in the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vector_id (str): The ID of the vector to update.
|
|
||||||
vector (Dict[str, Any]): The new vector.
|
|
||||||
metadata (Dict[str, Any]): The new metadata.
|
|
||||||
"""
|
|
||||||
pass
|
|
@ -0,0 +1,183 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import GaussianBlur
|
||||||
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPQ:
|
||||||
|
"""
|
||||||
|
ClipQ is an CLIQ based model that can be used to generate captions for images.
|
||||||
|
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_name (str): The name of the model to be used.
|
||||||
|
query_text (str): The query text to be used for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): The name of the model to be used.
|
||||||
|
query_text (str): The query text to be used for the model.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "openai/clip-vit-base-patch16",
|
||||||
|
query_text: str = "A photo ",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model = CLIPModel.from_pretrained(
|
||||||
|
model_name, *args, **kwargs
|
||||||
|
)
|
||||||
|
self.processor = CLIPProcessor.from_pretrained(model_name)
|
||||||
|
self.query_text = query_text
|
||||||
|
|
||||||
|
def fetch_image_from_url(self, url="https://picsum.photos/800"):
|
||||||
|
"""Fetches an image from the given url"""
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception("Failed to fetch an image")
|
||||||
|
image = Image.open(BytesIO(response.content))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def load_image_from_path(self, path):
|
||||||
|
"""Loads an image from the given path"""
|
||||||
|
return Image.open(path)
|
||||||
|
|
||||||
|
def split_image(
|
||||||
|
self, image, h_splits: int = 2, v_splits: int = 2
|
||||||
|
):
|
||||||
|
"""Splits the given image into h_splits x v_splits parts"""
|
||||||
|
width, height = image.size
|
||||||
|
w_step, h_step = width // h_splits, height // v_splits
|
||||||
|
slices = []
|
||||||
|
|
||||||
|
for i in range(v_splits):
|
||||||
|
for j in range(h_splits):
|
||||||
|
slice = image.crop(
|
||||||
|
(
|
||||||
|
j * w_step,
|
||||||
|
i * h_step,
|
||||||
|
(j + 1) * w_step,
|
||||||
|
(i + 1) * h_step,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
slices.append(slice)
|
||||||
|
return slices
|
||||||
|
|
||||||
|
def get_vectors(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
h_splits: int = 2,
|
||||||
|
v_splits: int = 2,
|
||||||
|
):
|
||||||
|
"""Gets the vectors for the given image"""
|
||||||
|
slices = self.split_image(image, h_splits, v_splits)
|
||||||
|
vectors = []
|
||||||
|
|
||||||
|
for slice in slices:
|
||||||
|
inputs = self.processor(
|
||||||
|
text=self.query_text,
|
||||||
|
images=slice,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
)
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
vectors.append(
|
||||||
|
outputs.image_embeds.squeeze().detach().numpy()
|
||||||
|
)
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
def run_from_url(
|
||||||
|
self,
|
||||||
|
url: str = "https://picsum.photos/800",
|
||||||
|
h_splits: int = 2,
|
||||||
|
v_splits: int = 2,
|
||||||
|
):
|
||||||
|
"""Runs the model on the image fetched from the given url"""
|
||||||
|
image = self.fetch_image_from_url(url)
|
||||||
|
return self.get_vectors(image, h_splits, v_splits)
|
||||||
|
|
||||||
|
def check_hard_chunking(self, quadrants):
|
||||||
|
"""Check if the chunking is hard"""
|
||||||
|
variances = []
|
||||||
|
for quadrant in quadrants:
|
||||||
|
edge_pixels = torch.cat(
|
||||||
|
[
|
||||||
|
quadrant[0, 1],
|
||||||
|
quadrant[-1, :],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
variances.append(torch.var(edge_pixels).item())
|
||||||
|
return variances
|
||||||
|
|
||||||
|
def embed_whole_image(self, image):
|
||||||
|
"""Embed the entire image"""
|
||||||
|
inputs = self.processor(
|
||||||
|
image,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
return outputs.image_embeds.squeeze()
|
||||||
|
|
||||||
|
def apply_noise_reduction(self, image, kernel_size: int = 5):
|
||||||
|
"""Implement an upscaling method to upscale the image and tiling issues"""
|
||||||
|
blur = GaussianBlur(kernel_size)
|
||||||
|
return blur(image)
|
||||||
|
|
||||||
|
def run_from_path(
|
||||||
|
self, path: str = None, h_splits: int = 2, v_splits: int = 2
|
||||||
|
):
|
||||||
|
"""Runs the model on the image loaded from the given path"""
|
||||||
|
image = self.load_image_from_path(path)
|
||||||
|
return self.get_vectors(image, h_splits, v_splits)
|
||||||
|
|
||||||
|
def get_captions(self, image, candidate_captions):
|
||||||
|
"""Get the best caption for the given image"""
|
||||||
|
inputs_image = self.processor(
|
||||||
|
images=image,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_text = self.processor(
|
||||||
|
text=candidate_captions,
|
||||||
|
images=inputs_image.pixel_values[
|
||||||
|
0
|
||||||
|
], # Fix the argument name
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_embeds = self.model(
|
||||||
|
pixel_values=inputs_image.pixel_values[0]
|
||||||
|
).image_embeds
|
||||||
|
text_embeds = self.model(
|
||||||
|
input_ids=inputs_text.input_ids,
|
||||||
|
attention_mask=inputs_text.attention_mask,
|
||||||
|
).text_embeds
|
||||||
|
|
||||||
|
# Calculate similarity between image and text
|
||||||
|
similarities = (image_embeds @ text_embeds.T).squeeze(0)
|
||||||
|
best_caption_index = similarities.argmax().item()
|
||||||
|
|
||||||
|
return candidate_captions[best_caption_index]
|
||||||
|
|
||||||
|
def get_and_concat_captions(
|
||||||
|
self, image, candidate_captions, h_splits=2, v_splits=2
|
||||||
|
):
|
||||||
|
"""Get the best caption for the given image"""
|
||||||
|
slices = self.split_image(image, h_splits, v_splits)
|
||||||
|
captions = [
|
||||||
|
self.get_captions(slice, candidate_captions)
|
||||||
|
for slice in slices
|
||||||
|
]
|
||||||
|
concated_captions = "".join(captions)
|
||||||
|
return concated_captions
|
@ -0,0 +1,85 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from swarms.memory.base_vectordatabase import AbstractVectorDatabase
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiAgentRag:
|
||||||
|
"""
|
||||||
|
Represents a multi-agent RAG (Relational Agent Graph) structure.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agents (List[Agent]): List of agents in the multi-agent RAG.
|
||||||
|
db (AbstractVectorDatabase): Database used for querying.
|
||||||
|
verbose (bool): Flag indicating whether to print verbose output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
agents: List[Agent]
|
||||||
|
db: AbstractVectorDatabase
|
||||||
|
verbose: bool = False
|
||||||
|
|
||||||
|
def query_database(self, query: str):
|
||||||
|
"""
|
||||||
|
Queries the database using the given query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: The list of results from the database.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for agent in self.agents:
|
||||||
|
agent_results = agent.long_term_memory_prompt(query)
|
||||||
|
results.extend(agent_results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_agent_by_id(self, agent_id) -> Optional[Agent]:
|
||||||
|
"""
|
||||||
|
Retrieves an agent from the multi-agent RAG by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The ID of the agent to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent or None: The agent with the specified ID, or None if not found.
|
||||||
|
"""
|
||||||
|
for agent in self.agents:
|
||||||
|
if agent.agent_id == agent_id:
|
||||||
|
return agent
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self, sender: Agent, message: str, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds a message to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sender (Agent): The agent sending the message.
|
||||||
|
message (str): The message to add.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The ID of the added message.
|
||||||
|
"""
|
||||||
|
doc = f"{sender.ai_name}: {message}"
|
||||||
|
|
||||||
|
return self.db.add(doc)
|
||||||
|
|
||||||
|
def query(self, message: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Queries the database using the given message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to query.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: The list of results from the database.
|
||||||
|
"""
|
||||||
|
return self.db.query(message)
|
@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonParsingException(Exception):
|
||||||
|
"""Custom exception for errors in JSON parsing."""
|
||||||
|
|
||||||
|
|
||||||
|
class JsonOutputParser:
|
||||||
|
"""Parse JSON output using a Pydantic model.
|
||||||
|
|
||||||
|
This parser is designed to extract JSON formatted data from a given string
|
||||||
|
and parse it using a specified Pydantic model for validation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pydantic_object: A Pydantic model class for parsing and validation.
|
||||||
|
pattern: A regex pattern to match JSON code blocks.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from pydantic import BaseModel
|
||||||
|
>>> from swarms.utils.json_output_parser import JsonOutputParser
|
||||||
|
>>> class MyModel(BaseModel):
|
||||||
|
... name: str
|
||||||
|
... age: int
|
||||||
|
...
|
||||||
|
>>> parser = JsonOutputParser(MyModel)
|
||||||
|
>>> text = "```json\n{\"name\": \"John\", \"age\": 42}\n```"
|
||||||
|
>>> model = parser.parse(text)
|
||||||
|
>>> model.name
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pydantic_object: Type[T]):
|
||||||
|
self.pydantic_object = pydantic_object
|
||||||
|
self.pattern = re.compile(
|
||||||
|
r"^```(?:json)?(?P<json>[^`]*)", re.MULTILINE | re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> T:
|
||||||
|
"""Parse the provided text to extract and validate JSON data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A string containing potential JSON data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the specified Pydantic model with parsed data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
JsonParsingException: If parsing or validation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
match = re.search(self.pattern, text.strip())
|
||||||
|
json_str = match.group("json") if match else text
|
||||||
|
|
||||||
|
json_object = json.loads(json_str)
|
||||||
|
return self.pydantic_object.parse_obj(json_object)
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, ValidationError) as e:
|
||||||
|
name = self.pydantic_object.__name__
|
||||||
|
msg = (
|
||||||
|
f"Failed to parse {name} from text '{text}'."
|
||||||
|
f" Error: {e}"
|
||||||
|
)
|
||||||
|
raise JsonParsingException(msg) from e
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
"""Generate formatting instructions based on the Pydantic model schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string containing formatting instructions.
|
||||||
|
"""
|
||||||
|
schema = self.pydantic_object.schema()
|
||||||
|
reduced_schema = {
|
||||||
|
k: v
|
||||||
|
for k, v in schema.items()
|
||||||
|
if k not in ["title", "type"]
|
||||||
|
}
|
||||||
|
schema_str = json.dumps(reduced_schema, indent=4)
|
||||||
|
|
||||||
|
format_instructions = (
|
||||||
|
f"JSON Formatting Instructions:\n{schema_str}"
|
||||||
|
)
|
||||||
|
return format_instructions
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# class ExampleModel(BaseModel):
|
||||||
|
# field1: int
|
||||||
|
# field2: str
|
||||||
|
|
||||||
|
# parser = JsonOutputParser(ExampleModel)
|
||||||
|
# # Use parser.parse(text) to parse JSON data
|
@ -0,0 +1,50 @@
|
|||||||
|
import json
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def remove_whitespace_from_json(json_string: str) -> str:
|
||||||
|
"""
|
||||||
|
Removes unnecessary whitespace from a JSON string.
|
||||||
|
|
||||||
|
This function parses the JSON string into a Python object and then
|
||||||
|
serializes it back into a JSON string without unnecessary whitespace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string (str): The JSON string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The JSON string with whitespace removed.
|
||||||
|
"""
|
||||||
|
parsed = json.loads(json_string)
|
||||||
|
return json.dumps(parsed, separators=(",", ":"))
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage for JSON
|
||||||
|
# json_string = '{"field1": 123, "field2": "example text"}'
|
||||||
|
# print(remove_whitespace_from_json(json_string))
|
||||||
|
|
||||||
|
|
||||||
|
def remove_whitespace_from_yaml(yaml_string: str) -> str:
|
||||||
|
"""
|
||||||
|
Removes unnecessary whitespace from a YAML string.
|
||||||
|
|
||||||
|
This function parses the YAML string into a Python object and then
|
||||||
|
serializes it back into a YAML string with minimized whitespace.
|
||||||
|
Note: This might change the representation style of YAML data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_string (str): The YAML string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The YAML string with whitespace reduced.
|
||||||
|
"""
|
||||||
|
parsed = yaml.safe_load(yaml_string)
|
||||||
|
return yaml.dump(parsed, default_flow_style=True)
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage for YAML
|
||||||
|
# yaml_string = """
|
||||||
|
# field1: 123
|
||||||
|
# field2: example text
|
||||||
|
# """
|
||||||
|
# print(remove_whitespace_from_yaml(yaml_string))
|
@ -0,0 +1,89 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import yaml
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class YamlParsingException(Exception):
|
||||||
|
"""Custom exception for errors in YAML parsing."""
|
||||||
|
|
||||||
|
|
||||||
|
class YamlOutputParser:
|
||||||
|
"""Parse YAML output using a Pydantic model.
|
||||||
|
|
||||||
|
This parser is designed to extract YAML formatted data from a given string
|
||||||
|
and parse it using a specified Pydantic model for validation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pydantic_object: A Pydantic model class for parsing and validation.
|
||||||
|
pattern: A regex pattern to match YAML code blocks.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from pydantic import BaseModel
|
||||||
|
>>> from swarms.utils.yaml_output_parser import YamlOutputParser
|
||||||
|
>>> class MyModel(BaseModel):
|
||||||
|
... name: str
|
||||||
|
... age: int
|
||||||
|
...
|
||||||
|
>>> parser = YamlOutputParser(MyModel)
|
||||||
|
>>> text = "```yaml\nname: John\nage: 42\n```"
|
||||||
|
>>> model = parser.parse(text)
|
||||||
|
>>> model.name
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pydantic_object: Type[T]):
|
||||||
|
self.pydantic_object = pydantic_object
|
||||||
|
self.pattern = re.compile(
|
||||||
|
r"^```(?:ya?ml)?(?P<yaml>[^`]*)", re.MULTILINE | re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> T:
|
||||||
|
"""Parse the provided text to extract and validate YAML data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A string containing potential YAML data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the specified Pydantic model with parsed data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
YamlParsingException: If parsing or validation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
match = re.search(self.pattern, text.strip())
|
||||||
|
yaml_str = match.group("yaml") if match else text
|
||||||
|
|
||||||
|
json_object = yaml.safe_load(yaml_str)
|
||||||
|
return self.pydantic_object.parse_obj(json_object)
|
||||||
|
|
||||||
|
except (yaml.YAMLError, ValidationError) as e:
|
||||||
|
name = self.pydantic_object.__name__
|
||||||
|
msg = (
|
||||||
|
f"Failed to parse {name} from text '{text}'."
|
||||||
|
f" Error: {e}"
|
||||||
|
)
|
||||||
|
raise YamlParsingException(msg) from e
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
"""Generate formatting instructions based on the Pydantic model schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string containing formatting instructions.
|
||||||
|
"""
|
||||||
|
schema = self.pydantic_object.schema()
|
||||||
|
reduced_schema = {
|
||||||
|
k: v
|
||||||
|
for k, v in schema.items()
|
||||||
|
if k not in ["title", "type"]
|
||||||
|
}
|
||||||
|
schema_str = json.dumps(reduced_schema, indent=4)
|
||||||
|
|
||||||
|
format_instructions = (
|
||||||
|
f"YAML Formatting Instructions:\n{schema_str}"
|
||||||
|
)
|
||||||
|
return format_instructions
|
Loading…
Reference in new issue