From f74f73124ff8ea4ecb848ccd62b23a0465760afe Mon Sep 17 00:00:00 2001 From: Occupying-Mars Date: Thu, 5 Dec 2024 20:24:25 +0530 Subject: [PATCH 01/18] handle artifact bug --- swarms/structs/agent.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index c9160b1b..bdeac3e3 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -2388,36 +2388,42 @@ class Agent: ) -> None: """Handle creating and saving artifacts with error handling.""" try: - logger.info( - f"Creating artifact for file: {file_output_path}" - ) + # Ensure file_extension starts with a dot + if not file_extension.startswith('.'): + file_extension = '.' + file_extension + + # If file_output_path doesn't have an extension, treat it as a directory + # and create a default filename based on timestamp + if not os.path.splitext(file_output_path)[1]: + timestamp = time.strftime("%Y%m%d_%H%M%S") + filename = f"artifact_{timestamp}{file_extension}" + full_path = os.path.join(file_output_path, filename) + else: + full_path = file_output_path + + # Create the directory if it doesn't exist + os.makedirs(os.path.dirname(full_path), exist_ok=True) + + logger.info(f"Creating artifact for file: {full_path}") artifact = Artifact( - file_path=file_output_path, + file_path=full_path, file_type=file_extension, contents=text, edit_count=0, ) - logger.info( - f"Saving artifact with extension: {file_extension}" - ) + logger.info(f"Saving artifact with extension: {file_extension}") artifact.save_as(file_extension) - logger.success( - f"Successfully saved artifact to {file_output_path}" - ) + logger.success(f"Successfully saved artifact to {full_path}") except ValueError as e: - logger.error( - f"Invalid input values for artifact: {str(e)}" - ) + logger.error(f"Invalid input values for artifact: {str(e)}") raise except IOError as e: logger.error(f"Error saving artifact to file: {str(e)}") raise except Exception as e: - logger.error( - f"Unexpected error handling artifact: {str(e)}" - ) + logger.error(f"Unexpected error handling artifact: {str(e)}") raise def showcase_config(self): From 88da59e614b40c083d50579923a0a29a6511ad06 Mon Sep 17 00:00:00 2001 From: Occupying-Mars Date: Thu, 5 Dec 2024 20:44:45 +0530 Subject: [PATCH 02/18] return appended entire schema --- swarms/tools/base_tool.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/swarms/tools/base_tool.py b/swarms/tools/base_tool.py index dcb81974..09b3c506 100644 --- a/swarms/tools/base_tool.py +++ b/swarms/tools/base_tool.py @@ -387,6 +387,8 @@ class BaseTool(BaseModel): "Converting tools into OpenAI function calling schema" ) + tool_schemas = [] + for tool in self.tools: # Transform the tool into a openai function calling schema if self.check_func_if_have_docs( @@ -398,7 +400,7 @@ class BaseTool(BaseModel): logger.info( f"Converting tool: {name} into a OpenAI certified function calling schema. Add documentation and type hints." ) - tool_schema_list = ( + tool_schema = ( get_openai_function_schema_from_func( tool, name=name, description=description ) @@ -408,18 +410,21 @@ class BaseTool(BaseModel): f"Tool {name} converted successfully into OpenAI schema" ) - # Transform the dictionary to a string - tool_schema_list = json.dumps( - tool_schema_list, indent=4 - ) - - return tool_schema_list + tool_schemas.append(tool_schema) else: logger.error( f"Tool {tool.__name__} does not have documentation or type hints, please add them to make the tool execution reliable." ) - return tool_schema_list + # Combine all tool schemas into a single schema + if tool_schemas: + combined_schema = { + "type": "function", + "functions": [schema["function"] for schema in tool_schemas] + } + return json.dumps(combined_schema, indent=4) + + return None def check_func_if_have_docs(self, func: callable): if func.__doc__ is not None: From 0af33010874474bc87cbbeeceb8f7e0328120811 Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Sun, 8 Dec 2024 20:38:06 -0800 Subject: [PATCH 03/18] Update README.md --- README.md | 212 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 165 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 071b1991..70b31e6e 100644 --- a/README.md +++ b/README.md @@ -453,8 +453,8 @@ agent.run(task, img) ---- -### `ToolAgent` -ToolAgent is an agent that can use tools through JSON function calling. It intakes any open source model from huggingface and is extremely modular and plug in and play. We need help adding general support to all models soon. +### Local Agent `ToolAgent` +ToolAgent is an fully local agent that can use tools through JSON function calling. It intakes any open source model from huggingface and is extremely modular and plug in and play. We need help adding general support to all models soon. ```python @@ -774,6 +774,8 @@ print( The `AgentRearrange` orchestration technique, inspired by Einops and einsum, allows you to define and map out the relationships between various agents. It provides a powerful tool for orchestrating complex workflows, enabling you to specify linear and sequential relationships such as `a -> a1 -> a2 -> a3`, or concurrent relationships where the first agent sends a message to 3 agents simultaneously: `a -> a1, a2, a3`. This level of customization allows for the creation of highly efficient and dynamic workflows, where agents can work in parallel or in sequence as needed. The `AgentRearrange` technique is a valuable addition to the swarms library, providing a new level of flexibility and control over the orchestration of agents. For more detailed information and examples, please refer to the [official documentation](https://docs.swarms.world/en/latest/swarms/structs/agent_rearrange/). +[Check out my video on agent rearrange!](https://youtu.be/Rq8wWQ073mg) + ### Methods @@ -799,68 +801,184 @@ The `run` method returns the final output after all agents have processed the in ```python -from swarms import Agent, AgentRearrange - +from datetime import datetime -# Initialize the director agent +from swarms import Agent, AgentRearrange, create_file_in_folder -director = Agent( - agent_name="Director", - system_prompt="Directs the tasks for the workers", - model_name="claude-2", +chief_medical_officer = Agent( + agent_name="Chief Medical Officer", + system_prompt="""You are the Chief Medical Officer coordinating a team of medical specialists for viral disease diagnosis. + Your responsibilities include: + - Gathering initial patient symptoms and medical history + - Coordinating with specialists to form differential diagnoses + - Synthesizing different specialist opinions into a cohesive diagnosis + - Ensuring all relevant symptoms and test results are considered + - Making final diagnostic recommendations + - Suggesting treatment plans based on team input + - Identifying when additional specialists need to be consulted + + Guidelines: + 1. Always start with a comprehensive patient history + 2. Consider both common and rare viral conditions + 3. Factor in patient demographics and risk factors + 4. Document your reasoning process clearly + 5. Highlight any critical or emergency symptoms + 6. Note any limitations or uncertainties in the diagnosis + + Format all responses with clear sections for: + - Initial Assessment + - Differential Diagnoses + - Specialist Consultations Needed + - Recommended Next Steps""", + model_name="gpt-4o", # Models from litellm -> claude-2 max_loops=1, - dashboard=False, - streaming_on=True, - verbose=True, - stopping_token="", - state_save_file_type="json", - saved_state_path="director.json", ) +# Viral Disease Specialist +virologist = Agent( + agent_name="Virologist", + system_prompt="""You are a specialist in viral diseases with expertise in: + - Respiratory viruses (Influenza, Coronavirus, RSV) + - Systemic viral infections (EBV, CMV, HIV) + - Childhood viral diseases (Measles, Mumps, Rubella) + - Emerging viral threats + + Your role involves: + 1. Analyzing symptoms specific to viral infections + 2. Distinguishing between different viral pathogens + 3. Assessing viral infection patterns and progression + 4. Recommending specific viral tests + 5. Evaluating epidemiological factors + + For each case, consider: + - Incubation periods + - Transmission patterns + - Seasonal factors + - Geographic prevalence + - Patient immune status + - Current viral outbreaks + + Provide detailed analysis of: + - Characteristic viral symptoms + - Disease progression timeline + - Risk factors for severe disease + - Potential complications""", + model_name="gpt-4o", + max_loops=1, +) -# Initialize worker 1 +# Internal Medicine Specialist +internist = Agent( + agent_name="Internist", + system_prompt="""You are an Internal Medicine specialist responsible for: + - Comprehensive system-based evaluation + - Integration of symptoms across organ systems + - Identification of systemic manifestations + - Assessment of comorbidities + + For each case, analyze: + 1. Vital signs and their implications + 2. System-by-system review (cardiovascular, respiratory, etc.) + 3. Impact of existing medical conditions + 4. Medication interactions and contraindications + 5. Risk stratification + + Consider these aspects: + - Age-related factors + - Chronic disease impact + - Medication history + - Social and environmental factors + + Document: + - Physical examination findings + - System-specific symptoms + - Relevant lab abnormalities + - Risk factors for complications""", + model_name="gpt-4o", + max_loops=1, +) -worker1 = Agent( - agent_name="Worker1", - system_prompt="Generates a transcript for a youtube video on what swarms are", - model_name="claude-2", +# Diagnostic Synthesizer +synthesizer = Agent( + agent_name="Diagnostic Synthesizer", + system_prompt="""You are responsible for synthesizing all specialist inputs to create a final diagnostic assessment: + + Core responsibilities: + 1. Integrate findings from all specialists + 2. Identify patterns and correlations + 3. Resolve conflicting opinions + 4. Generate probability-ranked differential diagnoses + 5. Recommend additional testing if needed + + Analysis framework: + - Weight evidence based on reliability and specificity + - Consider epidemiological factors + - Evaluate diagnostic certainty + - Account for test limitations + + Provide structured output including: + 1. Primary diagnosis with confidence level + 2. Supporting evidence summary + 3. Alternative diagnoses to consider + 4. Recommended confirmatory tests + 5. Red flags or warning signs + 6. Follow-up recommendations + + Documentation requirements: + - Clear reasoning chain + - Evidence quality assessment + - Confidence levels for each diagnosis + - Knowledge gaps identified + - Risk assessment""", + model_name="gpt-4o", max_loops=1, - dashboard=False, - streaming_on=True, - verbose=True, - stopping_token="", - state_save_file_type="json", - saved_state_path="worker1.json", ) +# Create agent list +agents = [chief_medical_officer, virologist, internist, synthesizer] + +# Define diagnostic flow +flow = f"""{chief_medical_officer.agent_name} -> {virologist.agent_name} -> {internist.agent_name} -> {synthesizer.agent_name}""" -# Initialize worker 2 -worker2 = Agent( - agent_name="Worker2", - system_prompt="Summarizes the transcript generated by Worker1", - model_name="claude-2", +# Create the swarm system +diagnosis_system = AgentRearrange( + name="Medical-nlp-diagnosis-swarm", + description="natural language symptions to diagnosis report", + agents=agents, + flow=flow, max_loops=1, - dashboard=False, - streaming_on=True, - verbose=True, - stopping_token="", - state_save_file_type="json", - saved_state_path="worker2.json", + output_type="all", ) -# Create a list of agents -agents = [director, worker1, worker2] +# Example usage +if __name__ == "__main__": + # Example patient case + patient_case = """ + Patient: 45-year-old female + Presenting symptoms: + - Fever (101.5°F) for 3 days + - Dry cough + - Fatigue + - Mild shortness of breath + Medical history: + - Controlled hypertension + - No recent travel + - Fully vaccinated for COVID-19 + - No known sick contacts + """ -# Define the flow pattern -flow = "Director -> Worker1 -> Worker2" + # Add timestamp to the patient case + case_info = f"Timestamp: {datetime.now()}\nPatient Information: {patient_case}" + + # Run the diagnostic process + diagnosis = diagnosis_system.run(case_info) + + # Create a folder and file called reports + create_file_in_folder( + "reports", "medical_analysis_agent_rearrange.md", diagnosis + ) -# Using AgentRearrange class -agent_system = AgentRearrange(agents=agents, flow=flow) -output = agent_system.run( - "Create a format to express and communicate swarms of llms in a structured manner for youtube" -) -print(output) ``` From d070f7e317adb94f88757551419c00719d02be2b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:52:58 +0000 Subject: [PATCH 04/18] Bump returntocorp/semgrep-action Bumps [returntocorp/semgrep-action](https://github.com/returntocorp/semgrep-action) from fcd5ab7459e8d91cb1777481980d1b18b4fc6735 to 713efdd345f3035192eaa63f56867b88e63e4e5d. - [Changelog](https://github.com/returntocorp/semgrep-action/blob/develop/CHANGELOG.md) - [Commits](https://github.com/returntocorp/semgrep-action/compare/fcd5ab7459e8d91cb1777481980d1b18b4fc6735...713efdd345f3035192eaa63f56867b88e63e4e5d) --- updated-dependencies: - dependency-name: returntocorp/semgrep-action dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- .github/workflows/semgrep.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 1e78a687..4a122c7b 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -35,7 +35,7 @@ jobs: - uses: actions/checkout@v4 # Scan code using project's configuration on https://semgrep.dev/manage - - uses: returntocorp/semgrep-action@fcd5ab7459e8d91cb1777481980d1b18b4fc6735 + - uses: returntocorp/semgrep-action@713efdd345f3035192eaa63f56867b88e63e4e5d with: publishToken: ${{ secrets.SEMGREP_APP_TOKEN }} publishDeployment: ${{ secrets.SEMGREP_DEPLOYMENT_ID }} From 8af39867ec819598fff549dcf6c5f5a661e4d8aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:53:01 +0000 Subject: [PATCH 05/18] Bump actions/setup-python from 3 to 5 Bumps [actions/setup-python](https://github.com/actions/setup-python) from 3 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v3...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/python-package-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index f3586044..51c99bba 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.10' - name: Add conda to system path From 66fcea3b5a3420d1930c66be531b7c71df7a5d19 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:53:03 +0000 Subject: [PATCH 06/18] Bump facebook/pyre-action from 0.0.1 to 0.0.2 Bumps [facebook/pyre-action](https://github.com/facebook/pyre-action) from 0.0.1 to 0.0.2. - [Release notes](https://github.com/facebook/pyre-action/releases) - [Commits](https://github.com/facebook/pyre-action/compare/60697a7858f7cc8470d8cc494a3cf2ad6b06560d...12b8d923443ea66cb657facc2e5faac1c8c86e64) --- updated-dependencies: - dependency-name: facebook/pyre-action dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/pyre.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml index 2e4713d3..53aca44d 100644 --- a/.github/workflows/pyre.yml +++ b/.github/workflows/pyre.yml @@ -38,7 +38,7 @@ jobs: submodules: true - name: Run Pyre - uses: facebook/pyre-action@60697a7858f7cc8470d8cc494a3cf2ad6b06560d + uses: facebook/pyre-action@12b8d923443ea66cb657facc2e5faac1c8c86e64 with: # To customize these inputs: # See https://github.com/facebook/pyre-action#inputs From 4c1c143fe531bdbf50162c193175cbf47d299d9a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:53:06 +0000 Subject: [PATCH 07/18] Bump aquasecurity/trivy-action from 0.5.0 to 0.29.0 Bumps [aquasecurity/trivy-action](https://github.com/aquasecurity/trivy-action) from 0.5.0 to 0.29.0. - [Release notes](https://github.com/aquasecurity/trivy-action/releases) - [Commits](https://github.com/aquasecurity/trivy-action/compare/7b7aa264d83dc58691451798b4d117d53d21edfe...18f2510ee396bbf400402947b394f2dd8c87dbb0) --- updated-dependencies: - dependency-name: aquasecurity/trivy-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/trivy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/trivy.yml b/.github/workflows/trivy.yml index d9e6c82b..112bdf93 100644 --- a/.github/workflows/trivy.yml +++ b/.github/workflows/trivy.yml @@ -34,7 +34,7 @@ jobs: docker build -t docker.io/my-organization/my-app:${{ github.sha }} . - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@7b7aa264d83dc58691451798b4d117d53d21edfe + uses: aquasecurity/trivy-action@18f2510ee396bbf400402947b394f2dd8c87dbb0 with: image-ref: 'docker.io/my-organization/my-app:${{ github.sha }}' format: 'template' From 57b14987e4c4f3cbf6466894b2a6c0a308369a29 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:19:36 +0000 Subject: [PATCH 08/18] Bump pypdf from 4.3.1 to 5.1.0 Bumps [pypdf](https://github.com/py-pdf/pypdf) from 4.3.1 to 5.1.0. - [Release notes](https://github.com/py-pdf/pypdf/releases) - [Changelog](https://github.com/py-pdf/pypdf/blob/main/CHANGELOG.md) - [Commits](https://github.com/py-pdf/pypdf/compare/4.3.1...5.1.0) --- updated-dependencies: - dependency-name: pypdf dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e5375a0d..5d05e3e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ torch>=2.1.1,<3.0 transformers>=4.39.0,<5.0.0 asyncio>=3.4.3,<4.0 toml -pypdf==4.3.1 +pypdf==5.1.0 ratelimit==2.2.1 loguru pydantic==2.8.2 From 1a85dd33416b1ec3b4a802b1bf383db18131d4fe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:19:58 +0000 Subject: [PATCH 09/18] Update ruff requirement from >=0.5.1,<0.8.2 to >=0.5.1,<0.8.3 --- updated-dependencies: - dependency-name: ruff dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0cc0a373..1d1fcaed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ swarms = "swarms.cli.main:main" [tool.poetry.group.lint.dependencies] black = ">=23.1,<25.0" -ruff = ">=0.5.1,<0.8.2" +ruff = ">=0.5.1,<0.8.3" types-toml = "^0.10.8.1" types-pytz = ">=2023.3,<2025.0" types-chardet = "^5.0.4.6" From 0e626a686e600152e4d81b194143f20585ae182e Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:13:41 -0800 Subject: [PATCH 10/18] Delete byte.py --- byte.py | 898 -------------------------------------------------------- 1 file changed, 898 deletions(-) delete mode 100644 byte.py diff --git a/byte.py b/byte.py deleted file mode 100644 index d0a5a92f..00000000 --- a/byte.py +++ /dev/null @@ -1,898 +0,0 @@ -from enum import Enum -from typing import Union, Optional -import io -from PIL import Image -import numpy as np -import torch -import struct - - -from enum import auto -from typing import List, Dict, Tuple -import wave -from dataclasses import dataclass -import torch.nn as nn -import torch.nn.functional as F -from loguru import logger -from einops import rearrange -from torch import Tensor - - -@dataclass -class ModelConfig: - """Configuration for the enhanced BytePredictor model.""" - - vocab_size: int = 256 # Standard byte range - hidden_size: int = 1024 - num_layers: int = 12 - num_key_value_heads: int = 8 # For multi-query attention - num_query_heads: int = 32 # More query heads than kv heads - dropout: float = 0.1 - max_sequence_length: int = 8192 - rope_theta: float = 10000.0 - layer_norm_eps: float = 1e-5 - vocab_parallel: bool = False - qk_norm: bool = True - qk_norm_scale: float = None - attention_bias: bool = False - - -class MultiQueryAttention(nn.Module): - """Fixed Multi-Query Attention implementation.""" - - def __init__(self, config: ModelConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.num_query_heads = config.num_query_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_query_heads - self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5) - - self.q_proj = nn.Linear( - config.hidden_size, config.num_query_heads * self.head_dim - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - ) - self.o_proj = nn.Linear( - config.num_query_heads * self.head_dim, config.hidden_size - ) - - self.qk_norm = config.qk_norm - if self.qk_norm: - self.q_norm = nn.LayerNorm(self.head_dim) - self.k_norm = nn.LayerNorm(self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - batch_size, seq_length, _ = hidden_states.shape - - # Project and reshape - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Reshape to [seq_len, batch, heads, head_dim] - q = q.view( - batch_size, - seq_length, - self.num_query_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - k = k.view( - batch_size, - seq_length, - self.num_key_value_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - v = v.view( - batch_size, - seq_length, - self.num_key_value_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - - # Apply rotary embeddings - # q, k = self.rotary(q, k, seq_length) - - # Apply QK normalization if enabled - if self.qk_norm: - q = self.q_norm(q) - k = self.k_norm(k) - - # Handle MQA head expansion - if self.num_key_value_heads != self.num_query_heads: - k = k.repeat_interleave( - self.num_query_heads // self.num_key_value_heads, - dim=2, - ) - v = v.repeat_interleave( - self.num_query_heads // self.num_key_value_heads, - dim=2, - ) - - # Compute attention - # Reshape for matmul: [batch, heads, seq_length, head_dim] - q = q.permute(1, 2, 0, 3) - k = k.permute(1, 2, 0, 3) - v = v.permute(1, 2, 0, 3) - - attn_weights = ( - torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = F.softmax(attn_weights, dim=-1) - - output = torch.matmul(attn_weights, v) - - # Reshape back to [batch, seq_length, hidden_size] - output = ( - output.transpose(1, 2) - .contiguous() - .view(batch_size, seq_length, -1) - ) - output = self.o_proj(output) - - return output - - -class EnhancedBytePredictor(nn.Module): - """Enhanced byte prediction model with state-of-the-art features.""" - - def __init__(self, config: ModelConfig): - super().__init__() - self.config = config - - # Token embeddings - self.tok_embeddings = nn.Embedding( - config.vocab_size, config.hidden_size - ) - - # Transformer layers - self.layers = nn.ModuleList( - [ - nn.ModuleDict( - { - "attention": MultiQueryAttention(config), - "attention_norm": nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ), - "feed_forward": nn.Sequential( - nn.Linear( - config.hidden_size, - 4 * config.hidden_size, - ), - nn.GELU(), - nn.Linear( - 4 * config.hidden_size, - config.hidden_size, - ), - ), - "feed_forward_norm": nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ), - } - ) - for _ in range(config.num_layers) - ] - ) - - self.norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.output = nn.Linear( - config.hidden_size, config.vocab_size, bias=False - ) - - # Initialize weights - self.apply(self._init_weights) - - def _init_weights(self, module: nn.Module) -> None: - """Initialize weights with scaled normal distribution.""" - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass of the model. - - Args: - input_ids: Tensor of shape (batch_size, sequence_length) - attention_mask: Optional attention mask - - Returns: - Tensor of logits with shape (batch_size, sequence_length, vocab_size) - """ - hidden_states = self.tok_embeddings(input_ids) - - # Create causal mask if needed - if attention_mask is None: - attention_mask = torch.triu( - torch.ones( - (input_ids.size(1), input_ids.size(1)), - device=input_ids.device, - dtype=torch.bool, - ), - diagonal=1, - ) - attention_mask = attention_mask.masked_fill( - attention_mask == 1, float("-inf") - ) - - # Apply transformer layers - for layer in self.layers: - # Attention block - hidden_states = hidden_states + layer["attention"]( - layer["attention_norm"](hidden_states), attention_mask - ) - - # Feed-forward block - hidden_states = hidden_states + layer["feed_forward"]( - layer["feed_forward_norm"](hidden_states) - ) - - hidden_states = self.norm(hidden_states) - logits = self.output(hidden_states) - - return logits - - def compute_loss( - self, - input_ids: torch.Tensor, - target_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Compute cross entropy loss. - - Args: - input_ids: Input token ids - target_ids: Target token ids - attention_mask: Optional attention mask - - Returns: - Loss value - """ - logits = self(input_ids, attention_mask) - loss = F.cross_entropy( - rearrange(logits, "b s v -> (b s) v"), - rearrange(target_ids, "b s -> (b s)"), - ) - return loss - - @torch.no_grad() - def _generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 100, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: float = 1.0, - ) -> torch.Tensor: - """ - Generate new tokens autoregressively. - - Args: - input_ids: Starting sequence - max_new_tokens: Number of tokens to generate - temperature: Sampling temperature - top_k: K for top-k sampling - top_p: P for nucleus sampling - repetition_penalty: Penalty for repeating tokens - - Returns: - Generated sequence - """ - batch_size, seq_length = input_ids.shape - generated = input_ids.clone() - - for _ in range(max_new_tokens): - if generated.size(1) >= self.config.max_sequence_length: - break - - # Forward pass - logits = self(generated)[:, -1, :] - - # Apply temperature - logits = logits / temperature - - # Apply repetition penalty - if repetition_penalty != 1.0: - for i in range(batch_size): - for token_id in set(generated[i].tolist()): - logits[i, token_id] /= repetition_penalty - - # Apply top-k sampling - if top_k is not None: - indices_to_remove = ( - logits - < torch.topk(logits, top_k)[0][..., -1, None] - ) - logits[indices_to_remove] = float("-inf") - - # Apply nucleus (top-p) sampling - if top_p is not None: - sorted_logits, sorted_indices = torch.sort( - logits, descending=True - ) - cumulative_probs = torch.cumsum( - F.softmax(sorted_logits, dim=-1), dim=-1 - ) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = ( - sorted_indices_to_remove[..., :-1].clone() - ) - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = torch.zeros_like( - logits, dtype=torch.bool - ) - indices_to_remove.scatter_( - 1, sorted_indices, sorted_indices_to_remove - ) - logits[indices_to_remove] = float("-inf") - - # Sample next token - probs = F.softmax(logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - - # Append to sequence - generated = torch.cat([generated, next_token], dim=1) - - return generated - - def generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 100, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: float = 1.0, - ): - tensor_data = self._generate( - input_ids=input_ids, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - ) - - return tensor_to_data(tensor_data) - - -# import torch -# from typing import Optional - - -class DataType(Enum): - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - BINARY = "binary" - - -class ByteDetokenizer: - """Utility class for converting model output bytes back to original data formats.""" - - @staticmethod - def tensor_to_bytes(tensor: torch.Tensor) -> bytes: - """Convert model output tensor to bytes.""" - # Convert logits/probabilities to byte values - if tensor.dim() > 1: - # If we have logits, convert to byte indices - byte_indices = tensor.argmax(dim=-1) - else: - byte_indices = tensor - - # Convert to Python bytes - return bytes( - byte_indices.cpu().numpy().astype(np.uint8).tolist() - ) - - @staticmethod - def decode_text(byte_sequence: bytes) -> str: - """Convert bytes to text.""" - try: - return byte_sequence.decode("utf-8") - except UnicodeDecodeError: - # Try with error handling - return byte_sequence.decode("utf-8", errors="replace") - - @staticmethod - def decode_image( - byte_sequence: bytes, - mode: str = "RGB", - size: Optional[tuple] = None, - ) -> Image.Image: - """Convert bytes to image. - - Args: - byte_sequence: Raw image bytes - mode: Image mode (RGB, RGBA, L, etc.) - size: Optional tuple of (width, height) - """ - try: - # Try to load as-is first (for standard image formats) - img = Image.open(io.BytesIO(byte_sequence)) - if size: - img = img.resize(size) - return img - except: - # If failed, assume raw pixel data - if not size: - # Try to determine size from byte count - pixel_count = len(byte_sequence) // len(mode) - size = ( - int(np.sqrt(pixel_count)), - int(np.sqrt(pixel_count)), - ) - - # Convert raw bytes to pixel array - pixels = np.frombuffer(byte_sequence, dtype=np.uint8) - pixels = pixels.reshape((*size, len(mode))) - - return Image.fromarray(pixels, mode=mode) - - @staticmethod - def decode_audio( - byte_sequence: bytes, - sample_rate: int = 44100, - channels: int = 2, - sample_width: int = 2, - ) -> np.ndarray: - """Convert bytes to audio samples. - - Args: - byte_sequence: Raw audio bytes - sample_rate: Audio sample rate in Hz - channels: Number of audio channels - sample_width: Bytes per sample (1, 2, or 4) - """ - # Determine format string based on sample width - format_str = { - 1: "b", # signed char - 2: "h", # short - 4: "i", # int - }[sample_width] - - # Unpack bytes to samples - sample_count = len(byte_sequence) // (channels * sample_width) - samples = struct.unpack( - f"<{sample_count * channels}{format_str}", byte_sequence - ) - - # Reshape to [samples, channels] - return np.array(samples).reshape(-1, channels) - - def decode_data( - self, - model_output: Union[torch.Tensor, bytes], - data_type: DataType, - **kwargs, - ) -> Union[str, Image.Image, np.ndarray, bytes]: - """Main method to decode model output to desired format. - - Args: - model_output: Either tensor from model or raw bytes - data_type: Type of data to decode to - **kwargs: Additional parameters for specific decoders - - Returns: - Decoded data in specified format - """ - # Convert tensor to bytes if needed - if isinstance(model_output, torch.Tensor): - byte_sequence = self.tensor_to_bytes(model_output) - else: - byte_sequence = model_output - - # Decode based on type - if data_type == DataType.TEXT: - return self.decode_text(byte_sequence) - elif data_type == DataType.IMAGE: - return self.decode_image(byte_sequence, **kwargs) - elif data_type == DataType.AUDIO: - return self.decode_audio(byte_sequence, **kwargs) - elif data_type == DataType.VIDEO: - raise NotImplementedError( - "Video decoding not yet implemented" - ) - else: # BINARY - return byte_sequence - - -# Usage example - - -class Modality(Enum): - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - BINARY = auto() - MULTIMODAL = auto() - - -@dataclass -class ModalityInfo: - """Information about detected modality.""" - - modality: Modality - confidence: float - metadata: Dict[str, any] - sub_modalities: Optional[List["ModalityInfo"]] = None - - -class ModalityDetector: - """Detects data modalities from byte sequences.""" - - # Common file signatures (magic numbers) - SIGNATURES = { - # Images - b"\xFF\xD8\xFF": "JPEG", - b"\x89PNG\r\n\x1a\n": "PNG", - b"GIF87a": "GIF", - b"GIF89a": "GIF", - b"RIFF": "WEBP", - # Audio - b"RIFF....WAVE": "WAV", - b"ID3": "MP3", - b"\xFF\xFB": "MP3", - b"OggS": "OGG", - # Video - b"\x00\x00\x00\x18ftypmp42": "MP4", - b"\x00\x00\x00\x1Cftypav01": "MP4", - b"\x1A\x45\xDF\xA3": "WEBM", - } - - def __init__(self): - self.magic = magic.Magic(mime=True) - - def _check_text_probability(self, data: bytes) -> float: - """Estimate probability that data is text.""" - # Check if data is valid UTF-8 - try: - data.decode("utf-8") - # Count printable ASCII characters - printable = sum(1 for b in data if 32 <= b <= 126) - return printable / len(data) - except UnicodeDecodeError: - return 0.0 - - def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]: - """Check if data is a valid image and extract metadata.""" - try: - with io.BytesIO(data) as bio: - img = Image.open(bio) - return True, { - "format": img.format, - "size": img.size, - "mode": img.mode, - } - except: - return False, {} - - def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]: - """Check if data is valid audio and extract metadata.""" - try: - with io.BytesIO(data) as bio: - # Try to parse as WAV - with wave.open(bio) as wav: - return True, { - "channels": wav.getnchannels(), - "sample_width": wav.getsampwidth(), - "framerate": wav.getframerate(), - "frames": wav.getnframes(), - } - except: - # Check for other audio signatures - for sig in [b"ID3", b"\xFF\xFB", b"OggS"]: - if data.startswith(sig): - return True, {"format": "compressed_audio"} - return False, {} - - def _detect_boundaries( - self, data: bytes - ) -> List[Tuple[int, int, Modality]]: - """Detect boundaries between different modalities in the data.""" - boundaries = [] - current_pos = 0 - - while current_pos < len(data): - # Look for known signatures - for sig, format_type in self.SIGNATURES.items(): - if data[current_pos:].startswith(sig): - # Found a signature, determine its length - if format_type in ["JPEG", "PNG", "GIF"]: - # Find image end - try: - with io.BytesIO( - data[current_pos:] - ) as bio: - img = Image.open(bio) - img.verify() - size = bio.tell() - boundaries.append( - ( - current_pos, - current_pos + size, - Modality.IMAGE, - ) - ) - current_pos += size - continue - except: - pass - - # Check for text sections - text_prob = self._check_text_probability( - data[current_pos : current_pos + 1024] - ) - if text_prob > 0.8: - # Look for end of text section - end_pos = current_pos + 1 - while end_pos < len(data): - if ( - self._check_text_probability( - data[end_pos : end_pos + 32] - ) - < 0.5 - ): - break - end_pos += 1 - boundaries.append( - (current_pos, end_pos, Modality.TEXT) - ) - current_pos = end_pos - continue - - current_pos += 1 - - return boundaries - - def detect_modality(self, data: bytes) -> ModalityInfo: - """Detect modality of byte sequence.""" - # First check for single modality - mime_type = self.magic.from_buffer(data) - - # Check text - text_prob = self._check_text_probability(data) - if text_prob > 0.9: - return ModalityInfo( - modality=Modality.TEXT, - confidence=text_prob, - metadata={"mime_type": mime_type}, - ) - - # Check image - is_image, image_meta = self._check_image_validity(data) - if is_image: - return ModalityInfo( - modality=Modality.IMAGE, - confidence=1.0, - metadata={**image_meta, "mime_type": mime_type}, - ) - - # Check audio - is_audio, audio_meta = self._check_audio_validity(data) - if is_audio: - return ModalityInfo( - modality=Modality.AUDIO, - confidence=1.0, - metadata={**audio_meta, "mime_type": mime_type}, - ) - - # Check for multimodal content - boundaries = self._detect_boundaries(data) - if len(boundaries) > 1: - sub_modalities = [] - for start, end, modality in boundaries: - chunk_data = data[start:end] - sub_info = self.detect_modality(chunk_data) - if sub_info.modality != Modality.BINARY: - sub_modalities.append(sub_info) - - if sub_modalities: - return ModalityInfo( - modality=Modality.MULTIMODAL, - confidence=0.8, - metadata={"mime_type": "multipart/mixed"}, - sub_modalities=sub_modalities, - ) - - # Default to binary - return ModalityInfo( - modality=Modality.BINARY, - confidence=0.5, - metadata={"mime_type": mime_type}, - ) - - def split_modalities( - self, data: bytes - ) -> List[Tuple[Modality, bytes, Dict]]: - """Split multimodal data into separate modalities.""" - boundaries = self._detect_boundaries(data) - result = [] - - for start, end, modality in boundaries: - chunk = data[start:end] - info = self.detect_modality(chunk) - result.append((modality, chunk, info.metadata)) - - return result - - -class AutoDetectBytesDecoder: - """Decoder that automatically detects and decodes different modalities.""" - - def __init__(self): - self.detector = ModalityDetector() - self.text_decoder = ByteDetokenizer() # From previous example - - def decode( - self, data: bytes - ) -> Union[str, Image.Image, np.ndarray, List[any]]: - """Automatically detect and decode byte sequence.""" - info = self.detector.detect_modality(data) - - if info.modality == Modality.MULTIMODAL: - # Handle multimodal content - parts = self.detector.split_modalities(data) - return [ - self.decode(chunk) for modality, chunk, _ in parts - ] - - if info.modality == Modality.TEXT: - return self.text_decoder.decode_text(data) - elif info.modality == Modality.IMAGE: - return self.text_decoder.decode_image(data) - elif info.modality == Modality.AUDIO: - return self.text_decoder.decode_audio(data) - else: - return data - - -# # Example usage -# def demo_auto_detection(): -# """Demonstrate auto modality detection.""" -# # Create mixed content -# text = "Hello, World!".encode('utf-8') - -# # Create a small test image -# img = Image.new('RGB', (100, 100), color='red') -# img_bytes = io.BytesIO() -# img.save(img_bytes, format='PNG') - -# # Combine into multimodal content -# mixed_content = text + img_bytes.getvalue() - -# # Initialize decoder -# decoder = AutoDetectBytesDecoder() - -# # Decode -# result = decoder.decode(mixed_content) - -# if isinstance(result, list): -# print("Detected multimodal content:") -# for i, part in enumerate(result): -# print(f"Part {i+1}: {type(part)}") - -# if __name__ == "__main__": -# demo_auto_detection() - - -def tensor_to_data(tensor: Tensor): - byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor) - - # Initialize auto-detector - decoder = AutoDetectBytesDecoder() - - # Decode with automatic detection - result = decoder.decode(byte_sequence) - - return result - - -def demo_byte_predictor(): - """Demo with smaller dimensions to test.""" - # Initialize model configuration with adjusted dimensions - config = ModelConfig( - vocab_size=256, - hidden_size=128, # Smaller for testing - num_layers=2, # Fewer layers for testing - num_key_value_heads=2, - num_query_heads=4, - dropout=0.1, - max_sequence_length=1024, - ) - - # Initialize model - model = EnhancedBytePredictor(config) - logger.info("Model initialized") - - # Move to GPU if available - device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - model = model.to(device) - logger.info(f"Using device: {device}") - - # Create sample input data - batch_size = 2 - seq_length = 16 # Shorter sequence for testing - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seq_length), device=device - ) - logger.info(f"Created input tensor of shape: {input_ids.shape}") - - # Test forward pass - try: - logits = model(input_ids) - logger.info( - f"Forward pass successful! Output shape: {logits.shape}" - ) - - # Test loss computation - target_ids = torch.randint( - 0, - config.vocab_size, - (batch_size, seq_length), - device=device, - ) - loss = model.compute_loss(input_ids, target_ids) - logger.info( - f"Loss computation successful! Loss value: {loss.item():.4f}" - ) - - # Test generation - prompt = torch.randint( - 0, - config.vocab_size, - (1, 4), # Very short prompt for testing - device=device, - ) - generated = model.generate( - prompt, max_new_tokens=8, temperature=0.8, top_k=50 - ) - logger.info( - f"Generation successful! Generated shape: {generated.shape}" - ) - - except Exception as e: - logger.error(f"Error during execution: {str(e)}") - raise - - -if __name__ == "__main__": - # Set up logging - # logger.remove() # Remove default handler - # logger.add(sys.stderr, format="{time:HH:mm:ss} | {level} | {message}") - - demo_byte_predictor() From 55018c636a3c0f7d49a21c1a7e5a69be388dc20f Mon Sep 17 00:00:00 2001 From: mike dupont Date: Sat, 7 Dec 2024 12:54:06 -0500 Subject: [PATCH 11/18] adding emacs --- .gitignore | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/.gitignore b/.gitignore index 9f6e25b6..65ce495c 100644 --- a/.gitignore +++ b/.gitignore @@ -224,3 +224,52 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .vscode/settings.json +# -*- mode: gitignore; -*- +*~ +\#*\# +/.emacs.desktop +/.emacs.desktop.lock +*.elc +auto-save-list +tramp +.\#* + +# Org-mode +.org-id-locations +*_archive + +# flymake-mode +*_flymake.* + +# eshell files +/eshell/history +/eshell/lastdir + +# elpa packages +/elpa/ + +# reftex files +*.rel + +# AUCTeX auto folder +/auto/ + +# cask packages +.cask/ +dist/ + +# Flycheck +flycheck_*.el + +# server auth directory +/server/ + +# projectiles files +.projectile + +# directory configuration +.dir-locals.el + +# network security +/network-security.data + From 449b2db79ed82532a3e2c915510079ebdf24fd6d Mon Sep 17 00:00:00 2001 From: mike dupont Date: Sat, 7 Dec 2024 16:00:03 -0500 Subject: [PATCH 12/18] adding main --- api/agent_api.py | 3 + api/main.py | 638 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 641 insertions(+) create mode 100644 api/main.py diff --git a/api/agent_api.py b/api/agent_api.py index d1968d9d..83d05101 100644 --- a/api/agent_api.py +++ b/api/agent_api.py @@ -619,6 +619,7 @@ def create_app() -> FastAPI: if __name__ == "__main__": # Configure uvicorn logging + print("in main") logger.info("API Starting") uvicorn.run( "main:create_app", @@ -627,3 +628,5 @@ if __name__ == "__main__": reload=True, workers=4, ) +else: + print("not in main") diff --git a/api/main.py b/api/main.py new file mode 100644 index 00000000..768e8d96 --- /dev/null +++ b/api/main.py @@ -0,0 +1,638 @@ +import os +from fastapi import ( + FastAPI, + HTTPException, + status, + Query, + BackgroundTasks, +) +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from typing import Optional, Dict, Any, List +from loguru import logger +import uvicorn +from datetime import datetime, timedelta +from uuid import UUID, uuid4 +from enum import Enum +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +import traceback + +from swarms import Agent +from dotenv import load_dotenv + +print ("starting") +# Load environment variables +load_dotenv() + +# Configure Loguru +logger.add( + "logs/api_{time}.log", + rotation="500 MB", + retention="10 days", + level="INFO", + format="{time} {level} {message}", + backtrace=True, + diagnose=True, +) + + +class AgentStatus(str, Enum): + """Enum for agent status.""" + + IDLE = "idle" + PROCESSING = "processing" + ERROR = "error" + MAINTENANCE = "maintenance" + + +class AgentConfig(BaseModel): + """Configuration model for creating a new agent.""" + + agent_name: str = Field(..., description="Name of the agent") + model_name: str = Field( + ..., + description="Name of the llm you want to use provided by litellm", + ) + description: str = Field( + default="", description="Description of the agent's purpose" + ) + system_prompt: str = Field( + ..., description="System prompt for the agent" + ) + model_name: str = Field( + default="gpt-4", description="Model name to use" + ) + temperature: float = Field( + default=0.1, + ge=0.0, + le=2.0, + description="Temperature for the model", + ) + max_loops: int = Field( + default=1, ge=1, description="Maximum number of loops" + ) + autosave: bool = Field( + default=True, description="Enable autosave" + ) + dashboard: bool = Field( + default=False, description="Enable dashboard" + ) + verbose: bool = Field( + default=True, description="Enable verbose output" + ) + dynamic_temperature_enabled: bool = Field( + default=True, description="Enable dynamic temperature" + ) + user_name: str = Field( + default="default_user", description="Username for the agent" + ) + retry_attempts: int = Field( + default=1, ge=1, description="Number of retry attempts" + ) + context_length: int = Field( + default=200000, ge=1000, description="Context length" + ) + output_type: str = Field( + default="string", description="Output type (string or json)" + ) + streaming_on: bool = Field( + default=False, description="Enable streaming" + ) + tags: List[str] = Field( + default_factory=list, + description="Tags for categorizing the agent", + ) + + +class AgentUpdate(BaseModel): + """Model for updating agent configuration.""" + + description: Optional[str] = None + system_prompt: Optional[str] = None + temperature: Optional[float] = None + max_loops: Optional[int] = None + tags: Optional[List[str]] = None + status: Optional[AgentStatus] = None + + +class AgentSummary(BaseModel): + """Summary model for agent listing.""" + + agent_id: UUID + agent_name: str + description: str + created_at: datetime + last_used: datetime + total_completions: int + tags: List[str] + status: AgentStatus + + +class AgentMetrics(BaseModel): + """Model for agent performance metrics.""" + + total_completions: int + average_response_time: float + error_rate: float + last_24h_completions: int + total_tokens_used: int + uptime_percentage: float + success_rate: float + peak_tokens_per_minute: int + + +class CompletionRequest(BaseModel): + """Model for completion requests.""" + + prompt: str = Field(..., description="The prompt to process") + agent_id: UUID = Field(..., description="ID of the agent to use") + max_tokens: Optional[int] = Field( + None, description="Maximum tokens to generate" + ) + temperature_override: Optional[float] = None + stream: bool = Field( + default=False, description="Enable streaming response" + ) + + +class CompletionResponse(BaseModel): + """Model for completion responses.""" + + agent_id: UUID + response: str + metadata: Dict[str, Any] + timestamp: datetime + processing_time: float + token_usage: Dict[str, int] + + +class AgentStore: + """Enhanced store for managing agents.""" + + def __init__(self): + self.agents: Dict[UUID, Agent] = {} + self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} + self.executor = ThreadPoolExecutor(max_workers=4) + self._ensure_directories() + + def _ensure_directories(self): + """Ensure required directories exist.""" + Path("logs").mkdir(exist_ok=True) + Path("states").mkdir(exist_ok=True) + + async def create_agent(self, config: AgentConfig) -> UUID: + """Create a new agent with the given configuration.""" + try: + + agent = Agent( + agent_name=config.agent_name, + system_prompt=config.system_prompt, + model_name=config.model_name, + max_loops=config.max_loops, + autosave=config.autosave, + dashboard=config.dashboard, + verbose=config.verbose, + dynamic_temperature_enabled=config.dynamic_temperature_enabled, + saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", + user_name=config.user_name, + retry_attempts=config.retry_attempts, + context_length=config.context_length, + return_step_meta=True, + output_type="str", + streaming_on=config.streaming_on, + ) + + agent_id = uuid4() + self.agents[agent_id] = agent + self.agent_metadata[agent_id] = { + "description": config.description, + "created_at": datetime.utcnow(), + "last_used": datetime.utcnow(), + "total_completions": 0, + "tags": config.tags, + "total_tokens": 0, + "error_count": 0, + "response_times": [], + "status": AgentStatus.IDLE, + "start_time": datetime.utcnow(), + "downtime": timedelta(), + "successful_completions": 0, + } + + logger.info(f"Created agent with ID: {agent_id}") + return agent_id + + except Exception as e: + logger.error(f"Error creating agent: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create agent: {str(e)}", + ) + + async def get_agent(self, agent_id: UUID) -> Agent: + """Retrieve an agent by ID.""" + agent = self.agents.get(agent_id) + if not agent: + logger.error(f"Agent not found: {agent_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + return agent + + async def update_agent( + self, agent_id: UUID, update: AgentUpdate + ) -> None: + """Update agent configuration.""" + agent = await self.get_agent(agent_id) + metadata = self.agent_metadata[agent_id] + + if update.system_prompt: + agent.system_prompt = update.system_prompt + if update.temperature is not None: + agent.llm.temperature = update.temperature + if update.max_loops is not None: + agent.max_loops = update.max_loops + if update.tags is not None: + metadata["tags"] = update.tags + if update.description is not None: + metadata["description"] = update.description + if update.status is not None: + metadata["status"] = update.status + if update.status == AgentStatus.MAINTENANCE: + metadata["downtime"] += ( + datetime.utcnow() - metadata["last_used"] + ) + + logger.info(f"Updated agent {agent_id}") + + async def list_agents( + self, + tags: Optional[List[str]] = None, + status: Optional[AgentStatus] = None, + ) -> List[AgentSummary]: + """List all agents, optionally filtered by tags and status.""" + summaries = [] + for agent_id, agent in self.agents.items(): + metadata = self.agent_metadata[agent_id] + + # Apply filters + if tags and not any( + tag in metadata["tags"] for tag in tags + ): + continue + if status and metadata["status"] != status: + continue + + summaries.append( + AgentSummary( + agent_id=agent_id, + agent_name=agent.agent_name, + description=metadata["description"], + created_at=metadata["created_at"], + last_used=metadata["last_used"], + total_completions=metadata["total_completions"], + tags=metadata["tags"], + status=metadata["status"], + ) + ) + return summaries + + async def get_agent_metrics(self, agent_id: UUID) -> AgentMetrics: + """Get performance metrics for an agent.""" + metadata = self.agent_metadata[agent_id] + response_times = metadata["response_times"] + + # Calculate metrics + total_time = datetime.utcnow() - metadata["start_time"] + uptime = total_time - metadata["downtime"] + uptime_percentage = ( + uptime.total_seconds() / total_time.total_seconds() + ) * 100 + + success_rate = ( + metadata["successful_completions"] + / metadata["total_completions"] + * 100 + if metadata["total_completions"] > 0 + else 0 + ) + + return AgentMetrics( + total_completions=metadata["total_completions"], + average_response_time=( + sum(response_times) / len(response_times) + if response_times + else 0 + ), + error_rate=( + metadata["error_count"] + / metadata["total_completions"] + if metadata["total_completions"] > 0 + else 0 + ), + last_24h_completions=sum( + 1 + for t in response_times + if (datetime.utcnow() - t).days < 1 + ), + total_tokens_used=metadata["total_tokens"], + uptime_percentage=uptime_percentage, + success_rate=success_rate, + peak_tokens_per_minute=max( + metadata.get("tokens_per_minute", [0]) + ), + ) + + async def clone_agent( + self, agent_id: UUID, new_name: str + ) -> UUID: + """Clone an existing agent with a new name.""" + original_agent = await self.get_agent(agent_id) + original_metadata = self.agent_metadata[agent_id] + + config = AgentConfig( + agent_name=new_name, + description=f"Clone of {original_agent.agent_name}", + system_prompt=original_agent.system_prompt, + model_name=original_agent.llm.model_name, + temperature=original_agent.llm.temperature, + max_loops=original_agent.max_loops, + tags=original_metadata["tags"], + ) + + return await self.create_agent(config) + + async def delete_agent(self, agent_id: UUID) -> None: + """Delete an agent.""" + if agent_id not in self.agents: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + + # Clean up any resources + agent = self.agents[agent_id] + if agent.autosave and os.path.exists(agent.saved_state_path): + os.remove(agent.saved_state_path) + + del self.agents[agent_id] + del self.agent_metadata[agent_id] + logger.info(f"Deleted agent {agent_id}") + + async def process_completion( + self, + agent: Agent, + prompt: str, + agent_id: UUID, + max_tokens: Optional[int] = None, + temperature_override: Optional[float] = None, + ) -> CompletionResponse: + """Process a completion request using the specified agent.""" + start_time = datetime.utcnow() + metadata = self.agent_metadata[agent_id] + + try: + # Update agent status + metadata["status"] = AgentStatus.PROCESSING + metadata["last_used"] = start_time + + # Apply temporary overrides if specified + original_temp = agent.llm.temperature + if temperature_override is not None: + agent.llm.temperature = temperature_override + + # Process the completion + response = agent.run(prompt) + + # Reset overrides + if temperature_override is not None: + agent.llm.temperature = original_temp + + # Update metrics + processing_time = ( + datetime.utcnow() - start_time + ).total_seconds() + metadata["response_times"].append(processing_time) + metadata["total_completions"] += 1 + metadata["successful_completions"] += 1 + + # Estimate token usage (this is a rough estimate) + prompt_tokens = len(prompt.split()) * 1.3 + completion_tokens = len(response.split()) * 1.3 + total_tokens = int(prompt_tokens + completion_tokens) + metadata["total_tokens"] += total_tokens + + # Update tokens per minute tracking + current_minute = datetime.utcnow().replace( + second=0, microsecond=0 + ) + if "tokens_per_minute" not in metadata: + metadata["tokens_per_minute"] = {} + metadata["tokens_per_minute"][current_minute] = ( + metadata["tokens_per_minute"].get(current_minute, 0) + + total_tokens + ) + + return CompletionResponse( + agent_id=agent_id, + response=response, + metadata={ + "agent_name": agent.agent_name, + "model_name": agent.llm.model_name, + "temperature": agent.llm.temperature, + }, + timestamp=datetime.utcnow(), + processing_time=processing_time, + token_usage={ + "prompt_tokens": int(prompt_tokens), + "completion_tokens": int(completion_tokens), + "total_tokens": total_tokens, + }, + ) + + except Exception as e: + metadata["error_count"] += 1 + metadata["status"] = AgentStatus.ERROR + logger.error( + f"Error in completion processing: {str(e)}\n{traceback.format_exc()}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing completion: {str(e)}", + ) + finally: + metadata["status"] = AgentStatus.IDLE + + +class SwarmsAPI: + """Enhanced API class for Swarms agent integration.""" + + def __init__(self): + self.app = FastAPI( + title="Swarms Agent API", + description="Production-grade API for Swarms agent interaction", + version="1.0.0", + docs_url="/v1/docs", + redoc_url="/v1/redoc", + ) + self.store = AgentStore() + # Configure CORS + self.app.add_middleware( + CORSMiddleware, + allow_origins=[ + "*" + ], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + self._setup_routes() + + def _setup_routes(self): + """Set up API routes.""" + + @self.app.post("/v1/agent", response_model=Dict[str, UUID]) + async def create_agent(config: AgentConfig): + """Create a new agent with the specified configuration.""" + agent_id = await self.store.create_agent(config) + return {"agent_id": agent_id} + + @self.app.get("/v1/agents", response_model=List[AgentSummary]) + async def list_agents( + tags: Optional[List[str]] = Query(None), + status: Optional[AgentStatus] = None, + ): + """List all agents, optionally filtered by tags and status.""" + return await self.store.list_agents(tags, status) + + @self.app.patch( + "/v1/agent/{agent_id}", response_model=Dict[str, str] + ) + async def update_agent(agent_id: UUID, update: AgentUpdate): + """Update an existing agent's configuration.""" + await self.store.update_agent(agent_id, update) + return {"status": "updated"} + + @self.app.get( + "/v1/agent/{agent_id}/metrics", + response_model=AgentMetrics, + ) + async def get_agent_metrics(agent_id: UUID): + """Get performance metrics for a specific agent.""" + return await self.store.get_agent_metrics(agent_id) + + @self.app.post( + "/v1/agent/{agent_id}/clone", + response_model=Dict[str, UUID], + ) + async def clone_agent(agent_id: UUID, new_name: str): + """Clone an existing agent with a new name.""" + new_id = await self.store.clone_agent(agent_id, new_name) + return {"agent_id": new_id} + + @self.app.delete("/v1/agent/{agent_id}") + async def delete_agent(agent_id: UUID): + """Delete an agent.""" + await self.store.delete_agent(agent_id) + return {"status": "deleted"} + + @self.app.post( + "/v1/agent/completions", response_model=CompletionResponse + ) + async def create_completion( + request: CompletionRequest, + background_tasks: BackgroundTasks, + ): + """Process a completion request with the specified agent.""" + try: + agent = await self.store.get_agent(request.agent_id) + + # Process completion + response = await self.store.process_completion( + agent, + request.prompt, + request.agent_id, + request.max_tokens, + request.temperature_override, + ) + + # Schedule background cleanup + background_tasks.add_task( + self._cleanup_old_metrics, request.agent_id + ) + + return response + + except Exception as e: + logger.error(f"Error processing completion: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing completion: {str(e)}", + ) + + @self.app.get("/v1/agent/{agent_id}/status") + async def get_agent_status(agent_id: UUID): + """Get the current status of an agent.""" + metadata = self.store.agent_metadata.get(agent_id) + if not metadata: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent {agent_id} not found", + ) + return { + "agent_id": agent_id, + "status": metadata["status"], + "last_used": metadata["last_used"], + "total_completions": metadata["total_completions"], + "error_count": metadata["error_count"], + } + + async def _cleanup_old_metrics(self, agent_id: UUID): + """Clean up old metrics data to prevent memory bloat.""" + metadata = self.store.agent_metadata.get(agent_id) + if metadata: + # Keep only last 24 hours of response times + cutoff = datetime.utcnow() - timedelta(days=1) + metadata["response_times"] = [ + t + for t in metadata["response_times"] + if isinstance(t, (int, float)) + and t > cutoff.timestamp() + ] + + # Clean up old tokens per minute data + if "tokens_per_minute" in metadata: + metadata["tokens_per_minute"] = { + k: v + for k, v in metadata["tokens_per_minute"].items() + if k > cutoff + } + + +def create_app() -> FastAPI: + """Create and configure the FastAPI application.""" + print("create app") + api = SwarmsAPI() + return api.app + + +#if __name__ == "__main__": +if __name__ == '__main__': + #freeze_support() + print("yes in main") + # Configure uvicorn logging + logger.info("API Starting") + + uvicorn.run( + "main:create_app", + host="0.0.0.0", + port=8000, + # reload=True, + # workers=4, + ) +else: + print("not in main") + From 823051a9f4dedabf9b5a0d4f79a553b5ea8b81fd Mon Sep 17 00:00:00 2001 From: mike dupont Date: Sat, 7 Dec 2024 16:01:00 -0500 Subject: [PATCH 13/18] remove the agent api renamed to main --- api/agent_api.py | 632 ----------------------------------------------- 1 file changed, 632 deletions(-) delete mode 100644 api/agent_api.py diff --git a/api/agent_api.py b/api/agent_api.py deleted file mode 100644 index 83d05101..00000000 --- a/api/agent_api.py +++ /dev/null @@ -1,632 +0,0 @@ -import os -from fastapi import ( - FastAPI, - HTTPException, - status, - Query, - BackgroundTasks, -) -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field -from typing import Optional, Dict, Any, List -from loguru import logger -import uvicorn -from datetime import datetime, timedelta -from uuid import UUID, uuid4 -from enum import Enum -from pathlib import Path -from concurrent.futures import ThreadPoolExecutor -import traceback - -from swarms import Agent -from dotenv import load_dotenv - -# Load environment variables -load_dotenv() - -# Configure Loguru -logger.add( - "logs/api_{time}.log", - rotation="500 MB", - retention="10 days", - level="INFO", - format="{time} {level} {message}", - backtrace=True, - diagnose=True, -) - - -class AgentStatus(str, Enum): - """Enum for agent status.""" - - IDLE = "idle" - PROCESSING = "processing" - ERROR = "error" - MAINTENANCE = "maintenance" - - -class AgentConfig(BaseModel): - """Configuration model for creating a new agent.""" - - agent_name: str = Field(..., description="Name of the agent") - model_name: str = Field( - ..., - description="Name of the llm you want to use provided by litellm", - ) - description: str = Field( - default="", description="Description of the agent's purpose" - ) - system_prompt: str = Field( - ..., description="System prompt for the agent" - ) - model_name: str = Field( - default="gpt-4", description="Model name to use" - ) - temperature: float = Field( - default=0.1, - ge=0.0, - le=2.0, - description="Temperature for the model", - ) - max_loops: int = Field( - default=1, ge=1, description="Maximum number of loops" - ) - autosave: bool = Field( - default=True, description="Enable autosave" - ) - dashboard: bool = Field( - default=False, description="Enable dashboard" - ) - verbose: bool = Field( - default=True, description="Enable verbose output" - ) - dynamic_temperature_enabled: bool = Field( - default=True, description="Enable dynamic temperature" - ) - user_name: str = Field( - default="default_user", description="Username for the agent" - ) - retry_attempts: int = Field( - default=1, ge=1, description="Number of retry attempts" - ) - context_length: int = Field( - default=200000, ge=1000, description="Context length" - ) - output_type: str = Field( - default="string", description="Output type (string or json)" - ) - streaming_on: bool = Field( - default=False, description="Enable streaming" - ) - tags: List[str] = Field( - default_factory=list, - description="Tags for categorizing the agent", - ) - - -class AgentUpdate(BaseModel): - """Model for updating agent configuration.""" - - description: Optional[str] = None - system_prompt: Optional[str] = None - temperature: Optional[float] = None - max_loops: Optional[int] = None - tags: Optional[List[str]] = None - status: Optional[AgentStatus] = None - - -class AgentSummary(BaseModel): - """Summary model for agent listing.""" - - agent_id: UUID - agent_name: str - description: str - created_at: datetime - last_used: datetime - total_completions: int - tags: List[str] - status: AgentStatus - - -class AgentMetrics(BaseModel): - """Model for agent performance metrics.""" - - total_completions: int - average_response_time: float - error_rate: float - last_24h_completions: int - total_tokens_used: int - uptime_percentage: float - success_rate: float - peak_tokens_per_minute: int - - -class CompletionRequest(BaseModel): - """Model for completion requests.""" - - prompt: str = Field(..., description="The prompt to process") - agent_id: UUID = Field(..., description="ID of the agent to use") - max_tokens: Optional[int] = Field( - None, description="Maximum tokens to generate" - ) - temperature_override: Optional[float] = None - stream: bool = Field( - default=False, description="Enable streaming response" - ) - - -class CompletionResponse(BaseModel): - """Model for completion responses.""" - - agent_id: UUID - response: str - metadata: Dict[str, Any] - timestamp: datetime - processing_time: float - token_usage: Dict[str, int] - - -class AgentStore: - """Enhanced store for managing agents.""" - - def __init__(self): - self.agents: Dict[UUID, Agent] = {} - self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} - self.executor = ThreadPoolExecutor(max_workers=4) - self._ensure_directories() - - def _ensure_directories(self): - """Ensure required directories exist.""" - Path("logs").mkdir(exist_ok=True) - Path("states").mkdir(exist_ok=True) - - async def create_agent(self, config: AgentConfig) -> UUID: - """Create a new agent with the given configuration.""" - try: - - agent = Agent( - agent_name=config.agent_name, - system_prompt=config.system_prompt, - model_name=config.model_name, - max_loops=config.max_loops, - autosave=config.autosave, - dashboard=config.dashboard, - verbose=config.verbose, - dynamic_temperature_enabled=config.dynamic_temperature_enabled, - saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", - user_name=config.user_name, - retry_attempts=config.retry_attempts, - context_length=config.context_length, - return_step_meta=True, - output_type="str", - streaming_on=config.streaming_on, - ) - - agent_id = uuid4() - self.agents[agent_id] = agent - self.agent_metadata[agent_id] = { - "description": config.description, - "created_at": datetime.utcnow(), - "last_used": datetime.utcnow(), - "total_completions": 0, - "tags": config.tags, - "total_tokens": 0, - "error_count": 0, - "response_times": [], - "status": AgentStatus.IDLE, - "start_time": datetime.utcnow(), - "downtime": timedelta(), - "successful_completions": 0, - } - - logger.info(f"Created agent with ID: {agent_id}") - return agent_id - - except Exception as e: - logger.error(f"Error creating agent: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create agent: {str(e)}", - ) - - async def get_agent(self, agent_id: UUID) -> Agent: - """Retrieve an agent by ID.""" - agent = self.agents.get(agent_id) - if not agent: - logger.error(f"Agent not found: {agent_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Agent {agent_id} not found", - ) - return agent - - async def update_agent( - self, agent_id: UUID, update: AgentUpdate - ) -> None: - """Update agent configuration.""" - agent = await self.get_agent(agent_id) - metadata = self.agent_metadata[agent_id] - - if update.system_prompt: - agent.system_prompt = update.system_prompt - if update.temperature is not None: - agent.llm.temperature = update.temperature - if update.max_loops is not None: - agent.max_loops = update.max_loops - if update.tags is not None: - metadata["tags"] = update.tags - if update.description is not None: - metadata["description"] = update.description - if update.status is not None: - metadata["status"] = update.status - if update.status == AgentStatus.MAINTENANCE: - metadata["downtime"] += ( - datetime.utcnow() - metadata["last_used"] - ) - - logger.info(f"Updated agent {agent_id}") - - async def list_agents( - self, - tags: Optional[List[str]] = None, - status: Optional[AgentStatus] = None, - ) -> List[AgentSummary]: - """List all agents, optionally filtered by tags and status.""" - summaries = [] - for agent_id, agent in self.agents.items(): - metadata = self.agent_metadata[agent_id] - - # Apply filters - if tags and not any( - tag in metadata["tags"] for tag in tags - ): - continue - if status and metadata["status"] != status: - continue - - summaries.append( - AgentSummary( - agent_id=agent_id, - agent_name=agent.agent_name, - description=metadata["description"], - created_at=metadata["created_at"], - last_used=metadata["last_used"], - total_completions=metadata["total_completions"], - tags=metadata["tags"], - status=metadata["status"], - ) - ) - return summaries - - async def get_agent_metrics(self, agent_id: UUID) -> AgentMetrics: - """Get performance metrics for an agent.""" - metadata = self.agent_metadata[agent_id] - response_times = metadata["response_times"] - - # Calculate metrics - total_time = datetime.utcnow() - metadata["start_time"] - uptime = total_time - metadata["downtime"] - uptime_percentage = ( - uptime.total_seconds() / total_time.total_seconds() - ) * 100 - - success_rate = ( - metadata["successful_completions"] - / metadata["total_completions"] - * 100 - if metadata["total_completions"] > 0 - else 0 - ) - - return AgentMetrics( - total_completions=metadata["total_completions"], - average_response_time=( - sum(response_times) / len(response_times) - if response_times - else 0 - ), - error_rate=( - metadata["error_count"] - / metadata["total_completions"] - if metadata["total_completions"] > 0 - else 0 - ), - last_24h_completions=sum( - 1 - for t in response_times - if (datetime.utcnow() - t).days < 1 - ), - total_tokens_used=metadata["total_tokens"], - uptime_percentage=uptime_percentage, - success_rate=success_rate, - peak_tokens_per_minute=max( - metadata.get("tokens_per_minute", [0]) - ), - ) - - async def clone_agent( - self, agent_id: UUID, new_name: str - ) -> UUID: - """Clone an existing agent with a new name.""" - original_agent = await self.get_agent(agent_id) - original_metadata = self.agent_metadata[agent_id] - - config = AgentConfig( - agent_name=new_name, - description=f"Clone of {original_agent.agent_name}", - system_prompt=original_agent.system_prompt, - model_name=original_agent.llm.model_name, - temperature=original_agent.llm.temperature, - max_loops=original_agent.max_loops, - tags=original_metadata["tags"], - ) - - return await self.create_agent(config) - - async def delete_agent(self, agent_id: UUID) -> None: - """Delete an agent.""" - if agent_id not in self.agents: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Agent {agent_id} not found", - ) - - # Clean up any resources - agent = self.agents[agent_id] - if agent.autosave and os.path.exists(agent.saved_state_path): - os.remove(agent.saved_state_path) - - del self.agents[agent_id] - del self.agent_metadata[agent_id] - logger.info(f"Deleted agent {agent_id}") - - async def process_completion( - self, - agent: Agent, - prompt: str, - agent_id: UUID, - max_tokens: Optional[int] = None, - temperature_override: Optional[float] = None, - ) -> CompletionResponse: - """Process a completion request using the specified agent.""" - start_time = datetime.utcnow() - metadata = self.agent_metadata[agent_id] - - try: - # Update agent status - metadata["status"] = AgentStatus.PROCESSING - metadata["last_used"] = start_time - - # Apply temporary overrides if specified - original_temp = agent.llm.temperature - if temperature_override is not None: - agent.llm.temperature = temperature_override - - # Process the completion - response = agent.run(prompt) - - # Reset overrides - if temperature_override is not None: - agent.llm.temperature = original_temp - - # Update metrics - processing_time = ( - datetime.utcnow() - start_time - ).total_seconds() - metadata["response_times"].append(processing_time) - metadata["total_completions"] += 1 - metadata["successful_completions"] += 1 - - # Estimate token usage (this is a rough estimate) - prompt_tokens = len(prompt.split()) * 1.3 - completion_tokens = len(response.split()) * 1.3 - total_tokens = int(prompt_tokens + completion_tokens) - metadata["total_tokens"] += total_tokens - - # Update tokens per minute tracking - current_minute = datetime.utcnow().replace( - second=0, microsecond=0 - ) - if "tokens_per_minute" not in metadata: - metadata["tokens_per_minute"] = {} - metadata["tokens_per_minute"][current_minute] = ( - metadata["tokens_per_minute"].get(current_minute, 0) - + total_tokens - ) - - return CompletionResponse( - agent_id=agent_id, - response=response, - metadata={ - "agent_name": agent.agent_name, - "model_name": agent.llm.model_name, - "temperature": agent.llm.temperature, - }, - timestamp=datetime.utcnow(), - processing_time=processing_time, - token_usage={ - "prompt_tokens": int(prompt_tokens), - "completion_tokens": int(completion_tokens), - "total_tokens": total_tokens, - }, - ) - - except Exception as e: - metadata["error_count"] += 1 - metadata["status"] = AgentStatus.ERROR - logger.error( - f"Error in completion processing: {str(e)}\n{traceback.format_exc()}" - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error processing completion: {str(e)}", - ) - finally: - metadata["status"] = AgentStatus.IDLE - - -class SwarmsAPI: - """Enhanced API class for Swarms agent integration.""" - - def __init__(self): - self.app = FastAPI( - title="Swarms Agent API", - description="Production-grade API for Swarms agent interaction", - version="1.0.0", - docs_url="/v1/docs", - redoc_url="/v1/redoc", - ) - self.store = AgentStore() - # Configure CORS - self.app.add_middleware( - CORSMiddleware, - allow_origins=[ - "*" - ], # Configure appropriately for production - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - self._setup_routes() - - def _setup_routes(self): - """Set up API routes.""" - - @self.app.post("/v1/agent", response_model=Dict[str, UUID]) - async def create_agent(config: AgentConfig): - """Create a new agent with the specified configuration.""" - agent_id = await self.store.create_agent(config) - return {"agent_id": agent_id} - - @self.app.get("/v1/agents", response_model=List[AgentSummary]) - async def list_agents( - tags: Optional[List[str]] = Query(None), - status: Optional[AgentStatus] = None, - ): - """List all agents, optionally filtered by tags and status.""" - return await self.store.list_agents(tags, status) - - @self.app.patch( - "/v1/agent/{agent_id}", response_model=Dict[str, str] - ) - async def update_agent(agent_id: UUID, update: AgentUpdate): - """Update an existing agent's configuration.""" - await self.store.update_agent(agent_id, update) - return {"status": "updated"} - - @self.app.get( - "/v1/agent/{agent_id}/metrics", - response_model=AgentMetrics, - ) - async def get_agent_metrics(agent_id: UUID): - """Get performance metrics for a specific agent.""" - return await self.store.get_agent_metrics(agent_id) - - @self.app.post( - "/v1/agent/{agent_id}/clone", - response_model=Dict[str, UUID], - ) - async def clone_agent(agent_id: UUID, new_name: str): - """Clone an existing agent with a new name.""" - new_id = await self.store.clone_agent(agent_id, new_name) - return {"agent_id": new_id} - - @self.app.delete("/v1/agent/{agent_id}") - async def delete_agent(agent_id: UUID): - """Delete an agent.""" - await self.store.delete_agent(agent_id) - return {"status": "deleted"} - - @self.app.post( - "/v1/agent/completions", response_model=CompletionResponse - ) - async def create_completion( - request: CompletionRequest, - background_tasks: BackgroundTasks, - ): - """Process a completion request with the specified agent.""" - try: - agent = await self.store.get_agent(request.agent_id) - - # Process completion - response = await self.store.process_completion( - agent, - request.prompt, - request.agent_id, - request.max_tokens, - request.temperature_override, - ) - - # Schedule background cleanup - background_tasks.add_task( - self._cleanup_old_metrics, request.agent_id - ) - - return response - - except Exception as e: - logger.error(f"Error processing completion: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error processing completion: {str(e)}", - ) - - @self.app.get("/v1/agent/{agent_id}/status") - async def get_agent_status(agent_id: UUID): - """Get the current status of an agent.""" - metadata = self.store.agent_metadata.get(agent_id) - if not metadata: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Agent {agent_id} not found", - ) - return { - "agent_id": agent_id, - "status": metadata["status"], - "last_used": metadata["last_used"], - "total_completions": metadata["total_completions"], - "error_count": metadata["error_count"], - } - - async def _cleanup_old_metrics(self, agent_id: UUID): - """Clean up old metrics data to prevent memory bloat.""" - metadata = self.store.agent_metadata.get(agent_id) - if metadata: - # Keep only last 24 hours of response times - cutoff = datetime.utcnow() - timedelta(days=1) - metadata["response_times"] = [ - t - for t in metadata["response_times"] - if isinstance(t, (int, float)) - and t > cutoff.timestamp() - ] - - # Clean up old tokens per minute data - if "tokens_per_minute" in metadata: - metadata["tokens_per_minute"] = { - k: v - for k, v in metadata["tokens_per_minute"].items() - if k > cutoff - } - - -def create_app() -> FastAPI: - """Create and configure the FastAPI application.""" - api = SwarmsAPI() - return api.app - - -if __name__ == "__main__": - # Configure uvicorn logging - print("in main") - logger.info("API Starting") - uvicorn.run( - "main:create_app", - host="0.0.0.0", - port=8000, - reload=True, - workers=4, - ) -else: - print("not in main") From dc4ff7df4528286de3ce641ca388b0b041419643 Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:45:44 -0800 Subject: [PATCH 14/18] Create requirements.txt --- api/requirements.txt | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 api/requirements.txt diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 00000000..4bd48f33 --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,6 @@ +fastapi +uvicorn +pydantic +loguru +python-dotenv +swarms # Specify the version or source if it's not on PyPI From e6e989de275b7f21040cabcfdc4d8690b5507335 Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:48:37 -0800 Subject: [PATCH 15/18] Create skypilot.yaml --- api/skypilot.yaml | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 api/skypilot.yaml diff --git a/api/skypilot.yaml b/api/skypilot.yaml new file mode 100644 index 00000000..8cd25d90 --- /dev/null +++ b/api/skypilot.yaml @@ -0,0 +1,41 @@ +service: + readiness_probe: + path: /docs + initial_delay_seconds: 300 + timeout_seconds: 30 + + replica_policy: + min_replicas: 1 + max_replicas: 50 + target_qps_per_replica: 5 + upscale_delay_seconds: 180 + downscale_delay_seconds: 600 + +resources: + ports: 8000 # FastAPI default port + cpus: 16 + memory: 64 + disk_size: 100 + use_spot: true + +workdir: /app + +setup: | + git clone https://github.com/kyegomez/swarms.git + cd swarms/api + pip install -r requirements.txt + pip install swarms + +run: | + cd swarms/api + uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4 + +# env: +# PYTHONPATH: /app/swarms +# LOG_LEVEL: "INFO" +# # MAX_WORKERS: "4" + +# metadata: +# name: swarms-api-service +# version: "1.0.0" +# environment: production From 770b4a15fd52a9b36e4aa3e7aa5836aca20c307f Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:00:51 -0800 Subject: [PATCH 16/18] Update auto_swarm_builder.py --- swarms/structs/auto_swarm_builder.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/swarms/structs/auto_swarm_builder.py b/swarms/structs/auto_swarm_builder.py index 93e542fd..16e1f5b9 100644 --- a/swarms/structs/auto_swarm_builder.py +++ b/swarms/structs/auto_swarm_builder.py @@ -50,13 +50,11 @@ class SwarmConfig(BaseModel): name="Research-Agent", description="Gathers information", system_prompt="You are a research agent...", - max_loops=2, ), AgentConfig( name="Writing-Agent", description="Writes content", system_prompt="You are a writing agent...", - max_loops=1, ), ], ) @@ -195,7 +193,7 @@ class AutoSwarmBuilder: self.name = agents_dictionary.name self.description = agents_dictionary.description self.max_loops = getattr( - agents_dictionary, "max_loops", 1 + agents_dictionary ) # Default to 1 if not set logger.info( @@ -213,7 +211,6 @@ class AutoSwarmBuilder: agent_name=agent_config.name, agent_description=agent_config.description, agent_system_prompt=agent_config.system_prompt, - # max_loops=agent_config.max_loops, ) agents.append(agent) From a564fd27e4c7f8fc3f7947009974fbe5e7be0c21 Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:46:44 -0800 Subject: [PATCH 17/18] Update main.py --- api/main.py | 288 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 240 insertions(+), 48 deletions(-) diff --git a/api/main.py b/api/main.py index 768e8d96..cfc5e1b2 100644 --- a/api/main.py +++ b/api/main.py @@ -1,41 +1,34 @@ import os +import secrets +import traceback +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional +from uuid import UUID, uuid4 + +import uvicorn +from dotenv import load_dotenv from fastapi import ( + BackgroundTasks, + Depends, FastAPI, + Header, HTTPException, - status, Query, - BackgroundTasks, + Request, + status, ) from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field -from typing import Optional, Dict, Any, List from loguru import logger -import uvicorn -from datetime import datetime, timedelta -from uuid import UUID, uuid4 -from enum import Enum -from pathlib import Path -from concurrent.futures import ThreadPoolExecutor -import traceback +from pydantic import BaseModel, Field from swarms import Agent -from dotenv import load_dotenv -print ("starting") # Load environment variables load_dotenv() -# Configure Loguru -logger.add( - "logs/api_{time}.log", - rotation="500 MB", - retention="10 days", - level="INFO", - format="{time} {level} {message}", - backtrace=True, - diagnose=True, -) - class AgentStatus(str, Enum): """Enum for agent status.""" @@ -44,6 +37,28 @@ class AgentStatus(str, Enum): PROCESSING = "processing" ERROR = "error" MAINTENANCE = "maintenance" + + +# Security configurations +API_KEY_LENGTH = 32 # Length of generated API keys + +class APIKey(BaseModel): + key: str + name: str + created_at: datetime + last_used: datetime + is_active: bool = True + +class APIKeyCreate(BaseModel): + name: str # A friendly name for the API key + +class User(BaseModel): + id: UUID + username: str + is_active: bool = True + is_admin: bool = False + api_keys: Dict[str, APIKey] = {} # key -> APIKey object + class AgentConfig(BaseModel): @@ -105,6 +120,7 @@ class AgentConfig(BaseModel): ) + class AgentUpdate(BaseModel): """Model for updating agent configuration.""" @@ -173,6 +189,9 @@ class AgentStore: def __init__(self): self.agents: Dict[UUID, Agent] = {} self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} + self.users: Dict[UUID, User] = {} # user_id -> User + self.api_keys: Dict[str, UUID] = {} # api_key -> user_id + self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids] self.executor = ThreadPoolExecutor(max_workers=4) self._ensure_directories() @@ -180,8 +199,56 @@ class AgentStore: """Ensure required directories exist.""" Path("logs").mkdir(exist_ok=True) Path("states").mkdir(exist_ok=True) + + def create_api_key(self, user_id: UUID, key_name: str) -> APIKey: + """Create a new API key for a user.""" + if user_id not in self.users: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) - async def create_agent(self, config: AgentConfig) -> UUID: + # Generate a secure random API key + api_key = secrets.token_urlsafe(API_KEY_LENGTH) + + # Create the API key object + key_object = APIKey( + key=api_key, + name=key_name, + created_at=datetime.utcnow(), + last_used=datetime.utcnow() + ) + + # Store the API key + self.users[user_id].api_keys[api_key] = key_object + self.api_keys[api_key] = user_id + + return key_object + + async def verify_agent_access(self, agent_id: UUID, user_id: UUID) -> bool: + """Verify if a user has access to an agent.""" + if agent_id not in self.agents: + return False + return ( + self.agent_metadata[agent_id]["owner_id"] == user_id + or self.users[user_id].is_admin + ) + + def validate_api_key(self, api_key: str) -> Optional[UUID]: + """Validate an API key and return the associated user ID.""" + user_id = self.api_keys.get(api_key) + if not user_id or api_key not in self.users[user_id].api_keys: + return None + + key_object = self.users[user_id].api_keys[api_key] + if not key_object.is_active: + return None + + # Update last used timestamp + key_object.last_used = datetime.utcnow() + return user_id + + async def create_agent(self, config: AgentConfig, user_id: UUID) -> UUID: """Create a new agent with the given configuration.""" try: @@ -220,7 +287,11 @@ class AgentStore: "successful_completions": 0, } - logger.info(f"Created agent with ID: {agent_id}") + # Add to user's agents list + if user_id not in self.user_agents: + self.user_agents[user_id] = [] + self.user_agents[user_id].append(agent_id) + return agent_id except Exception as e: @@ -465,6 +536,35 @@ class AgentStore: finally: metadata["status"] = AgentStatus.IDLE +class StoreManager: + _instance = None + + @classmethod + def get_instance(cls) -> 'AgentStore': + if cls._instance is None: + cls._instance = AgentStore() + return cls._instance + +# Modify the dependency function +def get_store() -> AgentStore: + """Dependency to get the AgentStore instance.""" + return StoreManager.get_instance() + +# Security utility function using the new dependency +async def get_current_user( + api_key: str = Header(..., description="API key for authentication"), + store: AgentStore = Depends(get_store) +) -> User: + """Validate API key and return current user.""" + user_id = store.validate_api_key(api_key) + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + return store.users[user_id] + class SwarmsAPI: """Enhanced API class for Swarms agent integration.""" @@ -477,7 +577,9 @@ class SwarmsAPI: docs_url="/v1/docs", redoc_url="/v1/redoc", ) - self.store = AgentStore() + # Initialize the store using the singleton manager + self.store = StoreManager.get_instance() + # Configure CORS self.app.add_middleware( CORSMiddleware, @@ -493,11 +595,102 @@ class SwarmsAPI: def _setup_routes(self): """Set up API routes.""" + + # In your API code + @self.app.post("/v1/users", response_model=Dict[str, Any]) + async def create_user(request: Request): + """Create a new user and initial API key.""" + try: + body = await request.json() + username = body.get("username") + if not username or len(username) < 3: + raise HTTPException(status_code=400, detail="Invalid username") + + user_id = uuid4() + user = User(id=user_id, username=username) + self.store.users[user_id] = user + initial_key = self.store.create_api_key(user_id, "Initial Key") + return {"user_id": user_id, "api_key": initial_key.key} + except Exception as e: + logger.error(f"Error creating user: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + + + + @self.app.post("/v1/users/{user_id}/api-keys", response_model=APIKey) + async def create_api_key( + user_id: UUID, + key_create: APIKeyCreate, + current_user: User = Depends(get_current_user) + ): + """Create a new API key for a user.""" + if current_user.id != user_id and not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to create API keys for this user" + ) + + return self.store.create_api_key(user_id, key_create.name) + @self.app.get("/v1/users/{user_id}/api-keys", response_model=List[APIKey]) + async def list_api_keys( + user_id: UUID, + current_user: User = Depends(get_current_user) + ): + """List all API keys for a user.""" + if current_user.id != user_id and not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view API keys for this user" + ) + + return list(self.store.users[user_id].api_keys.values()) + + @self.app.delete("/v1/users/{user_id}/api-keys/{key}") + async def revoke_api_key( + user_id: UUID, + key: str, + current_user: User = Depends(get_current_user) + ): + """Revoke an API key.""" + if current_user.id != user_id and not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to revoke API keys for this user" + ) + + if key in self.store.users[user_id].api_keys: + self.store.users[user_id].api_keys[key].is_active = False + del self.store.api_keys[key] + return {"status": "API key revoked"} + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="API key not found" + ) + + @self.app.get("/v1/users/me/agents", response_model=List[AgentSummary]) + async def list_user_agents( + current_user: User = Depends(get_current_user), + tags: Optional[List[str]] = Query(None), + status: Optional[AgentStatus] = None, + ): + """List all agents owned by the current user.""" + user_agents = self.store.user_agents.get(current_user.id, []) + return [ + agent for agent in await self.store.list_agents(tags, status) + if agent.agent_id in user_agents + ] + + + # Modify existing routes to use API key authentication @self.app.post("/v1/agent", response_model=Dict[str, UUID]) - async def create_agent(config: AgentConfig): + async def create_agent( + config: AgentConfig, + current_user: User = Depends(get_current_user) + ): """Create a new agent with the specified configuration.""" - agent_id = await self.store.create_agent(config) + agent_id = await self.store.create_agent(config, current_user.id) return {"agent_id": agent_id} @self.app.get("/v1/agents", response_model=List[AgentSummary]) @@ -611,28 +804,27 @@ class SwarmsAPI: if k > cutoff } - def create_app() -> FastAPI: """Create and configure the FastAPI application.""" - print("create app") + logger.info("Creating FastAPI application") api = SwarmsAPI() - return api.app + app = api.app + logger.info("FastAPI application created successfully") + return app +app = create_app() -#if __name__ == "__main__": if __name__ == '__main__': - #freeze_support() - print("yes in main") - # Configure uvicorn logging - logger.info("API Starting") - - uvicorn.run( - "main:create_app", - host="0.0.0.0", - port=8000, - # reload=True, - # workers=4, - ) -else: - print("not in main") - + try: + logger.info("Starting API server...") + print("Starting API server on http://0.0.0.0:8000") + + uvicorn.run( + app, # Pass the app instance directly + host="0.0.0.0", + port=8000, + log_level="info" + ) + except Exception as e: + logger.error(f"Failed to start API: {str(e)}") + print(f"Error starting server: {str(e)}") From 5ed5af20a7fbf9f74a070e8513466403695bc6aa Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:46:58 -0800 Subject: [PATCH 18/18] Update agent_api_test.py --- api/agent_api_test.py | 333 +++++++++++++++++++++++++++++++++--------- 1 file changed, 260 insertions(+), 73 deletions(-) diff --git a/api/agent_api_test.py b/api/agent_api_test.py index 066efc4f..2ad4e059 100644 --- a/api/agent_api_test.py +++ b/api/agent_api_test.py @@ -1,107 +1,294 @@ import requests from loguru import logger import time - -# Configure loguru -logger.add( - "api_tests_{time}.log", - rotation="100 MB", - level="DEBUG", - format="{time} {level} {message}", -) +from typing import Dict, Optional, Tuple +from uuid import UUID +from datetime import datetime +import sys BASE_URL = "http://localhost:8000/v1" +def check_api_server() -> bool: + """Check if the API server is running and accessible.""" + try: + response = requests.get(f"{BASE_URL}/docs") + return response.status_code == 200 + except requests.exceptions.ConnectionError: + logger.error("API server is not running at {BASE_URL}") + logger.error("Please start the API server first with:") + logger.error(" python main.py") + return False + except Exception as e: + logger.error(f"Error checking API server: {str(e)}") + return False + +class TestSession: + """Manages test session state and authentication.""" + + def __init__(self): + self.user_id: Optional[UUID] = None + self.api_key: Optional[str] = None + self.test_agents: list[UUID] = [] + + @property + def headers(self) -> Dict[str, str]: + """Get headers with authentication.""" + return {"api-key": self.api_key} if self.api_key else {} + +def create_test_user(session: TestSession) -> Tuple[bool, str]: + """Create a test user and store credentials in session.""" + logger.info("Creating test user") + + try: + response = requests.post( + f"{BASE_URL}/users", + json={"username": f"test_user_{int(time.time())}"} + ) + + if response.status_code == 200: + data = response.json() + session.user_id = data["user_id"] + session.api_key = data["api_key"] + logger.success(f"Created user with ID: {session.user_id}") + return True, "Success" + else: + logger.error(f"Failed to create user: {response.text}") + return False, response.text + except Exception as e: + logger.exception("Exception during user creation") + return False, str(e) + +def create_additional_api_key(session: TestSession) -> Tuple[bool, str]: + """Test creating an additional API key.""" + logger.info("Creating additional API key") + + try: + response = requests.post( + f"{BASE_URL}/users/{session.user_id}/api-keys", + headers=session.headers, + json={"name": "Test Key"} + ) + + if response.status_code == 200: + logger.success("Created additional API key") + return True, response.json()["key"] + else: + logger.error(f"Failed to create API key: {response.text}") + return False, response.text + except Exception as e: + logger.exception("Exception during API key creation") + return False, str(e) -def test_create_agent(): +def test_create_agent(session: TestSession) -> Tuple[bool, Optional[UUID]]: """Test creating a new agent.""" logger.info("Testing agent creation") payload = { - "agent_name": "Test Agent", + "agent_name": f"Test Agent {int(time.time())}", "system_prompt": "You are a helpful assistant", "model_name": "gpt-4", "description": "Test agent", - "tags": ["test"], + "tags": ["test", "automated"] } - response = requests.post(f"{BASE_URL}/agent", json=payload) - logger.debug(f"Create response: {response.json()}") + try: + response = requests.post( + f"{BASE_URL}/agent", + headers=session.headers, + json=payload + ) + + if response.status_code == 200: + agent_id = response.json()["agent_id"] + session.test_agents.append(agent_id) + logger.success(f"Created agent with ID: {agent_id}") + return True, agent_id + else: + logger.error(f"Failed to create agent: {response.text}") + return False, None + except Exception as e: + logger.exception("Exception during agent creation") + return False, None - if response.status_code == 200: - logger.success("Successfully created agent") - return response.json()["agent_id"] - else: - logger.error(f"Failed to create agent: {response.text}") - return None +def test_list_user_agents(session: TestSession) -> bool: + """Test listing user's agents.""" + logger.info("Testing user agent listing") + try: + response = requests.get( + f"{BASE_URL}/users/me/agents", + headers=session.headers + ) + + if response.status_code == 200: + agents = response.json() + logger.success(f"Found {len(agents)} user agents") + return True + else: + logger.error(f"Failed to list user agents: {response.text}") + return False + except Exception as e: + logger.exception("Exception during agent listing") + return False -def test_list_agents(): - """Test listing all agents.""" - logger.info("Testing agent listing") +def test_agent_operations(session: TestSession, agent_id: UUID) -> bool: + """Test various operations on an agent.""" + logger.info(f"Testing operations for agent {agent_id}") + + # Test update + try: + update_response = requests.patch( + f"{BASE_URL}/agent/{agent_id}", + headers=session.headers, + json={ + "description": "Updated description", + "tags": ["test", "updated"] + } + ) + if update_response.status_code != 200: + logger.error(f"Failed to update agent: {update_response.text}") + return False + + # Test metrics + metrics_response = requests.get( + f"{BASE_URL}/agent/{agent_id}/metrics", + headers=session.headers + ) + if metrics_response.status_code != 200: + logger.error(f"Failed to get agent metrics: {metrics_response.text}") + return False + + logger.success("Successfully performed agent operations") + return True + except Exception as e: + logger.exception("Exception during agent operations") + return False - response = requests.get(f"{BASE_URL}/agents") - logger.debug(f"List response: {response.json()}") - - if response.status_code == 200: - logger.success(f"Found {len(response.json())} agents") - else: - logger.error(f"Failed to list agents: {response.text}") - - -def test_completion(agent_id): +def test_completion(session: TestSession, agent_id: UUID) -> bool: """Test running a completion.""" logger.info("Testing completion") payload = { "prompt": "What is the weather like today?", "agent_id": agent_id, + "max_tokens": 100 } - response = requests.post( - f"{BASE_URL}/agent/completions", json=payload - ) - logger.debug(f"Completion response: {response.json()}") - - if response.status_code == 200: - logger.success("Successfully got completion") - else: - logger.error(f"Failed to get completion: {response.text}") + try: + response = requests.post( + f"{BASE_URL}/agent/completions", + headers=session.headers, + json=payload + ) + + if response.status_code == 200: + completion_data = response.json() + logger.success( + f"Got completion, used {completion_data['token_usage']['total_tokens']} tokens" + ) + return True + else: + logger.error(f"Failed to get completion: {response.text}") + return False + except Exception as e: + logger.exception("Exception during completion") + return False +def cleanup_test_resources(session: TestSession): + """Clean up all test resources.""" + logger.info("Cleaning up test resources") + + # Delete test agents + for agent_id in session.test_agents: + try: + response = requests.delete( + f"{BASE_URL}/agent/{agent_id}", + headers=session.headers + ) + if response.status_code == 200: + logger.debug(f"Deleted agent {agent_id}") + else: + logger.warning(f"Failed to delete agent {agent_id}: {response.text}") + except Exception as e: + logger.exception(f"Exception deleting agent {agent_id}") -def test_delete_agent(agent_id): - """Test deleting an agent.""" - logger.info("Testing agent deletion") + # Revoke API keys + if session.user_id: + try: + response = requests.get( + f"{BASE_URL}/users/{session.user_id}/api-keys", + headers=session.headers + ) + if response.status_code == 200: + for key in response.json(): + try: + revoke_response = requests.delete( + f"{BASE_URL}/users/{session.user_id}/api-keys/{key['key']}", + headers=session.headers + ) + if revoke_response.status_code == 200: + logger.debug(f"Revoked API key {key['name']}") + else: + logger.warning(f"Failed to revoke API key {key['name']}") + except Exception as e: + logger.exception(f"Exception revoking API key {key['name']}") + except Exception as e: + logger.exception("Exception getting API keys for cleanup") - response = requests.delete(f"{BASE_URL}/agent/{agent_id}") - logger.debug(f"Delete response: {response.json()}") - - if response.status_code == 200: - logger.success("Successfully deleted agent") - else: - logger.error(f"Failed to delete agent: {response.text}") - - -def run_tests(): - """Run all tests in sequence.""" +def run_test_workflow(): + """Run complete test workflow.""" logger.info("Starting API tests") - - # Create agent and get ID - agent_id = test_create_agent() - if not agent_id: - logger.error("Cannot continue tests without agent ID") - return - - # Wait a bit for agent to be ready - time.sleep(1) - - # Run other tests - test_list_agents() - test_completion(agent_id) - test_delete_agent(agent_id) - - logger.info("Tests completed") - + + # Check if API server is running first + if not check_api_server(): + return False + + session = TestSession() + success = True + + try: + # Create user + user_success, message = create_test_user(session) + if not user_success: + logger.error(f"User creation failed: {message}") + return False + + # Create additional API key + key_success, key = create_additional_api_key(session) + if not key_success: + logger.error(f"API key creation failed: {key}") + return False + + # Create agent + agent_success, agent_id = test_create_agent(session) + if not agent_success or not agent_id: + logger.error("Agent creation failed") + return False + + # Test user agent listing + if not test_list_user_agents(session): + logger.error("Agent listing failed") + return False + + # Test agent operations + if not test_agent_operations(session, agent_id): + logger.error("Agent operations failed") + return False + + # Test completion + if not test_completion(session, agent_id): + logger.error("Completion test failed") + return False + + logger.success("All tests completed successfully") + return True + + except Exception as e: + logger.exception("Exception during test workflow") + return False + finally: + cleanup_test_resources(session) if __name__ == "__main__": - run_tests() + success = run_test_workflow() + sys.exit(0 if success else 1)