From 03296b4d0fc0c9c963fdbdb135313e7f5960546b Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 30 Jul 2023 09:08:44 -0400 Subject: [PATCH] iteratisons on agent class, moving away from langchain Former-commit-id: 874eaa12e1439a7713ac8521acda895a7a719afb --- swarms/agents/base.py | 9 ++++----- swarms/agents/models/base.py | 5 ----- swarms/agents/models/petals.py | 1 - 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/swarms/agents/base.py b/swarms/agents/base.py index bc8edd01..ed30fa1e 100644 --- a/swarms/agents/base.py +++ b/swarms/agents/base.py @@ -7,12 +7,11 @@ from pydantic import ValidationError from swarms.agents.utils.Agent import AgentOutputParser from swarms.agents.utils.human_input import HumanInputRun from swarms.agents.prompts.prompt_generator import FINISH_NAME +from swarms.agents.models.base import AbstractModel from langchain.chains.llm import LLMChain -from langchain.chat_models.base import BaseChatModel from langchain.memory import ChatMessageHistory - from langchain.schema import (BaseChatMessageHistory, Document,) from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.tools.base import BaseTool @@ -25,17 +24,17 @@ class Agent: def __init__( self, ai_name: str, - memory: VectorStoreRetriever, chain: LLMChain, + memory: VectorStoreRetriever, output_parser: BaseAgentOutputParser, tools: List[BaseTool], feedback_tool: Optional[HumanInputRun] = None, chat_history_memory: Optional[BaseChatMessageHistory] = None, ): self.ai_name = ai_name + self.chain = chain self.memory = memory self.next_action_count = 0 - self.chain = chain self.output_parser = output_parser self.tools = tools self.feedback_tool = feedback_tool @@ -48,7 +47,7 @@ class Agent: ai_role: str, memory: VectorStoreRetriever, tools: List[BaseTool], - llm: BaseChatModel, + llm: AbstractModel, human_in_the_loop: bool = False, output_parser: Optional[BaseAgentOutputParser] = None, chat_history_memory: Optional[BaseChatMessageHistory] = None, diff --git a/swarms/agents/models/base.py b/swarms/agents/models/base.py index 63277925..5db5cb31 100644 --- a/swarms/agents/models/base.py +++ b/swarms/agents/models/base.py @@ -2,11 +2,6 @@ from abc import ABC, abstractmethod class AbstractModel(ABC): #abstract base class for language models - - @abstractmethod - def __init__(self, model_name **kwargs): - self.model_name = model_name - @abstractmethod def generate(self, prompt): #generate text using language model diff --git a/swarms/agents/models/petals.py b/swarms/agents/models/petals.py index c5c25803..56b10ef1 100644 --- a/swarms/agents/models/petals.py +++ b/swarms/agents/models/petals.py @@ -1,4 +1,3 @@ -import os from transformers import AutoTokenizer, AutoModelForCausalLM class Petals: