You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
4.9 KiB
172 lines
4.9 KiB
import uuid
|
|
from abc import ABC
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from swarms.memory.schemas import Artifact, Status
|
|
from swarms.memory.schemas import Step as APIStep
|
|
from swarms.memory.schemas import Task as APITask
|
|
|
|
|
|
class Step(APIStep):
|
|
additional_properties: Optional[Dict[str, str]] = None
|
|
|
|
class Task(APITask):
|
|
steps: List[Step] = []
|
|
|
|
class NotFoundException(Exception):
|
|
"""
|
|
Exception raised when a resource is not found.
|
|
"""
|
|
|
|
def __init__(self, item_name: str, item_id: str):
|
|
self.item_name = item_name
|
|
self.item_id = item_id
|
|
super().__init__(f"{item_name} with {item_id} not found.")
|
|
|
|
class TaskDB(ABC):
|
|
async def create_task(
|
|
self,
|
|
input: Optional[str],
|
|
additional_input: Any = None,
|
|
artifacts: Optional[List[Artifact]] = None,
|
|
steps: Optional[List[Step]] = None,
|
|
) -> Task:
|
|
raise NotImplementedError
|
|
|
|
async def create_step(
|
|
self,
|
|
task_id: str,
|
|
name: Optional[str] = None,
|
|
input: Optional[str] = None,
|
|
is_last: bool = False,
|
|
additional_properties: Optional[Dict[str, str]] = None,
|
|
) -> Step:
|
|
raise NotImplementedError
|
|
|
|
async def create_artifact(
|
|
self,
|
|
task_id: str,
|
|
file_name: str,
|
|
relative_path: Optional[str] = None,
|
|
step_id: Optional[str] = None,
|
|
) -> Artifact:
|
|
raise NotImplementedError
|
|
|
|
async def get_task(self, task_id: str) -> Task:
|
|
raise NotImplementedError
|
|
|
|
async def get_step(self, task_id: str, step_id: str) -> Step:
|
|
raise NotImplementedError
|
|
|
|
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
|
raise NotImplementedError
|
|
|
|
async def list_tasks(self) -> List[Task]:
|
|
raise NotImplementedError
|
|
|
|
async def list_steps(
|
|
self, task_id: str, status: Optional[Status] = None
|
|
) -> List[Step]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class InMemoryTaskDB(TaskDB):
|
|
_tasks: Dict[str, Task] = {}
|
|
|
|
async def create_task(
|
|
self,
|
|
input: Optional[str],
|
|
additional_input: Any = None,
|
|
artifacts: Optional[List[Artifact]] = None,
|
|
steps: Optional[List[Step]] = None,
|
|
) -> Task:
|
|
if not steps:
|
|
steps = []
|
|
if not artifacts:
|
|
artifacts = []
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(
|
|
task_id=task_id,
|
|
input=input,
|
|
steps=steps,
|
|
artifacts=artifacts,
|
|
additional_input=additional_input,
|
|
)
|
|
self._tasks[task_id] = task
|
|
return task
|
|
|
|
async def create_step(
|
|
self,
|
|
task_id: str,
|
|
name: Optional[str] = None,
|
|
input: Optional[str] = None,
|
|
is_last=False,
|
|
additional_properties: Optional[Dict[str, Any]] = None,
|
|
) -> Step:
|
|
step_id = str(uuid.uuid4())
|
|
step = Step(
|
|
task_id=task_id,
|
|
step_id=step_id,
|
|
name=name,
|
|
input=input,
|
|
status=Status.created,
|
|
is_last=is_last,
|
|
additional_properties=additional_properties,
|
|
)
|
|
task = await self.get_task(task_id)
|
|
task.steps.append(step)
|
|
return step
|
|
|
|
async def get_task(self, task_id: str) -> Task:
|
|
task = self._tasks.get(task_id, None)
|
|
if not task:
|
|
raise NotFoundException("Task", task_id)
|
|
return task
|
|
|
|
async def get_step(self, task_id: str, step_id: str) -> Step:
|
|
task = await self.get_task(task_id)
|
|
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
|
|
if not step:
|
|
raise NotFoundException("Step", step_id)
|
|
return step
|
|
|
|
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
|
|
task = await self.get_task(task_id)
|
|
artifact = next(
|
|
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
|
|
)
|
|
if not artifact:
|
|
raise NotFoundException("Artifact", artifact_id)
|
|
return artifact
|
|
|
|
async def create_artifact(
|
|
self,
|
|
task_id: str,
|
|
file_name: str,
|
|
relative_path: Optional[str] = None,
|
|
step_id: Optional[str] = None,
|
|
) -> Artifact:
|
|
artifact_id = str(uuid.uuid4())
|
|
artifact = Artifact(
|
|
artifact_id=artifact_id, file_name=file_name, relative_path=relative_path
|
|
)
|
|
task = await self.get_task(task_id)
|
|
task.artifacts.append(artifact)
|
|
|
|
if step_id:
|
|
step = await self.get_step(task_id, step_id)
|
|
step.artifacts.append(artifact)
|
|
|
|
return artifact
|
|
|
|
async def list_tasks(self) -> List[Task]:
|
|
return [task for task in self._tasks.values()]
|
|
|
|
async def list_steps(
|
|
self, task_id: str, status: Optional[Status] = None
|
|
) -> List[Step]:
|
|
task = await self.get_task(task_id)
|
|
steps = task.steps
|
|
if status:
|
|
steps = list(filter(lambda s: s.status == status, steps))
|
|
return steps |