You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
417 lines
14 KiB
417 lines
14 KiB
import datetime
|
|
import json
|
|
from typing import Optional
|
|
|
|
from termcolor import colored
|
|
|
|
from swarms.memory.base_db import AbstractDatabase
|
|
from swarms.structs.base_structure import BaseStructure
|
|
from typing import Any
|
|
|
|
|
|
class Conversation(BaseStructure):
|
|
"""
|
|
A class structure to represent a conversation in a chatbot. This class is used to store the conversation history.
|
|
And, it can be used to save the conversation history to a file, load the conversation history from a file, and
|
|
display the conversation history. We can also use this class to add the conversation history to a database, query
|
|
the conversation history from a database, delete the conversation history from a database, update the conversation
|
|
history from a database, and get the conversation history from a database.
|
|
|
|
|
|
Args:
|
|
time_enabled (bool): Whether to enable timestamps for the conversation history. Default is False.
|
|
database (AbstractDatabase): The database to use for storing the conversation history. Default is None.
|
|
autosave (bool): Whether to autosave the conversation history to a file. Default is None.
|
|
save_filepath (str): The filepath to save the conversation history to. Default is None.
|
|
|
|
|
|
Methods:
|
|
add(role: str, content: str): Add a message to the conversation history.
|
|
delete(index: str): Delete a message from the conversation history.
|
|
update(index: str, role, content): Update a message in the conversation history.
|
|
query(index: str): Query a message in the conversation history.
|
|
search(keyword: str): Search for a message in the conversation history.
|
|
display_conversation(detailed: bool = False): Display the conversation history.
|
|
export_conversation(filename: str): Export the conversation history to a file.
|
|
import_conversation(filename: str): Import a conversation history from a file.
|
|
count_messages_by_role(): Count the number of messages by role.
|
|
return_history_as_string(): Return the conversation history as a string.
|
|
save_as_json(filename: str): Save the conversation history as a JSON file.
|
|
load_from_json(filename: str): Load the conversation history from a JSON file.
|
|
search_keyword_in_conversation(keyword: str): Search for a keyword in the conversation history.
|
|
pretty_print_conversation(messages): Pretty print the conversation history.
|
|
add_to_database(): Add the conversation history to the database.
|
|
query_from_database(query): Query the conversation history from the database.
|
|
delete_from_database(): Delete the conversation history from the database.
|
|
update_from_database(): Update the conversation history from the database.
|
|
get_from_database(): Get the conversation history from the database.
|
|
execute_query_from_database(query): Execute a query on the database.
|
|
fetch_all_from_database(): Fetch all from the database.
|
|
fetch_one_from_database(): Fetch one from the database.
|
|
|
|
Examples:
|
|
>>> from swarms import Conversation
|
|
>>> conversation = Conversation()
|
|
>>> conversation.add("user", "Hello, how are you?")
|
|
>>> conversation.add("assistant", "I am doing well, thanks.")
|
|
>>> conversation.display_conversation()
|
|
user: Hello, how are you?
|
|
assistant: I am doing well, thanks.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
system_prompt: Optional[str] = None,
|
|
time_enabled: bool = False,
|
|
database: AbstractDatabase = None,
|
|
autosave: bool = False,
|
|
save_filepath: str = None,
|
|
tokenizer: Any = None,
|
|
context_length: int = 8192,
|
|
rules: str = None,
|
|
custom_rules_prompt: str = None,
|
|
user: str = "User:",
|
|
auto_save: bool = True,
|
|
save_as_yaml: bool = True,
|
|
save_as_json: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.system_prompt = system_prompt
|
|
self.time_enabled = time_enabled
|
|
self.database = database
|
|
self.autosave = autosave
|
|
self.save_filepath = save_filepath
|
|
self.conversation_history = []
|
|
self.tokenizer = tokenizer
|
|
self.context_length = context_length
|
|
self.rules = rules
|
|
self.custom_rules_prompt = custom_rules_prompt
|
|
self.user = user
|
|
self.auto_save = auto_save
|
|
self.save_as_yaml = save_as_yaml
|
|
self.save_as_json = save_as_json
|
|
|
|
# If system prompt is not None, add it to the conversation history
|
|
if self.system_prompt is not None:
|
|
self.add("System: ", self.system_prompt)
|
|
|
|
if self.rules is not None:
|
|
self.add(user, rules)
|
|
|
|
if custom_rules_prompt is not None:
|
|
self.add(user, custom_rules_prompt)
|
|
|
|
# If tokenizer then truncate
|
|
if tokenizer is not None:
|
|
self.truncate_memory_with_tokenizer()
|
|
|
|
def add(self, role: str, content: str, *args, **kwargs):
|
|
"""Add a message to the conversation history
|
|
|
|
Args:
|
|
role (str): The role of the speaker
|
|
content (str): The content of the message
|
|
|
|
"""
|
|
if self.time_enabled:
|
|
now = datetime.datetime.now()
|
|
timestamp = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": timestamp,
|
|
}
|
|
else:
|
|
message = {
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
|
|
self.conversation_history.append(message)
|
|
|
|
if self.autosave:
|
|
self.save_as_json(self.save_filepath)
|
|
|
|
def delete(self, index: str):
|
|
"""Delete a message from the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to delete
|
|
"""
|
|
self.conversation_history.pop(index)
|
|
|
|
def update(self, index: str, role, content):
|
|
"""Update a message in the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to update
|
|
role (_type_): role of the speaker
|
|
content (_type_): content of the message
|
|
"""
|
|
self.conversation_history[index] = {
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
|
|
def query(self, index: str):
|
|
"""Query a message in the conversation history
|
|
|
|
Args:
|
|
index (str): index of the message to query
|
|
|
|
Returns:
|
|
str: the message
|
|
"""
|
|
return self.conversation_history[index]
|
|
|
|
def search(self, keyword: str):
|
|
"""Search for a message in the conversation history
|
|
|
|
Args:
|
|
keyword (str): Keyword to search for
|
|
|
|
Returns:
|
|
str: description
|
|
"""
|
|
return [
|
|
msg
|
|
for msg in self.conversation_history
|
|
if keyword in msg["content"]
|
|
]
|
|
|
|
def display_conversation(self, detailed: bool = False):
|
|
"""Display the conversation history
|
|
|
|
Args:
|
|
detailed (bool, optional): detailed. Defaults to False.
|
|
"""
|
|
role_to_color = {
|
|
"system": "red",
|
|
"user": "green",
|
|
"assistant": "blue",
|
|
"function": "magenta",
|
|
}
|
|
for message in self.conversation_history:
|
|
print(
|
|
colored(
|
|
f"{message['role']}: {message['content']}\n\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
|
|
def export_conversation(self, filename: str, *args, **kwargs):
|
|
"""Export the conversation history to a file
|
|
|
|
Args:
|
|
filename (str): filename to export to
|
|
"""
|
|
with open(filename, "w") as f:
|
|
for message in self.conversation_history:
|
|
f.write(f"{message['role']}: {message['content']}\n")
|
|
|
|
def import_conversation(self, filename: str):
|
|
"""Import a conversation history from a file
|
|
|
|
Args:
|
|
filename (str): filename to import from
|
|
"""
|
|
with open(filename) as f:
|
|
for line in f:
|
|
role, content = line.split(": ", 1)
|
|
self.add(role, content.strip())
|
|
|
|
def count_messages_by_role(self):
|
|
"""Count the number of messages by role"""
|
|
counts = {
|
|
"system": 0,
|
|
"user": 0,
|
|
"assistant": 0,
|
|
"function": 0,
|
|
}
|
|
for message in self.conversation_history:
|
|
counts[message["role"]] += 1
|
|
return counts
|
|
|
|
def return_history_as_string(self):
|
|
"""Return the conversation history as a string
|
|
|
|
Returns:
|
|
str: the conversation history
|
|
"""
|
|
return "\n".join(
|
|
[
|
|
f"{message['role']}: {message['content']}\n\n"
|
|
for message in self.conversation_history
|
|
]
|
|
)
|
|
|
|
def save_as_json(self, filename: str = None):
|
|
"""Save the conversation history as a JSON file
|
|
|
|
Args:
|
|
filename (str): Save the conversation history as a JSON file
|
|
"""
|
|
# Create the directory if it does not exist
|
|
# os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
if filename is not None:
|
|
with open(filename, "w") as f:
|
|
json.dump(self.conversation_history, f)
|
|
|
|
def load_from_json(self, filename: str):
|
|
"""Load the conversation history from a JSON file
|
|
|
|
Args:
|
|
filename (str): filename to load from
|
|
"""
|
|
# Load the conversation history from a JSON file
|
|
if filename is not None:
|
|
with open(filename) as f:
|
|
self.conversation_history = json.load(f)
|
|
|
|
def search_keyword_in_conversation(self, keyword: str):
|
|
"""Search for a keyword in the conversation history
|
|
|
|
Args:
|
|
keyword (str): keyword to search for
|
|
|
|
Returns:
|
|
str: description
|
|
"""
|
|
return [
|
|
msg
|
|
for msg in self.conversation_history
|
|
if keyword in msg["content"]
|
|
]
|
|
|
|
def pretty_print_conversation(self, messages):
|
|
"""Pretty print the conversation history
|
|
|
|
Args:
|
|
messages (str): messages to print
|
|
"""
|
|
role_to_color = {
|
|
"system": "red",
|
|
"user": "green",
|
|
"assistant": "blue",
|
|
"tool": "magenta",
|
|
}
|
|
|
|
for message in messages:
|
|
if message["role"] == "system":
|
|
print(
|
|
colored(
|
|
f"system: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "user":
|
|
print(
|
|
colored(
|
|
f"user: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "assistant" and message.get(
|
|
"function_call"
|
|
):
|
|
print(
|
|
colored(
|
|
f"assistant: {message['function_call']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "assistant" and not message.get(
|
|
"function_call"
|
|
):
|
|
print(
|
|
colored(
|
|
f"assistant: {message['content']}\n",
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
elif message["role"] == "tool":
|
|
print(
|
|
colored(
|
|
(
|
|
f"function ({message['name']}):"
|
|
f" {message['content']}\n"
|
|
),
|
|
role_to_color[message["role"]],
|
|
)
|
|
)
|
|
|
|
def add_to_database(self, *args, **kwargs):
|
|
"""Add the conversation history to the database"""
|
|
self.database.add("conversation", self.conversation_history)
|
|
|
|
def query_from_database(self, query, *args, **kwargs):
|
|
"""Query the conversation history from the database"""
|
|
return self.database.query("conversation", query)
|
|
|
|
def delete_from_database(self, *args, **kwargs):
|
|
"""Delete the conversation history from the database"""
|
|
self.database.delete("conversation")
|
|
|
|
def update_from_database(self, *args, **kwargs):
|
|
"""Update the conversation history from the database"""
|
|
self.database.update("conversation", self.conversation_history)
|
|
|
|
def get_from_database(self, *args, **kwargs):
|
|
"""Get the conversation history from the database"""
|
|
return self.database.get("conversation")
|
|
|
|
def execute_query_from_database(self, query, *args, **kwargs):
|
|
"""Execute a query on the database"""
|
|
return self.database.execute_query(query)
|
|
|
|
def fetch_all_from_database(self, *args, **kwargs):
|
|
"""Fetch all from the database"""
|
|
return self.database.fetch_all()
|
|
|
|
def fetch_one_from_database(self, *args, **kwargs):
|
|
"""Fetch one from the database"""
|
|
return self.database.fetch_one()
|
|
|
|
def truncate_memory_with_tokenizer(self):
|
|
"""
|
|
Truncates the conversation history based on the total number of tokens using a tokenizer.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
total_tokens = 0
|
|
truncated_history = []
|
|
|
|
for message in self.conversation_history:
|
|
role = message.get("role")
|
|
content = message.get("content")
|
|
tokens = self.tokenizer.count_tokens(
|
|
text=content
|
|
) # Count the number of tokens
|
|
count = tokens # Assign the token count
|
|
total_tokens += count
|
|
|
|
if total_tokens <= self.context_length:
|
|
truncated_history.append(message)
|
|
else:
|
|
remaining_tokens = self.context_length - (
|
|
total_tokens - count
|
|
)
|
|
truncated_content = content[
|
|
:remaining_tokens
|
|
] # Truncate the content based on the remaining tokens
|
|
truncated_message = {
|
|
"role": role,
|
|
"content": truncated_content,
|
|
}
|
|
truncated_history.append(truncated_message)
|
|
break
|
|
|
|
self.conversation_history = truncated_history
|
|
|
|
def clear(self):
|
|
self.conversation_history = []
|