You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
365 lines
12 KiB
365 lines
12 KiB
import json
|
|
import os
|
|
from datetime import datetime
|
|
from typing import List
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
NAME_LIST = [
|
|
"Affirmative side",
|
|
"Negative side",
|
|
"Moderator",
|
|
]
|
|
|
|
|
|
class DebatePlayer(Agent):
|
|
def __init__(self, llm, name: str, *args, **kwargs) -> None:
|
|
"""Create a player in the debate
|
|
|
|
Args:
|
|
model_name(str): model name
|
|
name (str): name of this player
|
|
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
|
openai_api_key (str): As the parameter name suggests
|
|
sleep_time (float): sleep because of rate limits
|
|
"""
|
|
super().__init__(llm=llm, agent_name=name, *args, **kwargs)
|
|
|
|
|
|
class Debate:
|
|
"""Create a debate
|
|
|
|
Args:
|
|
model_name (str): openai model name
|
|
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
|
num_players (int): num of players
|
|
save_file_dir (str): dir path to json file
|
|
openai_api_key (str): As the parameter name suggests
|
|
prompts_path (str): prompts path (json file)
|
|
max_round (int): maximum Rounds of Debate
|
|
sleep_time (float): sleep because of rate limits
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
debate_agents: List[DebatePlayer],
|
|
temperature: float = 0,
|
|
num_players: int = 3,
|
|
save_file_dir: str = None,
|
|
prompts_path: str = None,
|
|
max_round: int = 3,
|
|
sleep_time: float = 0,
|
|
) -> None:
|
|
self.debate_agents = debate_agents
|
|
self.num_players = num_players
|
|
self.save_file_dir = save_file_dir
|
|
self.max_round = max_round
|
|
self.sleep_time = sleep_time
|
|
|
|
# init save file
|
|
now = datetime.now()
|
|
current_time = now.strftime("%Y-%m-%d_%H:%M:%S")
|
|
self.save_file = {
|
|
"start_time": current_time,
|
|
"end_time": "",
|
|
"temperature": temperature,
|
|
"num_players": num_players,
|
|
"success": False,
|
|
"src_lng": "",
|
|
"tgt_lng": "",
|
|
"source": "",
|
|
"reference": "",
|
|
"base_translation": "",
|
|
"debate_translation": "",
|
|
"Reason": "",
|
|
"Supported Side": "",
|
|
"players": {},
|
|
}
|
|
prompts = json.load(open(prompts_path))
|
|
self.save_file.update(prompts)
|
|
self.init_prompt()
|
|
|
|
if self.save_file["base_translation"] == "":
|
|
self.create_base()
|
|
|
|
# creat&init agents
|
|
self.create_agents()
|
|
self.init_agents()
|
|
|
|
def init_prompt(self):
|
|
def prompt_replace(key):
|
|
self.save_file[key] = (
|
|
self.save_file[key]
|
|
.replace("##src_lng##", self.save_file["src_lng"])
|
|
.replace("##tgt_lng##", self.save_file["tgt_lng"])
|
|
.replace("##source##", self.save_file["source"])
|
|
.replace(
|
|
"##base_translation##",
|
|
self.save_file["base_translation"],
|
|
)
|
|
)
|
|
|
|
prompt_replace("base_prompt")
|
|
prompt_replace("player_meta_prompt")
|
|
prompt_replace("moderator_meta_prompt")
|
|
prompt_replace("judge_prompt_last2")
|
|
|
|
def create_base(self):
|
|
print(
|
|
"\n===== Translation Task"
|
|
f" =====\n{self.save_file['base_prompt']}\n"
|
|
)
|
|
agent = DebatePlayer(
|
|
name="Baseline",
|
|
openai_api_key=self.openai_api_key,
|
|
)
|
|
agent.add_message_to_memory(self.save_file["base_prompt"])
|
|
base_translation = agent.ask()
|
|
agent.add_message_to_memory(base_translation)
|
|
self.save_file["base_translation"] = base_translation
|
|
self.save_file["affirmative_prompt"] = self.save_file[
|
|
"affirmative_prompt"
|
|
].replace("##base_translation##", base_translation)
|
|
self.save_file["players"][agent.name] = agent.memory_lst
|
|
|
|
def create_agents(self):
|
|
# creates players
|
|
self.players = [
|
|
DebatePlayer(
|
|
model_name=self.model_name,
|
|
name=name,
|
|
)
|
|
for name in NAME_LIST
|
|
]
|
|
self.affirmative = self.players[0]
|
|
self.negative = self.players[1]
|
|
self.moderator = self.players[2]
|
|
|
|
def init_agents(self):
|
|
# start: set meta prompt
|
|
self.affirmative.system_prompt(
|
|
self.save_file["player_meta_prompt"]
|
|
)
|
|
self.negative.system_prompt(
|
|
self.save_file["player_meta_prompt"]
|
|
)
|
|
self.moderator.system_prompt(
|
|
self.save_file["moderator_meta_prompt"]
|
|
)
|
|
|
|
# start: first round debate, state opinions
|
|
print("===== Debate Round-1 =====\n")
|
|
self.affirmative.add_message_to_memory(
|
|
self.save_file["affirmative_prompt"]
|
|
)
|
|
self.aff_ans = self.affirmative.ask()
|
|
self.affirmative.add_message_to_memory(self.aff_ans)
|
|
|
|
self.negative.add_message_to_memory(
|
|
self.save_file["negative_prompt"].replace(
|
|
"##aff_ans##", self.aff_ans
|
|
)
|
|
)
|
|
self.neg_ans = self.negative.ask()
|
|
self.negative.add_message_to_memory(self.neg_ans)
|
|
|
|
self.moderator.add_message_to_memory(
|
|
self.save_file["moderator_prompt"]
|
|
.replace("##aff_ans##", self.aff_ans)
|
|
.replace("##neg_ans##", self.neg_ans)
|
|
.replace("##round##", "first")
|
|
)
|
|
self.mod_ans = self.moderator.ask()
|
|
self.moderator.add_message_to_memory(self.mod_ans)
|
|
self.mod_ans = eval(self.mod_ans)
|
|
|
|
def round_dct(self, num: int):
|
|
dct = {
|
|
1: "first",
|
|
2: "second",
|
|
3: "third",
|
|
4: "fourth",
|
|
5: "fifth",
|
|
6: "sixth",
|
|
7: "seventh",
|
|
8: "eighth",
|
|
9: "ninth",
|
|
10: "tenth",
|
|
}
|
|
return dct[num]
|
|
|
|
def save_file_to_json(self, id):
|
|
now = datetime.now()
|
|
current_time = now.strftime("%Y-%m-%d_%H:%M:%S")
|
|
save_file_path = os.path.join(
|
|
self.save_file_dir, f"{id}.json"
|
|
)
|
|
|
|
self.save_file["end_time"] = current_time
|
|
json_str = json.dumps(
|
|
self.save_file, ensure_ascii=False, indent=4
|
|
)
|
|
with open(save_file_path, "w") as f:
|
|
f.write(json_str)
|
|
|
|
def broadcast(self, msg: str):
|
|
"""Broadcast a message to all players.
|
|
Typical use is for the host to announce public information
|
|
|
|
Args:
|
|
msg (str): the message
|
|
"""
|
|
# print(msg)
|
|
for player in self.players:
|
|
player.add_message_to_memory(msg)
|
|
|
|
def speak(self, speaker: str, msg: str):
|
|
"""The speaker broadcast a message to all other players.
|
|
|
|
Args:
|
|
speaker (str): name of the speaker
|
|
msg (str): the message
|
|
"""
|
|
if not msg.startswith(f"{speaker}: "):
|
|
msg = f"{speaker}: {msg}"
|
|
# print(msg)
|
|
for player in self.players:
|
|
if player.name != speaker:
|
|
player.add_message_to_memory(msg)
|
|
|
|
def ask_and_speak(self, player: DebatePlayer):
|
|
ans = player.ask()
|
|
player.add_message_to_memory(ans)
|
|
self.speak(player.name, ans)
|
|
|
|
def run(self):
|
|
for round in range(self.max_round - 1):
|
|
if self.mod_ans["debate_translation"] != "":
|
|
break
|
|
else:
|
|
print(f"===== Debate Round-{round+2} =====\n")
|
|
self.affirmative.add_message_to_memory(
|
|
self.save_file["debate_prompt"].replace(
|
|
"##oppo_ans##", self.neg_ans
|
|
)
|
|
)
|
|
self.aff_ans = self.affirmative.ask()
|
|
self.affirmative.add_message_to_memory(self.aff_ans)
|
|
|
|
self.negative.add_message_to_memory(
|
|
self.save_file["debate_prompt"].replace(
|
|
"##oppo_ans##", self.aff_ans
|
|
)
|
|
)
|
|
self.neg_ans = self.negative.ask()
|
|
self.negative.add_message_to_memory(self.neg_ans)
|
|
|
|
self.moderator.add_message_to_memory(
|
|
self.save_file["moderator_prompt"]
|
|
.replace("##aff_ans##", self.aff_ans)
|
|
.replace("##neg_ans##", self.neg_ans)
|
|
.replace("##round##", self.round_dct(round + 2))
|
|
)
|
|
self.mod_ans = self.moderator.ask()
|
|
self.moderator.add_message_to_memory(self.mod_ans)
|
|
self.mod_ans = eval(self.mod_ans)
|
|
|
|
if self.mod_ans["debate_translation"] != "":
|
|
self.save_file.update(self.mod_ans)
|
|
self.save_file["success"] = True
|
|
|
|
# ultimate deadly technique.
|
|
else:
|
|
judge_player = DebatePlayer(
|
|
model_name=self.model_name,
|
|
name="Judge",
|
|
temperature=self.temperature,
|
|
openai_api_key=self.openai_api_key,
|
|
sleep_time=self.sleep_time,
|
|
)
|
|
aff_ans = self.affirmative.memory_lst[2]["content"]
|
|
neg_ans = self.negative.memory_lst[2]["content"]
|
|
|
|
judge_player.system_prompt(
|
|
self.save_file["moderator_meta_prompt"]
|
|
)
|
|
|
|
# extract answer candidates
|
|
judge_player.add_message_to_memory(
|
|
self.save_file["judge_prompt_last1"]
|
|
.replace("##aff_ans##", aff_ans)
|
|
.replace("##neg_ans##", neg_ans)
|
|
)
|
|
ans = judge_player.ask()
|
|
judge_player.add_message_to_memory(ans)
|
|
|
|
# select one from the candidates
|
|
judge_player.add_message_to_memory(
|
|
self.save_file["judge_prompt_last2"]
|
|
)
|
|
ans = judge_player.ask()
|
|
judge_player.add_message_to_memory(ans)
|
|
|
|
ans = eval(ans)
|
|
if ans["debate_translation"] != "":
|
|
self.save_file["success"] = True
|
|
# save file
|
|
self.save_file.update(ans)
|
|
self.players.append(judge_player)
|
|
|
|
for player in self.players:
|
|
self.save_file["players"][player.name] = player.memory_lst
|
|
|
|
|
|
# def parse_args():
|
|
# parser = argparse.ArgumentParser("", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
# parser.add_argument("-i", "--input-file", type=str, required=True, help="Input file path")
|
|
# parser.add_argument("-o", "--output-dir", type=str, required=True, help="Output file dir")
|
|
# parser.add_argument("-lp", "--lang-pair", type=str, required=True, help="Language pair")
|
|
# parser.add_argument("-k", "--api-key", type=str, required=True, help="OpenAI api key")
|
|
# parser.add_argument("-m", "--model-name", type=str, default="gpt-3.5-turbo", help="Model name")
|
|
# parser.add_argument("-t", "--temperature", type=float, default=0, help="Sampling temperature")
|
|
|
|
# return parser.parse_args()
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# args = parse_args()
|
|
# openai_api_key = args.api_key
|
|
|
|
# current_script_path = os.path.abspath(__file__)
|
|
# MAD_path = current_script_path.rsplit("/", 2)[0]
|
|
|
|
# src_lng, tgt_lng = args.lang_pair.split('-')
|
|
# src_full = Language.make(language=src_lng).display_name()
|
|
# tgt_full = Language.make(language=tgt_lng).display_name()
|
|
|
|
# config = json.load(open(f"{MAD_path}/code/utils/config4tran.json", "r"))
|
|
|
|
# inputs = open(args.input_file, "r").readlines()
|
|
# inputs = [l.strip() for l in inputs]
|
|
|
|
# save_file_dir = args.output_dir
|
|
# if not os.path.exists(save_file_dir):
|
|
# os.mkdir(save_file_dir)
|
|
|
|
# for id, input in enumerate(tqdm(inputs)):
|
|
# # files = os.listdir(save_file_dir)
|
|
# # if f"{id}.json" in files:
|
|
# # continue
|
|
|
|
# prompts_path = f"{save_file_dir}/{id}-config.json"
|
|
|
|
# config['source'] = input.split('\t')[0]
|
|
# config['reference'] = input.split('\t')[1]
|
|
# config['src_lng'] = src_full
|
|
# config['tgt_lng'] = tgt_full
|
|
|
|
# with open(prompts_path, 'w') as file:
|
|
# json.dump(config, file, ensure_ascii=False, indent=4)
|
|
|
|
# debate = Debate(save_file_dir=save_file_dir, num_players=3, openai_api_key=openai_api_key, prompts_path=prompts_path, temperature=0, sleep_time=0)
|
|
# debate.run()
|
|
# debate.save_file_to_json(id)
|