|
|
|
@ -3,7 +3,6 @@ import logging
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
from typing import Callable, List
|
|
|
|
|
|
|
|
|
|
from tabulate import tabulate
|
|
|
|
|
from termcolor import colored
|
|
|
|
|
|
|
|
|
|
# Configure logging
|
|
|
|
@ -77,23 +76,6 @@ class ModelParallelizer:
|
|
|
|
|
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def print_responses(self, task):
|
|
|
|
|
"""Prints the responses in a tabular format"""
|
|
|
|
|
responses = self.run_all(task)
|
|
|
|
|
table = []
|
|
|
|
|
for i, response in enumerate(responses):
|
|
|
|
|
table.append([f"LLM {i+1}", response])
|
|
|
|
|
print(
|
|
|
|
|
colored(
|
|
|
|
|
tabulate(
|
|
|
|
|
table,
|
|
|
|
|
headers=["LLM", "Response"],
|
|
|
|
|
tablefmt="pretty",
|
|
|
|
|
),
|
|
|
|
|
"cyan",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def run_all(self, task):
|
|
|
|
|
"""Run the task on all LLMs"""
|
|
|
|
|
responses = []
|
|
|
|
@ -101,23 +83,7 @@ class ModelParallelizer:
|
|
|
|
|
responses.append(llm(task))
|
|
|
|
|
return responses
|
|
|
|
|
|
|
|
|
|
def print_arun_all(self, task):
|
|
|
|
|
"""Prints the responses in a tabular format"""
|
|
|
|
|
responses = self.arun_all(task)
|
|
|
|
|
table = []
|
|
|
|
|
for i, response in enumerate(responses):
|
|
|
|
|
table.append([f"LLM {i+1}", response])
|
|
|
|
|
print(
|
|
|
|
|
colored(
|
|
|
|
|
tabulate(
|
|
|
|
|
table,
|
|
|
|
|
headers=["LLM", "Response"],
|
|
|
|
|
tablefmt="pretty",
|
|
|
|
|
),
|
|
|
|
|
"cyan",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# New Features
|
|
|
|
|
def save_responses_to_file(self, filename):
|
|
|
|
|
"""Save responses to file"""
|
|
|
|
@ -126,7 +92,7 @@ class ModelParallelizer:
|
|
|
|
|
[f"LLM {i+1}", response]
|
|
|
|
|
for i, response in enumerate(self.last_responses)
|
|
|
|
|
]
|
|
|
|
|
file.write(tabulate(table, headers=["LLM", "Response"]))
|
|
|
|
|
file.write(table)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def load_llms_from_file(cls, filename):
|
|
|
|
@ -151,11 +117,7 @@ class ModelParallelizer:
|
|
|
|
|
]
|
|
|
|
|
print(
|
|
|
|
|
colored(
|
|
|
|
|
tabulate(
|
|
|
|
|
table,
|
|
|
|
|
headers=["LLM", "Response"],
|
|
|
|
|
tablefmt="pretty",
|
|
|
|
|
),
|
|
|
|
|
table,
|
|
|
|
|
"cyan",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|