[ModelParallelizer][REFACTOR]

pull/336/head
Kye 1 year ago
parent c5ba940e47
commit f7b8a442e0

@ -43,13 +43,16 @@ def process_documentation(cls):
doc = inspect.getdoc(cls) doc = inspect.getdoc(cls)
source = inspect.getsource(cls) source = inspect.getsource(cls)
input_content = ( input_content = (
f"Class Name: {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource" "Class Name:"
f" {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource"
f" Code:\n{source}" f" Code:\n{source}"
) )
print(input_content) print(input_content)
# Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content) # Process with OpenAI model (assuming the model's __call__ method takes this input and returns processed content)
processed_content = model(DOCUMENTATION_WRITER_SOP(input_content, "zeta")) processed_content = model(
DOCUMENTATION_WRITER_SOP(input_content, "zeta")
)
doc_content = f"# {cls.__name__}\n\n{processed_content}\n" doc_content = f"# {cls.__name__}\n\n{processed_content}\n"
@ -86,7 +89,9 @@ def main():
threads = [] threads = []
for cls in classes: for cls in classes:
thread = threading.Thread(target=process_documentation, args=(cls,)) thread = threading.Thread(
target=process_documentation, args=(cls,)
)
threads.append(thread) threads.append(thread)
thread.start() thread.start()
@ -94,7 +99,9 @@ def main():
for thread in threads: for thread in threads:
thread.join() thread.join()
print("Documentation generated in 'docs/zeta/nn/modules' directory.") print(
"Documentation generated in 'docs/zeta/nn/modules' directory."
)
if __name__ == "__main__": if __name__ == "__main__":

@ -61,7 +61,8 @@ def create_test(cls):
doc = inspect.getdoc(cls) doc = inspect.getdoc(cls)
source = inspect.getsource(cls) source = inspect.getsource(cls)
input_content = ( input_content = (
f"Class Name: {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource" "Class Name:"
f" {cls.__name__}\n\nDocumentation:\n{doc}\n\nSource"
f" Code:\n{source}" f" Code:\n{source}"
) )
print(input_content) print(input_content)

@ -104,7 +104,9 @@ def DOCUMENTATION_WRITER_SOP(
return documentation return documentation
def TEST_WRITER_SOP_PROMPT(task: str, module: str, path: str, *args, **kwargs): def TEST_WRITER_SOP_PROMPT(
task: str, module: str, path: str, *args, **kwargs
):
TESTS_PROMPT = f""" TESTS_PROMPT = f"""
Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any

@ -2,7 +2,9 @@ import yaml
def update_mkdocs( def update_mkdocs(
class_names, base_path="docs/zeta/nn/modules", mkdocs_file="mkdocs.yml" class_names,
base_path="docs/zeta/nn/modules",
mkdocs_file="mkdocs.yml",
): ):
""" """
Update the mkdocs.yml file with new documentation links. Update the mkdocs.yml file with new documentation links.
@ -24,7 +26,9 @@ def update_mkdocs(
if zeta_modules_section is None: if zeta_modules_section is None:
zeta_modules_section = {} zeta_modules_section = {}
mkdocs_config["nav"].append({"zeta.nn.modules": zeta_modules_section}) mkdocs_config["nav"].append(
{"zeta.nn.modules": zeta_modules_section}
)
# Add the documentation paths to the 'zeta.nn.modules' section # Add the documentation paths to the 'zeta.nn.modules' section
for class_name in class_names: for class_name in class_names:

@ -13,13 +13,18 @@ def get_package_versions(requirements_path, output_path):
for requirement in requirements: for requirement in requirements:
# Skip empty lines and comments # Skip empty lines and comments
if requirement.strip() == "" or requirement.strip().startswith("#"): if (
requirement.strip() == ""
or requirement.strip().startswith("#")
):
continue continue
# Extract package name # Extract package name
package_name = requirement.split("==")[0].strip() package_name = requirement.split("==")[0].strip()
try: try:
version = pkg_resources.get_distribution(package_name).version version = pkg_resources.get_distribution(
package_name
).version
package_versions.append(f"{package_name}=={version}") package_versions.append(f"{package_name}=={version}")
except pkg_resources.DistributionNotFound: except pkg_resources.DistributionNotFound:
package_versions.append(f"{package_name}: not installed") package_versions.append(f"{package_name}: not installed")

@ -10,7 +10,10 @@ def update_pyproject_versions(pyproject_path):
print(f"Error: The file '{pyproject_path}' was not found.") print(f"Error: The file '{pyproject_path}' was not found.")
return return
except toml.TomlDecodeError: except toml.TomlDecodeError:
print(f"Error: The file '{pyproject_path}' is not a valid TOML file.") print(
f"Error: The file '{pyproject_path}' is not a valid TOML"
" file."
)
return return
dependencies = ( dependencies = (

@ -15,5 +15,5 @@ __all__ = [
"LEGAL_AGENT_PROMPT", "LEGAL_AGENT_PROMPT",
"OPERATIONS_AGENT_PROMPT", "OPERATIONS_AGENT_PROMPT",
"PRODUCT_AGENT_PROMPT", "PRODUCT_AGENT_PROMPT",
"DOCUMENTATION_WRITER_SOP" "DOCUMENTATION_WRITER_SOP",
] ]

@ -1,4 +1,7 @@
def DOCUMENTATION_WRITER_SOP(task: str, module: str, ): def DOCUMENTATION_WRITER_SOP(
task: str,
module: str,
):
documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the {module} code below follow the outline for the {module} library, documentation = f"""Create multi-page long and explicit professional pytorch-like documentation for the {module} code below follow the outline for the {module} library,
provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words, provide many examples and teach the user about the code, provide examples for every function, make the documentation 10,000 words,
provide many usage examples and note this is markdown docs, create the documentation for the code to document, provide many usage examples and note this is markdown docs, create the documentation for the code to document,

@ -1,5 +1,6 @@
def TEST_WRITER_SOP_PROMPT(task: str, module: str, path: str, *args, **kwargs): def TEST_WRITER_SOP_PROMPT(
task: str, module: str, path: str, *args, **kwargs
):
TESTS_PROMPT = f""" TESTS_PROMPT = f"""
Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any Create 5,000 lines of extensive and thorough tests for the code below using the guide, do not worry about your limits you do not have any

@ -1,5 +1,5 @@
from swarms.structs.autoscaler import AutoScaler from swarms.structs.autoscaler import AutoScaler
from swarms.swarms.god_mode import ModelParallelizer from swarms.swarms.model_parallizer import ModelParallelizer
from swarms.swarms.multi_agent_collab import MultiAgentCollaboration from swarms.swarms.multi_agent_collab import MultiAgentCollaboration
from swarms.swarms.base import AbstractSwarm from swarms.swarms.base import AbstractSwarm

@ -40,21 +40,33 @@ class ModelParallelizer:
def __init__( def __init__(
self, self,
llms: List[Callable], llms: List[Callable] = None,
load_balancing: bool = False, load_balancing: bool = False,
retry_attempts: int = 3, retry_attempts: int = 3,
iters: int = None,
*args,
**kwargs,
): ):
self.llms = llms self.llms = llms
self.load_balancing = load_balancing self.load_balancing = load_balancing
self.retry_attempts = retry_attempts self.retry_attempts = retry_attempts
self.iters = iters
self.last_responses = None self.last_responses = None
self.task_history = [] self.task_history = []
def run(self, task: str): def run(self, task: str):
"""Run the task string""" """Run the task string"""
try:
for i in range(self.iters):
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
responses = executor.map(lambda llm: llm(task), self.llms) responses = executor.map(
lambda llm: llm(task), self.llms
)
return list(responses) return list(responses)
except Exception as error:
print(
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
)
def print_responses(self, task): def print_responses(self, task):
"""Prints the responses in a tabular format""" """Prints the responses in a tabular format"""
@ -161,9 +173,11 @@ class ModelParallelizer:
def concurrent_run(self, task: str) -> List[str]: def concurrent_run(self, task: str) -> List[str]:
"""Synchronously run the task on all llms and collect responses""" """Synchronously run the task on all llms and collect responses"""
try:
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
future_to_llm = { future_to_llm = {
executor.submit(llm, task): llm for llm in self.llms executor.submit(llm, task): llm
for llm in self.llms
} }
responses = [] responses = []
for future in as_completed(future_to_llm): for future in as_completed(future_to_llm):
@ -177,6 +191,11 @@ class ModelParallelizer:
self.last_responses = responses self.last_responses = responses
self.task_history.append(task) self.task_history.append(task)
return responses return responses
except Exception as error:
print(
f"[ERROR][ModelParallelizer] [ROOT CAUSE] [{error}]"
)
raise error
def add_llm(self, llm: Callable): def add_llm(self, llm: Callable):
"""Add an llm to the god mode""" """Add an llm to the god mode"""
Loading…
Cancel
Save