`InMemoryTaskDB`

Former-commit-id: 56be6060e2
group-chat
Kye 1 year ago
parent babdca8351
commit 9ce26a1ba8

@ -2,26 +2,50 @@ from abc import ABC, abstractmethod
from agent_protocol import Agent, Step, Task from agent_protocol import Agent, Step, Task
class AbstractAgent(ABC): class AbstractAgent:
#absrtact agent class @staticmethod
async def plan(step: Step) -> Step:
@classmethod task = await Agent.db.get_task(step.task_id)
def __init__( steps = generate_steps(task.input)
self,
ai_name: str = None,
ai_role: str = None,
memory = None,
tools = None,
llm = None,
human_in_the_loop=None,
output_parser = None,
chat_history_memory=None,
*args,
**kwargs
):
pass
@abstractmethod
def run(self, goals=None):
pass
last_step = steps[-1]
for step in steps[:-1]:
await Agent.db.create_step(
task_id=task.task_id,
name=step,
pass
)
await Agent.db.create_step(
task_id=task.task_id,
name=last_step,
is_last=True
)
step.output = steps
return step
@staticmethod
async def execute(step: Step) -> Step:
# Use tools, websearch, etc.
...
@staticmethod
async def task_handler(task: Task) -> None:
await Agent.db.create_step(
task_id=task.task_id,
name="plan",
pass
)
@staticmethod
async def step_handler(step: Step) -> Step:
if step.name == "plan":
await AbstractAgent.plan(step)
else:
await AbstractAgent.execute(step)
return step
@staticmethod
def start_agent():
Agent.setup_agent(AbstractAgent.task_handler, AbstractAgent.step_handler).start()

@ -0,0 +1,172 @@
import uuid
from abc import ABC
from typing import Dict, List, Optional, Any
from .models import Task as APITask, Step as APIStep, Artifact, Status
class Step(APIStep):
additional_properties: Optional[Dict[str, str]] = None
class Task(APITask):
steps: List[Step] = []
class NotFoundException(Exception):
"""
Exception raised when a resource is not found.
"""
def __init__(self, item_name: str, item_id: str):
self.item_name = item_name
self.item_id = item_id
super().__init__(f"{item_name} with {item_id} not found.")
class TaskDB(ABC):
async def create_task(
self,
input: Optional[str],
additional_input: Any = None,
artifacts: Optional[List[Artifact]] = None,
steps: Optional[List[Step]] = None,
) -> Task:
raise NotImplementedError
async def create_step(
self,
task_id: str,
name: Optional[str] = None,
input: Optional[str] = None,
is_last: bool = False,
additional_properties: Optional[Dict[str, str]] = None,
) -> Step:
raise NotImplementedError
async def create_artifact(
self,
task_id: str,
file_name: str,
relative_path: Optional[str] = None,
step_id: Optional[str] = None,
) -> Artifact:
raise NotImplementedError
async def get_task(self, task_id: str) -> Task:
raise NotImplementedError
async def get_step(self, task_id: str, step_id: str) -> Step:
raise NotImplementedError
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
raise NotImplementedError
async def list_tasks(self) -> List[Task]:
raise NotImplementedError
async def list_steps(
self, task_id: str, status: Optional[Status] = None
) -> List[Step]:
raise NotImplementedError
class InMemoryTaskDB(TaskDB):
_tasks: Dict[str, Task] = {}
async def create_task(
self,
input: Optional[str],
additional_input: Any = None,
artifacts: Optional[List[Artifact]] = None,
steps: Optional[List[Step]] = None,
) -> Task:
if not steps:
steps = []
if not artifacts:
artifacts = []
task_id = str(uuid.uuid4())
task = Task(
task_id=task_id,
input=input,
steps=steps,
artifacts=artifacts,
additional_input=additional_input,
)
self._tasks[task_id] = task
return task
async def create_step(
self,
task_id: str,
name: Optional[str] = None,
input: Optional[str] = None,
is_last=False,
additional_properties: Optional[Dict[str, Any]] = None,
) -> Step:
step_id = str(uuid.uuid4())
step = Step(
task_id=task_id,
step_id=step_id,
name=name,
input=input,
status=Status.created,
is_last=is_last,
additional_properties=additional_properties,
)
task = await self.get_task(task_id)
task.steps.append(step)
return step
async def get_task(self, task_id: str) -> Task:
task = self._tasks.get(task_id, None)
if not task:
raise NotFoundException("Task", task_id)
return task
async def get_step(self, task_id: str, step_id: str) -> Step:
task = await self.get_task(task_id)
step = next(filter(lambda s: s.task_id == task_id, task.steps), None)
if not step:
raise NotFoundException("Step", step_id)
return step
async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
task = await self.get_task(task_id)
artifact = next(
filter(lambda a: a.artifact_id == artifact_id, task.artifacts), None
)
if not artifact:
raise NotFoundException("Artifact", artifact_id)
return artifact
async def create_artifact(
self,
task_id: str,
file_name: str,
relative_path: Optional[str] = None,
step_id: Optional[str] = None,
) -> Artifact:
artifact_id = str(uuid.uuid4())
artifact = Artifact(
artifact_id=artifact_id, file_name=file_name, relative_path=relative_path
)
task = await self.get_task(task_id)
task.artifacts.append(artifact)
if step_id:
step = await self.get_step(task_id, step_id)
step.artifacts.append(artifact)
return artifact
async def list_tasks(self) -> List[Task]:
return [task for task in self._tasks.values()]
async def list_steps(
self, task_id: str, status: Optional[Status] = None
) -> List[Step]:
task = await self.get_task(task_id)
steps = task.steps
if status:
steps = list(filter(lambda s: s.status == status, steps))
return steps

@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
import pprint
import json import json
import pprint
from typing import Any, Optional
from typing import Optional, Any
from pydantic import BaseModel, Field, StrictStr, StrictStr, conlist
from artifacts.main import Artifact from artifacts.main import Artifact
from pydantic import BaseModel, Field, StrictStr, conlist
class Task(BaseModel): class Task(BaseModel):
input: Optional[StrictStr] = Field( input: Optional[StrictStr] = Field(

Loading…
Cancel
Save