You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
416 lines
14 KiB
416 lines
14 KiB
import datetime
|
|
import json
|
|
from typing import Any, Optional
|
|
|
|
import yaml
|
|
from swarms.structs.base_structure import BaseStructure
|
|
from typing import TYPE_CHECKING
|
|
from swarms.utils.formatter import formatter
|
|
|
|
if TYPE_CHECKING:
|
|
from swarms.structs.agent import (
|
|
Agent,
|
|
) # Only imported during type checking
|
|
|
|
|
|
class Conversation(BaseStructure):
|
|
"""
|
|
A class structure to represent a conversation in a chatbot. This class is used to store the conversation history.
|
|
And, it can be used to save the conversation history to a file, load the conversation history from a file, and
|
|
display the conversation history. We can also use this class to add the conversation history to a database, query
|
|
the conversation history from a database, delete the conversation history from a database, update the conversation
|
|
history from a database, and get the conversation history from a database.
|
|
|
|
|
|
Args:
|
|
time_enabled (bool): Whether to enable timestamps for the conversation history. Default is False.
|
|
database (AbstractDatabase): The database to use for storing the conversation history. Default is None.
|
|
autosave (bool): Whether to autosave the conversation history to a file. Default is None.
|
|
save_filepath (str): The filepath to save the conversation history to. Default is None.
|
|
|
|
|
|
Methods:
|
|
add(role: str, content: str): Add a message to the conversation history.
|
|
delete(index: str): Delete a message from the conversation history.
|
|
update(index: str, role, content): Update a message in the conversation history.
|
|
query(index: str): Query a message in the conversation history.
|
|
search(keyword: str): Search for a message in the conversation history.
|
|
display_conversation(detailed: bool = False): Display the conversation history.
|
|
export_conversation(filename: str): Export the conversation history to a file.
|
|
import_conversation(filename: str): Import a conversation history from a file.
|
|
count_messages_by_role(): Count the number of messages by role.
|
|
return_history_as_string(): Return the conversation history as a string.
|
|
save_as_json(filename: str): Save the conversation history as a JSON file.
|
|
load_from_json(filename: str): Load the conversation history from a JSON file.
|
|
search_keyword_in_conversation(keyword: str): Search for a keyword in the conversation history.
|
|
pretty_print_conversation(messages): Pretty print the conversation history.
|
|
add_to_database(): Add the conversation history to the database.
|
|
query_from_database(query): Query the conversation history from the database.
|
|
delete_from_database(): Delete the conversation history from the database.
|
|
update_from_database(): Update the conversation history from the database.
|
|
get_from_database(): Get the conversation history from the database.
|
|
execute_query_from_database(query): Execute a query on the database.
|
|
fetch_all_from_database(): Fetch all from the database.
|
|
fetch_one_from_database(): Fetch one from the database.
|
|
|
|
Examples:
|
|
>>> from swarms import Conversation
|
|
>>> conversation = Conversation()
|
|
>>> conversation.add("user", "Hello, how are you?")
|
|
>>> conversation.add("assistant", "I am doing well, thanks.")
|
|
>>> conversation.display_conversation()
|
|
user: Hello, how are you?
|
|
assistant: I am doing well, thanks.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
system_prompt: Optional[str] = None,
|
|
time_enabled: bool = False,
|
|
autosave: bool = False,
|
|
save_filepath: str = None,
|
|
tokenizer: Any = None,
|
|
context_length: int = 8192,
|
|
rules: str = None,
|
|
custom_rules_prompt: str = None,
|
|
user: str = "User:",
|
|
auto_save: bool = True,
|
|
save_as_yaml: bool = True,
|
|
save_as_json_bool: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.system_prompt = system_prompt
|
|
self.time_enabled = time_enabled
|
|
self.autosave = autosave
|
|
self.save_filepath = save_filepath
|
|
self.conversation_history = []
|
|
self.tokenizer = tokenizer
|
|
self.context_length = context_length
|
|
self.rules = rules
|
|
self.custom_rules_prompt = custom_rules_prompt
|
|
self.user = user
|
|
self.auto_save = auto_save
|
|
self.save_as_yaml = save_as_yaml
|
|
self.save_as_json_bool = save_as_json_bool
|
|
|
|
# If system prompt is not None, add it to the conversation history
|
|
if self.system_prompt is not None:
|
|
self.add("System: ", self.system_prompt)
|
|
|
|
if self.rules is not None:
|
|
self.add("User", rules)
|
|
|
|
if custom_rules_prompt is not None:
|
|
self.add(user or "User", custom_rules_prompt)
|
|
|
|
# If tokenizer then truncate
|
|
if tokenizer is not None:
|
|
self.truncate_memory_with_tokenizer()
|
|
|
|
def add(self, role: str, content: str, *args, **kwargs):
|
|
"""Add a message to the conversation history
|
|
|
|
Args:
|
|
role (str): The role of the speaker
|
|
content (str): The content of the message
|
|
|
|
"""
|
|
if self.time_enabled:
|
|
now = datetime.datetime.now()
|
|
timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": timestamp,
|
|
}
|
|
else:
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
|
|
self.conversation_history.append(message)
|
|
|
|
if self.autosave:
|
|
self.save_as_json(self.save_filepath)
|
|
|
|
def delete(self, index: str):
|
|
"""Delete a message from the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to delete
|
|
"""
|
|
self.conversation_history.pop(index)
|
|
|
|
def update(self, index: str, role, content):
|
|
"""Update a message in the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to update
|
|
role (_type_): role of the speaker
|
|
content (_type_): content of the message
|
|
"""
|
|
self.conversation_history[index] = {
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
|
|
def query(self, index: str):
|
|
"""Query a message in the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to query
|
|
|
|
Returns:
|
|
str: the message
|
|
"""
|
|
return self.conversation_history[index]
|
|
|
|
def search(self, keyword: str):
|
|
"""Search for a message in the conversation history
|
|
|
|
Args:
|
|
keyword (str): Keyword to search for
|
|
|
|
Returns:
|
|
str: description
|
|
"""
|
|
return [
|
|
msg
|
|
for msg in self.conversation_history
|
|
if keyword in msg["content"]
|
|
]
|
|
|
|
def display_conversation(self, detailed: bool = False):
|
|
"""Display the conversation history
|
|
|
|
Args:
|
|
detailed (bool, optional): detailed. Defaults to False.
|
|
"""
|
|
for message in self.conversation_history:
|
|
formatter.print_panel(
|
|
f"{message['role']}: {message['content']}\n\n"
|
|
)
|
|
|
|
def export_conversation(self, filename: str, *args, **kwargs):
|
|
"""Export the conversation history to a file
|
|
|
|
Args:
|
|
filename (str): filename to export to
|
|
"""
|
|
with open(filename, "w") as f:
|
|
for message in self.conversation_history:
|
|
f.write(f"{message['role']}: {message['content']}\n")
|
|
|
|
def import_conversation(self, filename: str):
|
|
"""Import a conversation history from a file
|
|
|
|
Args:
|
|
filename (str): filename to import from
|
|
"""
|
|
with open(filename) as f:
|
|
for line in f:
|
|
role, content = line.split(": ", 1)
|
|
self.add(role, content.strip())
|
|
|
|
def count_messages_by_role(self):
|
|
"""Count the number of messages by role"""
|
|
counts = {
|
|
"system": 0,
|
|
"user": 0,
|
|
"assistant": 0,
|
|
"function": 0,
|
|
}
|
|
for message in self.conversation_history:
|
|
counts[message["role"]] += 1
|
|
return counts
|
|
|
|
def return_history_as_string(self):
|
|
"""Return the conversation history as a string
|
|
|
|
Returns:
|
|
str: the conversation history
|
|
"""
|
|
return "\n".join(
|
|
[
|
|
f"{message['role']}: {message['content']}\n\n"
|
|
for message in self.conversation_history
|
|
]
|
|
)
|
|
|
|
def get_str(self):
|
|
return self.return_history_as_string()
|
|
|
|
def save_as_json(self, filename: str = None):
|
|
"""Save the conversation history as a JSON file
|
|
|
|
Args:
|
|
filename (str): Save the conversation history as a JSON file
|
|
"""
|
|
# Create the directory if it does not exist
|
|
# os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
if filename is not None:
|
|
with open(filename, "w") as f:
|
|
json.dump(self.conversation_history, f)
|
|
|
|
def load_from_json(self, filename: str):
|
|
"""Load the conversation history from a JSON file
|
|
|
|
Args:
|
|
filename (str): filename to load from
|
|
"""
|
|
# Load the conversation history from a JSON file
|
|
if filename is not None:
|
|
with open(filename) as f:
|
|
self.conversation_history = json.load(f)
|
|
|
|
def search_keyword_in_conversation(self, keyword: str):
|
|
"""Search for a keyword in the conversation history
|
|
|
|
Args:
|
|
keyword (str): keyword to search for
|
|
|
|
Returns:
|
|
str: description
|
|
"""
|
|
return [
|
|
msg
|
|
for msg in self.conversation_history
|
|
if keyword in msg["content"]
|
|
]
|
|
|
|
def pretty_print_conversation(self, messages):
|
|
"""Pretty print the conversation history
|
|
|
|
Args:
|
|
messages (str): messages to print
|
|
"""
|
|
role_to_color = {
|
|
"system": "red",
|
|
"user": "green",
|
|
"assistant": "blue",
|
|
"tool": "magenta",
|
|
}
|
|
|
|
for message in messages:
|
|
if message["role"] == "system":
|
|
formatter.print_panel(
|
|
f"system: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
elif message["role"] == "user":
|
|
formatter.print_panel(
|
|
f"user: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
elif message["role"] == "assistant" and message.get(
|
|
"function_call"
|
|
):
|
|
formatter.print_panel(
|
|
f"assistant: {message['function_call']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
elif message["role"] == "assistant" and not message.get(
|
|
"function_call"
|
|
):
|
|
formatter.print_panel(
|
|
f"assistant: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
elif message["role"] == "tool":
|
|
formatter.print_panel(
|
|
(
|
|
f"function ({message['name']}):"
|
|
f" {message['content']}\n"
|
|
),
|
|
role_to_color[message["role"]],
|
|
)
|
|
|
|
def truncate_memory_with_tokenizer(self):
|
|
"""
|
|
Truncates the conversation history based on the total number of tokens using a tokenizer.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
total_tokens = 0
|
|
truncated_history = []
|
|
|
|
for message in self.conversation_history:
|
|
role = message.get("role")
|
|
content = message.get("content")
|
|
tokens = self.tokenizer.count_tokens(
|
|
text=content
|
|
) # Count the number of tokens
|
|
count = tokens # Assign the token count
|
|
total_tokens += count
|
|
|
|
if total_tokens <= self.context_length:
|
|
truncated_history.append(message)
|
|
else:
|
|
remaining_tokens = self.context_length - (
|
|
total_tokens - count
|
|
)
|
|
truncated_content = content[
|
|
:remaining_tokens
|
|
] # Truncate the content based on the remaining tokens
|
|
truncated_message = {
|
|
"role": role,
|
|
"content": truncated_content,
|
|
}
|
|
truncated_history.append(truncated_message)
|
|
break
|
|
|
|
self.conversation_history = truncated_history
|
|
|
|
def clear(self):
|
|
self.conversation_history = []
|
|
|
|
def to_json(self):
|
|
return json.dumps(self.conversation_history)
|
|
|
|
def to_dict(self):
|
|
return self.conversation_history
|
|
|
|
def to_yaml(self):
|
|
return yaml.dump(self.conversation_history)
|
|
|
|
def get_visible_messages(self, agent: "Agent", turn: int):
|
|
"""
|
|
Get the visible messages for a given agent and turn.
|
|
|
|
Args:
|
|
agent (Agent): The agent.
|
|
turn (int): The turn number.
|
|
|
|
Returns:
|
|
List[Dict]: The list of visible messages.
|
|
"""
|
|
# Get the messages before the current turn
|
|
prev_messages = [
|
|
message
|
|
for message in self.conversation_history
|
|
if message["turn"] < turn
|
|
]
|
|
|
|
visible_messages = []
|
|
for message in prev_messages:
|
|
if (
|
|
message["visible_to"] == "all"
|
|
or agent.agent_name in message["visible_to"]
|
|
):
|
|
visible_messages.append(message)
|
|
return visible_messages
|
|
|
|
|
|
# # Example usage
|
|
# conversation = Conversation()
|
|
# conversation.add("user", "Hello, how are you?")
|
|
# conversation.add("assistant", "I am doing well, thanks.")
|
|
# # print(conversation.to_json())
|
|
# print(type(conversation.to_dict()))
|
|
# # print(conversation.to_yaml())
|