@ -40,21 +40,33 @@ class ModelParallelizer:
def __init__ (
self ,
llms : List [ Callable ] ,
llms : List [ Callable ] = None ,
load_balancing : bool = False ,
retry_attempts : int = 3 ,
iters : int = None ,
* args ,
* * kwargs ,
) :
self . llms = llms
self . load_balancing = load_balancing
self . retry_attempts = retry_attempts
self . iters = iters
self . last_responses = None
self . task_history = [ ]
def run ( self , task : str ) :
""" Run the task string """
with ThreadPoolExecutor ( ) as executor :
responses = executor . map ( lambda llm : llm ( task ) , self . llms )
return list ( responses )
try :
for i in range ( self . iters ) :
with ThreadPoolExecutor ( ) as executor :
responses = executor . map (
lambda llm : llm ( task ) , self . llms
)
return list ( responses )
except Exception as error :
print (
f " [ERROR][ModelParallelizer] [ROOT CAUSE] [ { error } ] "
)
def print_responses ( self , task ) :
""" Prints the responses in a tabular format """
@ -161,22 +173,29 @@ class ModelParallelizer:
def concurrent_run ( self , task : str ) - > List [ str ] :
""" Synchronously run the task on all llms and collect responses """
with ThreadPoolExecutor ( ) as executor :
future_to_llm = {
executor . submit ( llm , task ) : llm for llm in self . llms
}
responses = [ ]
for future in as_completed ( future_to_llm ) :
try :
responses . append ( future . result ( ) )
except Exception as error :
print (
f " { future_to_llm [ future ] } generated an "
f " exception: { error } "
)
self . last_responses = responses
self . task_history . append ( task )
return responses
try :
with ThreadPoolExecutor ( ) as executor :
future_to_llm = {
executor . submit ( llm , task ) : llm
for llm in self . llms
}
responses = [ ]
for future in as_completed ( future_to_llm ) :
try :
responses . append ( future . result ( ) )
except Exception as error :
print (
f " { future_to_llm [ future ] } generated an "
f " exception: { error } "
)
self . last_responses = responses
self . task_history . append ( task )
return responses
except Exception as error :
print (
f " [ERROR][ModelParallelizer] [ROOT CAUSE] [ { error } ] "
)
raise error
def add_llm ( self , llm : Callable ) :
""" Add an llm to the god mode """