From 9c14a59ab7655da991cd3f7de63d35bc00ca3027 Mon Sep 17 00:00:00 2001
From: Kye <kye@apacmediasolutions.com>
Date: Fri, 15 Sep 2023 11:45:04 -0400
Subject: [PATCH] `InMemoryTaskDB`

Former-commit-id: 04708f18bfb121b19c7f43ae91a1bf18943d4938
---
 swarms/agents/base.py  |  68 ++++++++++------
 swarms/memory/db.py    | 172 +++++++++++++++++++++++++++++++++++++++++
 swarms/structs/task.py |   8 +-
 3 files changed, 223 insertions(+), 25 deletions(-)
 create mode 100644 swarms/memory/db.py

diff --git a/swarms/agents/base.py b/swarms/agents/base.py
index 135f0b2b..9e0a641c 100644
--- a/swarms/agents/base.py
+++ b/swarms/agents/base.py
@@ -2,26 +2,50 @@ from abc import ABC, abstractmethod
 from agent_protocol import Agent, Step, Task
 
 
-class AbstractAgent(ABC):
-    #absrtact agent class
-    
-    @classmethod
-    def __init__(
-            self,
-            ai_name: str = None,
-            ai_role: str = None,
-            memory = None,
-            tools = None,
-            llm = None,
-            human_in_the_loop=None,
-            output_parser = None,
-            chat_history_memory=None,
-            *args,
-            **kwargs
-    ):
-        pass
-
-    @abstractmethod
-    def run(self, goals=None):
-        pass
+class AbstractAgent:
+    @staticmethod
+    async def plan(step: Step) -> Step:
+        task = await Agent.db.get_task(step.task_id)
+        steps = generate_steps(task.input)
 
+        last_step = steps[-1]
+        for step in steps[:-1]:
+            await Agent.db.create_step(
+                task_id=task.task_id, 
+                name=step, 
+                pass
+            )
+
+        await Agent.db.create_step(
+            task_id=task.task_id, 
+            name=last_step, 
+            is_last=True
+        )
+        step.output = steps
+        return step
+
+    @staticmethod
+    async def execute(step: Step) -> Step:
+        # Use tools, websearch, etc.
+        ...
+
+    @staticmethod
+    async def task_handler(task: Task) -> None:
+        await Agent.db.create_step(
+            task_id=task.task_id, 
+            name="plan", 
+            pass
+        )
+
+    @staticmethod
+    async def step_handler(step: Step) -> Step:
+        if step.name == "plan":
+            await AbstractAgent.plan(step)
+        else:
+            await AbstractAgent.execute(step)
+
+        return step
+
+    @staticmethod
+    def start_agent():
+        Agent.setup_agent(AbstractAgent.task_handler, AbstractAgent.step_handler).start()
\ No newline at end of file
diff --git a/swarms/memory/db.py b/swarms/memory/db.py
new file mode 100644
index 00000000..728fa4fd
--- /dev/null
+++ b/swarms/memory/db.py
@@ -0,0 +1,172 @@
+import uuid
+from abc import ABC
+from typing import Dict, List, Optional, Any
+from .models import Task as APITask, Step as APIStep, Artifact, Status
+
+
+class Step(APIStep):
+    additional_properties: Optional[Dict[str, str]] = None
+
+
+class Task(APITask):
+    steps: List[Step] = []
+
+
+class NotFoundException(Exception):
+    """
+    Exception raised when a resource is not found.
+    """
+
+    def __init__(self, item_name: str, item_id: str):
+        self.item_name = item_name
+        self.item_id = item_id
+        super().__init__(f"{item_name} with {item_id} not found.")
+
+
+class TaskDB(ABC):
+    async def create_task(
+        self,
+        input: Optional[str],
+        additional_input: Any = None,
+        artifacts: Optional[List[Artifact]] = None,
+        steps: Optional[List[Step]] = None,
+    ) -> Task:
+        raise NotImplementedError
+
+    async def create_step(
+        self,
+        task_id: str,
+        name: Optional[str] = None,
+        input: Optional[str] = None,
+        is_last: bool = False,
+        additional_properties: Optional[Dict[str, str]] = None,
+    ) -> Step:
+        raise NotImplementedError
+
+    async def create_artifact(
+        self,
+        task_id: str,
+        file_name: str,
+        relative_path: Optional[str] = None,
+        step_id: Optional[str] = None,
+    ) -> Artifact:
+        raise NotImplementedError
+
+    async def get_task(self, task_id: str) -> Task:
+        raise NotImplementedError
+
+    async def get_step(self, task_id: str, step_id: str) -> Step:
+        raise NotImplementedError
+
+    async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
+        raise NotImplementedError
+
+    async def list_tasks(self) -> List[Task]:
+        raise NotImplementedError
+
+    async def list_steps(
+        self, task_id: str, status: Optional[Status] = None
+    ) -> List[Step]:
+        raise NotImplementedError
+
+
+class InMemoryTaskDB(TaskDB):
+    _tasks: Dict[str, Task] = {}
+
+    async def create_task(
+        self,
+        input: Optional[str],
+        additional_input: Any = None,
+        artifacts: Optional[List[Artifact]] = None,
+        steps: Optional[List[Step]] = None,
+    ) -> Task:
+        if not steps:
+            steps = []
+        if not artifacts:
+            artifacts = []
+        task_id = str(uuid.uuid4())
+        task = Task(
+            task_id=task_id,
+            input=input,
+            steps=steps,
+            artifacts=artifacts,
+            additional_input=additional_input,
+        )
+        self._tasks[task_id] = task
+        return task
+
+    async def create_step(
+        self,
+        task_id: str,
+        name: Optional[str] = None,
+        input: Optional[str] = None,
+        is_last=False,
+        additional_properties: Optional[Dict[str, Any]] = None,
+    ) -> Step:
+        step_id = str(uuid.uuid4())
+        step = Step(
+            task_id=task_id,
+            step_id=step_id,
+            name=name,
+            input=input,
+            status=Status.created,
+            is_last=is_last,
+            additional_properties=additional_properties,
+        )
+        task = await self.get_task(task_id)
+        task.steps.append(step)
+        return step
+
+    async def get_task(self, task_id: str) -> Task:
+        task = self._tasks.get(task_id, None)
+        if not task:
+            raise NotFoundException("Task", task_id)
+        return task
+
+    async def get_step(self, task_id: str, step_id: str) -> Step:
+        task = await self.get_task(task_id)
+        step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
+        if not step:
+            raise NotFoundException("Step", step_id)
+        return step
+
+    async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
+        task = await self.get_task(task_id)
+        artifact = next(
+            filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
+        )
+        if not artifact:
+            raise NotFoundException("Artifact", artifact_id)
+        return artifact
+
+    async def create_artifact(
+        self,
+        task_id: str,
+        file_name: str,
+        relative_path: Optional[str] = None,
+        step_id: Optional[str] = None,
+    ) -> Artifact:
+        artifact_id = str(uuid.uuid4())
+        artifact = Artifact(
+            artifact_id=artifact_id, file_name=file_name, relative_path=relative_path
+        )
+        task = await self.get_task(task_id)
+        task.artifacts.append(artifact)
+
+        if step_id:
+            step = await self.get_step(task_id, step_id)
+            step.artifacts.append(artifact)
+
+        return artifact
+
+    async def list_tasks(self) -> List[Task]:
+        return [task for task in self._tasks.values()]
+
+    async def list_steps(
+        self, task_id: str, status: Optional[Status] = None
+    ) -> List[Step]:
+        task = await self.get_task(task_id)
+        steps = task.steps
+        if status:
+            steps = list(filter(lambda s: s.status == status, steps))
+        return steps
\ No newline at end of file
diff --git a/swarms/structs/task.py b/swarms/structs/task.py
index 77928a4f..8618234a 100644
--- a/swarms/structs/task.py
+++ b/swarms/structs/task.py
@@ -1,10 +1,12 @@
 from __future__ import annotations
-import pprint
+
 import json
+import pprint
+from typing import Any, Optional
 
-from typing import Optional, Any
-from pydantic import BaseModel, Field, StrictStr, StrictStr, conlist
 from artifacts.main import Artifact
+from pydantic import BaseModel, Field, StrictStr, conlist
+
 
 class Task(BaseModel):
     input: Optional[StrictStr] = Field(