You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/examples/agents/3rd_party_agents/langchain.py

83 lines
2.2 KiB

from typing import List, Optional
from langchain.agents import AgentExecutor, LLMSingleActionAgent, Tool
from langchain.chains import LLMChain
from langchain_community.llms import OpenAI
from langchain.prompts import StringPromptTemplate
from langchain.tools import DuckDuckGoSearchRun
from swarms import Agent
class LangchainAgentWrapper(Agent):
"""
Initialize the LangchainAgentWrapper.
Args:
name (str): The name of the agent.
tools (List[Tool]): The list of tools available to the agent.
llm (Optional[OpenAI], optional): The OpenAI language model to use. Defaults to None.
"""
def __init__(
self,
name: str,
tools: List[Tool],
llm: Optional[OpenAI] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.name = name
self.tools = tools
self.llm = llm or OpenAI(temperature=0)
prompt = StringPromptTemplate.from_template(
"You are {name}, an AI assistant. Answer the following question: {question}"
)
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
tool_names = [tool.name for tool in self.tools]
self.agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=None,
stop=["\nObservation:"],
allowed_tools=tool_names,
)
self.agent_executor = AgentExecutor.from_agent_and_tools(
agent=self.agent, tools=self.tools, verbose=True
)
def run(self, task: str, *args, **kwargs):
"""
Run the agent with the given task.
Args:
task (str): The task to be performed by the agent.
Returns:
Any: The result of the agent's execution.
"""
try:
return self.agent_executor.run(task)
except Exception as e:
print(f"An error occurred: {e}")
# Usage example
search_tool = DuckDuckGoSearchRun()
tools = [
Tool(
name="Search",
func=search_tool.run,
description="Useful for searching the internet",
)
]
langchain_wrapper = LangchainAgentWrapper("LangchainAssistant", tools)
result = langchain_wrapper.run("What is the capital of France?")
print(result)