You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
310 lines
9.1 KiB
310 lines
9.1 KiB
import datetime
|
|
import json
|
|
|
|
from termcolor import colored
|
|
|
|
from swarms.memory.base_db import AbstractDatabase
|
|
from swarms.structs.base import BaseStructure
|
|
|
|
|
|
class Conversation(BaseStructure):
|
|
"""
|
|
Conversation class
|
|
|
|
|
|
Attributes:
|
|
time_enabled (bool): whether to enable time
|
|
conversation_history (list): list of messages in the conversation
|
|
|
|
|
|
Examples:
|
|
>>> conv = Conversation()
|
|
>>> conv.add("user", "Hello, world!")
|
|
>>> conv.add("assistant", "Hello, user!")
|
|
>>> conv.display_conversation()
|
|
user: Hello, world!
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
time_enabled: bool = False,
|
|
database: AbstractDatabase = None,
|
|
autosave: bool = True,
|
|
save_filepath: str = "/runs/conversation.json",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.time_enabled = time_enabled
|
|
self.database = database
|
|
self.autosave = autosave
|
|
self.save_filepath = save_filepath
|
|
self.conversation_history = []
|
|
|
|
def add(self, role: str, content: str, *args, **kwargs):
|
|
"""Add a message to the conversation history
|
|
|
|
Args:
|
|
role (str): The role of the speaker
|
|
content (str): The content of the message
|
|
|
|
"""
|
|
if self.time_enabled:
|
|
now = datetime.datetime.now()
|
|
timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": timestamp,
|
|
}
|
|
else:
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
|
|
self.conversation_history.append(message)
|
|
|
|
if self.autosave:
|
|
self.save_as_json(self.save_filepath)
|
|
|
|
def delete(self, index: str):
|
|
"""Delete a message from the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to delete
|
|
"""
|
|
self.conversation_history.pop(index)
|
|
|
|
def update(self, index: str, role, content):
|
|
"""Update a message in the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to update
|
|
role (_type_): role of the speaker
|
|
content (_type_): content of the message
|
|
"""
|
|
self.conversation_history[index] = {
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
|
|
def query(self, index: str):
|
|
"""Query a message in the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to query
|
|
|
|
Returns:
|
|
str: the message
|
|
"""
|
|
return self.conversation_history[index]
|
|
|
|
def search(self, keyword: str):
|
|
"""Search for a message in the conversation history
|
|
|
|
Args:
|
|
keyword (str): Keyword to search for
|
|
|
|
Returns:
|
|
str: description
|
|
"""
|
|
return [
|
|
msg
|
|
for msg in self.conversation_history
|
|
if keyword in msg["content"]
|
|
]
|
|
|
|
def display_conversation(self, detailed: bool = False):
|
|
"""Display the conversation history
|
|
|
|
Args:
|
|
detailed (bool, optional): detailed. Defaults to False.
|
|
"""
|
|
role_to_color = {
|
|
"system": "red",
|
|
"user": "green",
|
|
"assistant": "blue",
|
|
"function": "magenta",
|
|
}
|
|
for message in self.conversation_history:
|
|
print(
|
|
colored(
|
|
f"{message['role']}: {message['content']}\n\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
|
|
def export_conversation(self, filename: str, *args, **kwargs):
|
|
"""Export the conversation history to a file
|
|
|
|
Args:
|
|
filename (str): filename to export to
|
|
"""
|
|
with open(filename, "w") as f:
|
|
for message in self.conversation_history:
|
|
f.write(f"{message['role']}: {message['content']}\n")
|
|
|
|
def import_conversation(self, filename: str):
|
|
"""Import a conversation history from a file
|
|
|
|
Args:
|
|
filename (str): filename to import from
|
|
"""
|
|
with open(filename, "r") as f:
|
|
for line in f:
|
|
role, content = line.split(": ", 1)
|
|
self.add(role, content.strip())
|
|
|
|
def count_messages_by_role(self):
|
|
"""Count the number of messages by role"""
|
|
counts = {
|
|
"system": 0,
|
|
"user": 0,
|
|
"assistant": 0,
|
|
"function": 0,
|
|
}
|
|
for message in self.conversation_history:
|
|
counts[message["role"]] += 1
|
|
return counts
|
|
|
|
def return_history_as_string(self):
|
|
"""Return the conversation history as a string
|
|
|
|
Returns:
|
|
str: the conversation history
|
|
"""
|
|
return "\n".join(
|
|
[
|
|
f"{message['role']}: {message['content']}\n\n"
|
|
for message in self.conversation_history
|
|
]
|
|
)
|
|
|
|
def save_as_json(self, filename: str):
|
|
"""Save the conversation history as a JSON file
|
|
|
|
Args:
|
|
filename (str): Save the conversation history as a JSON file
|
|
"""
|
|
# Save the conversation history as a JSON file
|
|
with open(filename, "w") as f:
|
|
json.dump(self.conversation_history, f)
|
|
|
|
def load_from_json(self, filename: str):
|
|
"""Load the conversation history from a JSON file
|
|
|
|
Args:
|
|
filename (str): filename to load from
|
|
"""
|
|
# Load the conversation history from a JSON file
|
|
with open(filename, "r") as f:
|
|
self.conversation_history = json.load(f)
|
|
|
|
def search_keyword_in_conversation(self, keyword: str):
|
|
"""Search for a keyword in the conversation history
|
|
|
|
Args:
|
|
keyword (str): keyword to search for
|
|
|
|
Returns:
|
|
str: description
|
|
"""
|
|
return [
|
|
msg
|
|
for msg in self.conversation_history
|
|
if keyword in msg["content"]
|
|
]
|
|
|
|
def pretty_print_conversation(self, messages):
|
|
"""Pretty print the conversation history
|
|
|
|
Args:
|
|
messages (str): messages to print
|
|
"""
|
|
role_to_color = {
|
|
"system": "red",
|
|
"user": "green",
|
|
"assistant": "blue",
|
|
"tool": "magenta",
|
|
}
|
|
|
|
for message in messages:
|
|
if message["role"] == "system":
|
|
print(
|
|
colored(
|
|
f"system: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "user":
|
|
print(
|
|
colored(
|
|
f"user: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "assistant" and message.get(
|
|
"function_call"
|
|
):
|
|
print(
|
|
colored(
|
|
f"assistant: {message['function_call']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "assistant" and not message.get(
|
|
"function_call"
|
|
):
|
|
print(
|
|
colored(
|
|
f"assistant: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "tool":
|
|
print(
|
|
colored(
|
|
(
|
|
f"function ({message['name']}):"
|
|
f" {message['content']}\n"
|
|
),
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
|
|
def add_to_database(self, *args, **kwargs):
|
|
"""Add the conversation history to the database"""
|
|
self.database.add("conversation", self.conversation_history)
|
|
|
|
def query_from_database(self, query, *args, **kwargs):
|
|
"""Query the conversation history from the database"""
|
|
return self.database.query("conversation", query)
|
|
|
|
def delete_from_database(self, *args, **kwargs):
|
|
"""Delete the conversation history from the database"""
|
|
self.database.delete("conversation")
|
|
|
|
def update_from_database(self, *args, **kwargs):
|
|
"""Update the conversation history from the database"""
|
|
self.database.update(
|
|
"conversation", self.conversation_history
|
|
)
|
|
|
|
def get_from_database(self, *args, **kwargs):
|
|
"""Get the conversation history from the database"""
|
|
return self.database.get("conversation")
|
|
|
|
def execute_query_from_database(self, query, *args, **kwargs):
|
|
"""Execute a query on the database"""
|
|
return self.database.execute_query(query)
|
|
|
|
def fetch_all_from_database(self, *args, **kwargs):
|
|
"""Fetch all from the database"""
|
|
return self.database.fetch_all()
|
|
|
|
def fetch_one_from_database(self, *args, **kwargs):
|
|
"""Fetch one from the database"""
|
|
return self.database.fetch_one()
|