Former-commit-id: ccb239dc2b
discord-bot-framework
parent
9092be3105
commit
927b95662a
@ -1,19 +0,0 @@
|
||||
from swarms.models.bing_chat import BingChat
|
||||
from swarms.workers.worker import Worker
|
||||
from swarms.tools.autogpt import EdgeGPTTool, tool
|
||||
|
||||
|
||||
# Initialize the language model,
|
||||
# This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC
|
||||
llm = BingChat(cookies_path="./cookies.json")
|
||||
|
||||
# Initialize the Worker with the custom tool
|
||||
worker = Worker(
|
||||
llm=llm,
|
||||
ai_name="EdgeGPT Worker",
|
||||
)
|
||||
|
||||
# Use the worker to process a task
|
||||
task = "Hello, my name is ChatGPT"
|
||||
response = worker.run(task)
|
||||
print(response)
|
@ -1,6 +0,0 @@
|
||||
[
|
||||
{
|
||||
"name": "cookie1",
|
||||
"value": "1GJjj1-tM6Jlo4HFtnbocQ3r0QbQ9Aq_R65dqbcSWKzKxnN8oEMW1xa4RlsJ_nGyNjFlXQRzMWRR2GK11bve8-6n_bjF0zTczYcQQ8oDB8W66jgpIWSL7Hr4hneB0R9dIt-OQ4cVPs4eehL2lcRCObWQr0zkG14MHlH5EMwAKthv_NNIQSfThq4Ey2Hmzhq9sRuyS04JveHdLC9gfthJ8xk3J12yr7j4HsynpzmvFUcA"
|
||||
}
|
||||
]
|
@ -0,0 +1,437 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.memory import ChatMessageHistory
|
||||
from langchain.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
)
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
Document,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.vectorstore import VectorStoreRetriever
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.human.tool import HumanInputRun
|
||||
from langchain_experimental.autonomous_agents.autogpt.output_parser import (
|
||||
AutoGPTOutputParser,
|
||||
BaseAutoGPTOutputParser,
|
||||
)
|
||||
from langchain_experimental.autonomous_agents.autogpt.prompt import AutoGPTPrompt
|
||||
from langchain_experimental.autonomous_agents.autogpt.prompt_generator import (
|
||||
FINISH_NAME,
|
||||
get_prompt,
|
||||
)
|
||||
from langchain_experimental.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
|
||||
# PROMPT
|
||||
FINISH_NAME = "finish"
|
||||
|
||||
|
||||
# This class has a metaclass conflict: both `BaseChatPromptTemplate` and `BaseModel`
|
||||
# define a metaclass to use, and the two metaclasses attempt to define
|
||||
# the same functions but in mutually-incompatible ways.
|
||||
# It isn't clear how to resolve this, and this code predates mypy
|
||||
# beginning to perform that check.
|
||||
#
|
||||
# Mypy errors:
|
||||
# ```
|
||||
# Definition of "__private_attributes__" in base class "BaseModel" is
|
||||
# incompatible with definition in base class "BaseModel" [misc]
|
||||
# Definition of "__repr_name__" in base class "Representation" is
|
||||
# incompatible with definition in base class "BaseModel" [misc]
|
||||
# Definition of "__pretty__" in base class "Representation" is
|
||||
# incompatible with definition in base class "BaseModel" [misc]
|
||||
# Definition of "__repr_str__" in base class "Representation" is
|
||||
# incompatible with definition in base class "BaseModel" [misc]
|
||||
# Definition of "__rich_repr__" in base class "Representation" is
|
||||
# incompatible with definition in base class "BaseModel" [misc]
|
||||
# Metaclass conflict: the metaclass of a derived class must be
|
||||
# a (non-strict) subclass of the metaclasses of all its bases [misc]
|
||||
# ```
|
||||
#
|
||||
# TODO: look into refactoring this class in a way that avoids the mypy type errors
|
||||
class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
|
||||
"""Prompt for AutoGPT."""
|
||||
|
||||
ai_name: str
|
||||
ai_role: str
|
||||
tools: List[BaseTool]
|
||||
token_counter: Callable[[str], int]
|
||||
send_token_limit: int = 4196
|
||||
|
||||
def construct_full_prompt(self, goals: List[str]) -> str:
|
||||
prompt_start = (
|
||||
"Your decisions must always be made independently "
|
||||
"without seeking user assistance.\n"
|
||||
"Play to your strengths as an LLM and pursue simple "
|
||||
"strategies with no legal complications.\n"
|
||||
"If you have completed all your tasks, make sure to "
|
||||
'use the "finish" command.'
|
||||
)
|
||||
# Construct full prompt
|
||||
full_prompt = (
|
||||
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
|
||||
)
|
||||
for i, goal in enumerate(goals):
|
||||
full_prompt += f"{i+1}. {goal}\n"
|
||||
|
||||
full_prompt += f"\n\n{get_prompt(self.tools)}"
|
||||
return full_prompt
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"]))
|
||||
time_prompt = SystemMessage(
|
||||
content=f"The current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
|
||||
time_prompt.content
|
||||
)
|
||||
memory: VectorStoreRetriever = kwargs["memory"]
|
||||
previous_messages = kwargs["messages"]
|
||||
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
|
||||
relevant_memory = [d.page_content for d in relevant_docs]
|
||||
relevant_memory_tokens = sum(
|
||||
[self.token_counter(doc) for doc in relevant_memory]
|
||||
)
|
||||
while used_tokens + relevant_memory_tokens > 2500:
|
||||
relevant_memory = relevant_memory[:-1]
|
||||
relevant_memory_tokens = sum(
|
||||
[self.token_counter(doc) for doc in relevant_memory]
|
||||
)
|
||||
content_format = (
|
||||
f"This reminds you of these events "
|
||||
f"from your past:\n{relevant_memory}\n\n"
|
||||
)
|
||||
memory_message = SystemMessage(content=content_format)
|
||||
used_tokens += self.token_counter(memory_message.content)
|
||||
historical_messages: List[BaseMessage] = []
|
||||
for message in previous_messages[-10:][::-1]:
|
||||
message_tokens = self.token_counter(message.content)
|
||||
if used_tokens + message_tokens > self.send_token_limit - 1000:
|
||||
break
|
||||
historical_messages = [message] + historical_messages
|
||||
used_tokens += message_tokens
|
||||
input_message = HumanMessage(content=kwargs["user_input"])
|
||||
messages: List[BaseMessage] = [base_prompt, time_prompt, memory_message]
|
||||
messages += historical_messages
|
||||
messages.append(input_message)
|
||||
return messages
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
"""A class for generating custom prompt strings.
|
||||
|
||||
Does this based on constraints, commands, resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the PromptGenerator object.
|
||||
|
||||
Starts with empty lists of constraints, commands, resources,
|
||||
and performance evaluations.
|
||||
"""
|
||||
self.constraints: List[str] = []
|
||||
self.commands: List[BaseTool] = []
|
||||
self.resources: List[str] = []
|
||||
self.performance_evaluation: List[str] = []
|
||||
self.response_format = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_tool(self, tool: BaseTool) -> None:
|
||||
self.commands.append(tool)
|
||||
|
||||
def _generate_command_string(self, tool: BaseTool) -> str:
|
||||
output = f"{tool.name}: {tool.description}"
|
||||
output += f", args json schema: {json.dumps(tool.args)}"
|
||||
return output
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
Add a performance evaluation item to the performance_evaluation list.
|
||||
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self.performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: list, item_type: str = "list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = [
|
||||
f"{i + 1}. {self._generate_command_string(item)}"
|
||||
for i, item in enumerate(items)
|
||||
]
|
||||
finish_description = (
|
||||
"use this to signal that you have finished all your objectives"
|
||||
)
|
||||
finish_args = (
|
||||
'"response": "final response to let '
|
||||
'people know you have finished your objectives"'
|
||||
)
|
||||
finish_string = (
|
||||
f"{len(items) + 1}. {FINISH_NAME}: "
|
||||
f"{finish_description}, args: {finish_args}"
|
||||
)
|
||||
return "\n".join(command_strings + [finish_string])
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def generate_prompt_string(self) -> str:
|
||||
"""Generate a prompt string.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
formatted_response_format = json.dumps(self.response_format, indent=4)
|
||||
prompt_string = (
|
||||
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
|
||||
f"Commands:\n"
|
||||
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
|
||||
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
|
||||
f"Performance Evaluation:\n"
|
||||
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
|
||||
f"You should only respond in JSON format as described below "
|
||||
f"\nResponse Format: \n{formatted_response_format} "
|
||||
f"\nEnsure the response can be parsed by Python json.loads"
|
||||
)
|
||||
|
||||
return prompt_string
|
||||
|
||||
|
||||
def get_prompt(tools: List[BaseTool]) -> str:
|
||||
"""Generates a prompt string.
|
||||
|
||||
It includes various constraints, commands, resources, and performance evaluations.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
|
||||
# Initialize the PromptGenerator object
|
||||
prompt_generator = PromptGenerator()
|
||||
|
||||
# Add constraints to the PromptGenerator object
|
||||
prompt_generator.add_constraint(
|
||||
"~16000 word limit for short term memory. "
|
||||
"Your short term memory is short, "
|
||||
"so immediately save important information to files."
|
||||
)
|
||||
prompt_generator.add_constraint(
|
||||
"If you are unsure how you previously did something "
|
||||
"or want to recall past events, "
|
||||
"thinking about similar events will help you remember."
|
||||
)
|
||||
prompt_generator.add_constraint("No user assistance")
|
||||
prompt_generator.add_constraint(
|
||||
'Exclusively use the commands listed in double quotes e.g. "command name"'
|
||||
)
|
||||
|
||||
# Add commands to the PromptGenerator object
|
||||
for tool in tools:
|
||||
prompt_generator.add_tool(tool)
|
||||
|
||||
# Add resources to the PromptGenerator object
|
||||
prompt_generator.add_resource(
|
||||
"Internet access for searches and information gathering."
|
||||
)
|
||||
prompt_generator.add_resource("Long Term memory management.")
|
||||
prompt_generator.add_resource(
|
||||
"GPT-3.5 powered Agents for delegation of simple tasks."
|
||||
)
|
||||
prompt_generator.add_resource("File output.")
|
||||
|
||||
# Add performance evaluations to the PromptGenerator object
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Continuously review and analyze your actions "
|
||||
"to ensure you are performing to the best of your abilities."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Constructively self-criticize your big-picture behavior constantly."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Reflect on past decisions and strategies to refine your approach."
|
||||
)
|
||||
prompt_generator.add_performance_evaluation(
|
||||
"Every command has a cost, so be smart and efficient. "
|
||||
"Aim to complete tasks in the least number of steps."
|
||||
)
|
||||
|
||||
# Generate the prompt string
|
||||
prompt_string = prompt_generator.generate_prompt_string()
|
||||
|
||||
return prompt_string
|
||||
|
||||
|
||||
class AutoGPT:
|
||||
"""
|
||||
AutoAgent:
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ai_name: str,
|
||||
memory: VectorStoreRetriever,
|
||||
chain: LLMChain,
|
||||
output_parser: BaseAutoGPTOutputParser,
|
||||
tools: List[BaseTool],
|
||||
feedback_tool: Optional[HumanInputRun] = None,
|
||||
chat_history_memory: Optional[BaseChatMessageHistory] = None,
|
||||
):
|
||||
self.ai_name = ai_name
|
||||
self.memory = memory
|
||||
self.next_action_count = 0
|
||||
self.chain = chain
|
||||
self.output_parser = output_parser
|
||||
self.tools = tools
|
||||
self.feedback_tool = feedback_tool
|
||||
self.chat_history_memory = chat_history_memory or ChatMessageHistory()
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
ai_name: str,
|
||||
ai_role: str,
|
||||
memory: VectorStoreRetriever,
|
||||
tools: List[BaseTool],
|
||||
llm: BaseChatModel,
|
||||
human_in_the_loop: bool = False,
|
||||
output_parser: Optional[BaseAutoGPTOutputParser] = None,
|
||||
chat_history_memory: Optional[BaseChatMessageHistory] = None,
|
||||
) -> AutoGPT:
|
||||
prompt = AutoGPTPrompt(
|
||||
ai_name=ai_name,
|
||||
ai_role=ai_role,
|
||||
tools=tools,
|
||||
input_variables=["memory", "messages", "goals", "user_input"],
|
||||
token_counter=llm.get_num_tokens,
|
||||
)
|
||||
human_feedback_tool = HumanInputRun() if human_in_the_loop else None
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(
|
||||
ai_name,
|
||||
memory,
|
||||
chain,
|
||||
output_parser or AutoGPTOutputParser(),
|
||||
tools,
|
||||
feedback_tool=human_feedback_tool,
|
||||
chat_history_memory=chat_history_memory,
|
||||
)
|
||||
|
||||
def run(self, goals: List[str]) -> str:
|
||||
user_input = (
|
||||
"Determine which next command to use, "
|
||||
"and respond using the format specified above:"
|
||||
)
|
||||
# Interaction Loop
|
||||
loop_count = 0
|
||||
while True:
|
||||
# Discontinue if continuous limit is reached
|
||||
loop_count += 1
|
||||
|
||||
# Send message to AI, get response
|
||||
assistant_reply = self.chain.run(
|
||||
goals=goals,
|
||||
messages=self.chat_history_memory.messages,
|
||||
memory=self.memory,
|
||||
user_input=user_input,
|
||||
)
|
||||
|
||||
# Print Assistant thoughts
|
||||
print(assistant_reply)
|
||||
self.chat_history_memory.add_message(HumanMessage(content=user_input))
|
||||
self.chat_history_memory.add_message(AIMessage(content=assistant_reply))
|
||||
|
||||
# Get command name and arguments
|
||||
action = self.output_parser.parse(assistant_reply)
|
||||
tools = {t.name: t for t in self.tools}
|
||||
if action.name == FINISH_NAME:
|
||||
return action.args["response"]
|
||||
if action.name in tools:
|
||||
tool = tools[action.name]
|
||||
try:
|
||||
observation = tool.run(action.args)
|
||||
except ValidationError as e:
|
||||
observation = (
|
||||
f"Validation Error in args: {str(e)}, args: {action.args}"
|
||||
)
|
||||
except Exception as e:
|
||||
observation = (
|
||||
f"Error: {str(e)}, {type(e).__name__}, args: {action.args}"
|
||||
)
|
||||
result = f"Command {tool.name} returned: {observation}"
|
||||
elif action.name == "ERROR":
|
||||
result = f"Error: {action.args}. "
|
||||
else:
|
||||
result = (
|
||||
f"Unknown command '{action.name}'. "
|
||||
f"Please refer to the 'COMMANDS' list for available "
|
||||
f"commands and only respond in the specified JSON format."
|
||||
)
|
||||
|
||||
memory_to_add = (
|
||||
f"Assistant Reply: {assistant_reply} " f"\nResult: {result} "
|
||||
)
|
||||
if self.feedback_tool is not None:
|
||||
feedback = f"\n{self.feedback_tool.run('Input: ')}"
|
||||
if feedback in {"q", "stop"}:
|
||||
print("EXITING")
|
||||
return "EXITING"
|
||||
memory_to_add += feedback
|
||||
|
||||
self.memory.add_documents([Document(page_content=memory_to_add)])
|
||||
self.chat_history_memory.add_message(SystemMessage(content=result))
|
@ -0,0 +1,7 @@
|
||||
"""
|
||||
Data Loaders for APPS
|
||||
|
||||
|
||||
TODO: Clean up all the llama index stuff, remake the logic from scratch
|
||||
|
||||
"""
|
@ -0,0 +1,103 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.readers.schema.base import Document
|
||||
|
||||
|
||||
class AsanaReader(BaseReader):
|
||||
"""Asana reader. Reads data from an Asana workspace.
|
||||
|
||||
Args:
|
||||
asana_token (str): Asana token.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, asana_token: str) -> None:
|
||||
"""Initialize Asana reader."""
|
||||
import asana
|
||||
|
||||
self.client = asana.Client.access_token(asana_token)
|
||||
|
||||
def load_data(
|
||||
self, workspace_id: Optional[str] = None, project_id: Optional[str] = None
|
||||
) -> List[Document]:
|
||||
"""Load data from the workspace.
|
||||
|
||||
Args:
|
||||
workspace_id (Optional[str], optional): Workspace ID. Defaults to None.
|
||||
project_id (Optional[str], optional): Project ID. Defaults to None.
|
||||
Returns:
|
||||
List[Document]: List of documents.
|
||||
"""
|
||||
|
||||
if workspace_id is None and project_id is None:
|
||||
raise ValueError("Either workspace_id or project_id must be provided")
|
||||
|
||||
if workspace_id is not None and project_id is not None:
|
||||
raise ValueError(
|
||||
"Only one of workspace_id or project_id should be provided"
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
if workspace_id is not None:
|
||||
workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"]
|
||||
projects = self.client.projects.find_all({"workspace": workspace_id})
|
||||
|
||||
# Case: Only project_id is provided
|
||||
else: # since we've handled the other cases, this means project_id is not None
|
||||
projects = [self.client.projects.find_by_id(project_id)]
|
||||
workspace_name = projects[0]["workspace"]["name"]
|
||||
|
||||
for project in projects:
|
||||
tasks = self.client.tasks.find_all(
|
||||
{
|
||||
"project": project["gid"],
|
||||
"opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields",
|
||||
}
|
||||
)
|
||||
for task in tasks:
|
||||
stories = self.client.tasks.stories(task["gid"], opt_fields="type,text")
|
||||
comments = "\n".join(
|
||||
[
|
||||
story["text"]
|
||||
for story in stories
|
||||
if story.get("type") == "comment" and "text" in story
|
||||
]
|
||||
)
|
||||
|
||||
task_metadata = {
|
||||
"task_id": task.get("gid", ""),
|
||||
"name": task.get("name", ""),
|
||||
"assignee": (task.get("assignee") or {}).get("name", ""),
|
||||
"completed_on": task.get("completed_at", ""),
|
||||
"completed_by": (task.get("completed_by") or {}).get("name", ""),
|
||||
"project_name": project.get("name", ""),
|
||||
"custom_fields": [
|
||||
i["display_value"]
|
||||
for i in task.get("custom_fields")
|
||||
if task.get("custom_fields") is not None
|
||||
],
|
||||
"workspace_name": workspace_name,
|
||||
"url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}",
|
||||
}
|
||||
|
||||
if task.get("followers") is not None:
|
||||
task_metadata["followers"] = [
|
||||
i.get("name") for i in task.get("followers") if "name" in i
|
||||
]
|
||||
else:
|
||||
task_metadata["followers"] = []
|
||||
|
||||
results.append(
|
||||
Document(
|
||||
text=task.get("name", "")
|
||||
+ " "
|
||||
+ task.get("notes", "")
|
||||
+ " "
|
||||
+ comments,
|
||||
extra_info=task_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
@ -0,0 +1,622 @@
|
||||
"""Base schema for data structures."""
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from enum import Enum, auto
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
from llama_index.utils import SAMPLE_TEXT, truncate_text
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
from semantic_kernel.memory.memory_record import MemoryRecord
|
||||
|
||||
|
||||
####
|
||||
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
|
||||
DEFAULT_METADATA_TMPL = "{key}: {value}"
|
||||
# NOTE: for pretty printing
|
||||
TRUNCATE_LENGTH = 350
|
||||
WRAP_WIDTH = 70
|
||||
|
||||
|
||||
class BaseComponent(BaseModel):
|
||||
"""Base component object to capture class names."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def class_name(cls) -> str:
|
||||
"""
|
||||
Get the class name, used as a unique ID in serialization.
|
||||
|
||||
This provides a key that makes serialization robust against actual class
|
||||
name changes.
|
||||
"""
|
||||
|
||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
data = self.dict(**kwargs)
|
||||
data["class_name"] = self.class_name()
|
||||
return data
|
||||
|
||||
def to_json(self, **kwargs: Any) -> str:
|
||||
data = self.to_dict(**kwargs)
|
||||
return json.dumps(data)
|
||||
|
||||
# TODO: return type here not supported by current mypy version
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
|
||||
if isinstance(kwargs, dict):
|
||||
data.update(kwargs)
|
||||
|
||||
data.pop("class_name", None)
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
|
||||
data = json.loads(data_str)
|
||||
return cls.from_dict(data, **kwargs)
|
||||
|
||||
|
||||
class NodeRelationship(str, Enum):
|
||||
"""Node relationships used in `BaseNode` class.
|
||||
|
||||
Attributes:
|
||||
SOURCE: The node is the source document.
|
||||
PREVIOUS: The node is the previous node in the document.
|
||||
NEXT: The node is the next node in the document.
|
||||
PARENT: The node is the parent node in the document.
|
||||
CHILD: The node is a child node in the document.
|
||||
|
||||
"""
|
||||
|
||||
SOURCE = auto()
|
||||
PREVIOUS = auto()
|
||||
NEXT = auto()
|
||||
PARENT = auto()
|
||||
CHILD = auto()
|
||||
|
||||
|
||||
class ObjectType(str, Enum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
INDEX = auto()
|
||||
DOCUMENT = auto()
|
||||
|
||||
|
||||
class MetadataMode(str, Enum):
|
||||
ALL = auto()
|
||||
EMBED = auto()
|
||||
LLM = auto()
|
||||
NONE = auto()
|
||||
|
||||
|
||||
class RelatedNodeInfo(BaseComponent):
|
||||
node_id: str
|
||||
node_type: Optional[ObjectType] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
hash: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "RelatedNodeInfo"
|
||||
|
||||
|
||||
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
|
||||
|
||||
|
||||
# Node classes for indexes
|
||||
class BaseNode(BaseComponent):
|
||||
"""Base node Object.
|
||||
|
||||
Generic abstract interface for retrievable nodes
|
||||
|
||||
"""
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
id_: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
|
||||
)
|
||||
embedding: Optional[List[float]] = Field(
|
||||
default=None, description="Embedding of the node."
|
||||
)
|
||||
|
||||
""""
|
||||
metadata fields
|
||||
- injected as part of the text shown to LLMs as context
|
||||
- injected as part of the text for generating embeddings
|
||||
- used by vector DBs for metadata filtering
|
||||
|
||||
"""
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="A flat dictionary of metadata fields",
|
||||
alias="extra_info",
|
||||
)
|
||||
excluded_embed_metadata_keys: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metadata keys that are excluded from text for the embed model.",
|
||||
)
|
||||
excluded_llm_metadata_keys: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metadata keys that are excluded from text for the LLM.",
|
||||
)
|
||||
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
|
||||
default_factory=dict,
|
||||
description="A mapping of relationships to other node information.",
|
||||
)
|
||||
hash: str = Field(default="", description="Hash of the node content.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Object type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Get object content."""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Metadata string."""
|
||||
|
||||
@abstractmethod
|
||||
def set_content(self, value: Any) -> None:
|
||||
"""Set the content of the node."""
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self.id_
|
||||
|
||||
@node_id.setter
|
||||
def node_id(self, value: str) -> None:
|
||||
self.id_ = value
|
||||
|
||||
@property
|
||||
def source_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Source object node.
|
||||
|
||||
Extracted from the relationships field.
|
||||
|
||||
"""
|
||||
if NodeRelationship.SOURCE not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.SOURCE]
|
||||
if isinstance(relation, list):
|
||||
raise ValueError("Source object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def prev_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Prev node."""
|
||||
if NodeRelationship.PREVIOUS not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.PREVIOUS]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Previous object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def next_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Next node."""
|
||||
if NodeRelationship.NEXT not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.NEXT]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Next object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def parent_node(self) -> Optional[RelatedNodeInfo]:
|
||||
"""Parent node."""
|
||||
if NodeRelationship.PARENT not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.PARENT]
|
||||
if not isinstance(relation, RelatedNodeInfo):
|
||||
raise ValueError("Parent object must be a single RelatedNodeInfo object")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
|
||||
"""Child nodes."""
|
||||
if NodeRelationship.CHILD not in self.relationships:
|
||||
return None
|
||||
|
||||
relation = self.relationships[NodeRelationship.CHILD]
|
||||
if not isinstance(relation, list):
|
||||
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
|
||||
return relation
|
||||
|
||||
@property
|
||||
def ref_doc_id(self) -> Optional[str]:
|
||||
"""Deprecated: Get ref doc id."""
|
||||
source_node = self.source_node
|
||||
if source_node is None:
|
||||
return None
|
||||
return source_node.node_id
|
||||
|
||||
@property
|
||||
def extra_info(self) -> Dict[str, Any]:
|
||||
"""TODO: DEPRECATED: Extra info."""
|
||||
return self.metadata
|
||||
|
||||
def __str__(self) -> str:
|
||||
source_text_truncated = truncate_text(
|
||||
self.get_content().strip(), TRUNCATE_LENGTH
|
||||
)
|
||||
source_text_wrapped = textwrap.fill(
|
||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||
)
|
||||
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
|
||||
|
||||
def get_embedding(self) -> List[float]:
|
||||
"""Get embedding.
|
||||
|
||||
Errors if embedding is None.
|
||||
|
||||
"""
|
||||
if self.embedding is None:
|
||||
raise ValueError("embedding not set.")
|
||||
return self.embedding
|
||||
|
||||
def as_related_node_info(self) -> RelatedNodeInfo:
|
||||
"""Get node as RelatedNodeInfo."""
|
||||
return RelatedNodeInfo(
|
||||
node_id=self.node_id,
|
||||
node_type=self.get_type(),
|
||||
metadata=self.metadata,
|
||||
hash=self.hash,
|
||||
)
|
||||
|
||||
|
||||
class TextNode(BaseNode):
|
||||
text: str = Field(default="", description="Text content of the node.")
|
||||
start_char_idx: Optional[int] = Field(
|
||||
default=None, description="Start char index of the node."
|
||||
)
|
||||
end_char_idx: Optional[int] = Field(
|
||||
default=None, description="End char index of the node."
|
||||
)
|
||||
text_template: str = Field(
|
||||
default=DEFAULT_TEXT_NODE_TMPL,
|
||||
description=(
|
||||
"Template for how text is formatted, with {content} and "
|
||||
"{metadata_str} placeholders."
|
||||
),
|
||||
)
|
||||
metadata_template: str = Field(
|
||||
default=DEFAULT_METADATA_TMPL,
|
||||
description=(
|
||||
"Template for how metadata is formatted, with {key} and "
|
||||
"{value} placeholders."
|
||||
),
|
||||
)
|
||||
metadata_seperator: str = Field(
|
||||
default="\n",
|
||||
description="Separator between metadata fields when converting to string.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "TextNode"
|
||||
|
||||
@root_validator
|
||||
def _check_hash(cls, values: dict) -> dict:
|
||||
"""Generate a hash to represent the node."""
|
||||
text = values.get("text", "")
|
||||
metadata = values.get("metadata", {})
|
||||
doc_identity = str(text) + str(metadata)
|
||||
values["hash"] = str(
|
||||
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
|
||||
)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Object type."""
|
||||
return ObjectType.TEXT
|
||||
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||
"""Get object content."""
|
||||
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
|
||||
if not metadata_str:
|
||||
return self.text
|
||||
|
||||
return self.text_template.format(
|
||||
content=self.text, metadata_str=metadata_str
|
||||
).strip()
|
||||
|
||||
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
|
||||
"""Metadata info string."""
|
||||
if mode == MetadataMode.NONE:
|
||||
return ""
|
||||
|
||||
usable_metadata_keys = set(self.metadata.keys())
|
||||
if mode == MetadataMode.LLM:
|
||||
for key in self.excluded_llm_metadata_keys:
|
||||
if key in usable_metadata_keys:
|
||||
usable_metadata_keys.remove(key)
|
||||
elif mode == MetadataMode.EMBED:
|
||||
for key in self.excluded_embed_metadata_keys:
|
||||
if key in usable_metadata_keys:
|
||||
usable_metadata_keys.remove(key)
|
||||
|
||||
return self.metadata_seperator.join(
|
||||
[
|
||||
self.metadata_template.format(key=key, value=str(value))
|
||||
for key, value in self.metadata.items()
|
||||
if key in usable_metadata_keys
|
||||
]
|
||||
)
|
||||
|
||||
def set_content(self, value: str) -> None:
|
||||
"""Set the content of the node."""
|
||||
self.text = value
|
||||
|
||||
def get_node_info(self) -> Dict[str, Any]:
|
||||
"""Get node info."""
|
||||
return {"start": self.start_char_idx, "end": self.end_char_idx}
|
||||
|
||||
def get_text(self) -> str:
|
||||
return self.get_content(metadata_mode=MetadataMode.NONE)
|
||||
|
||||
@property
|
||||
def node_info(self) -> Dict[str, Any]:
|
||||
"""Deprecated: Get node info."""
|
||||
return self.get_node_info()
|
||||
|
||||
|
||||
# TODO: legacy backport of old Node class
|
||||
Node = TextNode
|
||||
|
||||
|
||||
class ImageNode(TextNode):
|
||||
"""Node with image."""
|
||||
|
||||
# TODO: store reference instead of actual image
|
||||
# base64 encoded image str
|
||||
image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ObjectType.IMAGE
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ImageNode"
|
||||
|
||||
|
||||
class IndexNode(TextNode):
|
||||
"""Node with reference to any object.
|
||||
|
||||
This can include other indices, query engines, retrievers.
|
||||
|
||||
This can also include other nodes (though this is overlapping with `relationships`
|
||||
on the Node class).
|
||||
|
||||
"""
|
||||
|
||||
index_id: str
|
||||
|
||||
@classmethod
|
||||
def from_text_node(
|
||||
cls,
|
||||
node: TextNode,
|
||||
index_id: str,
|
||||
) -> "IndexNode":
|
||||
"""Create index node from text node."""
|
||||
# copy all attributes from text node, add index id
|
||||
return cls(
|
||||
**node.dict(),
|
||||
index_id=index_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ObjectType.INDEX
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "IndexNode"
|
||||
|
||||
|
||||
class NodeWithScore(BaseComponent):
|
||||
node: BaseNode
|
||||
score: Optional[float] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.node}\nScore: {self.score: 0.3f}\n"
|
||||
|
||||
def get_score(self, raise_error: bool = False) -> float:
|
||||
"""Get score."""
|
||||
if self.score is None:
|
||||
if raise_error:
|
||||
raise ValueError("Score not set.")
|
||||
else:
|
||||
return 0.0
|
||||
else:
|
||||
return self.score
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "NodeWithScore"
|
||||
|
||||
##### pass through methods to BaseNode #####
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
return self.node.node_id
|
||||
|
||||
@property
|
||||
def id_(self) -> str:
|
||||
return self.node.id_
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if isinstance(self.node, TextNode):
|
||||
return self.node.text
|
||||
else:
|
||||
raise ValueError("Node must be a TextNode to get text.")
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
return self.node.metadata
|
||||
|
||||
@property
|
||||
def embedding(self) -> Optional[List[float]]:
|
||||
return self.node.embedding
|
||||
|
||||
def get_text(self) -> str:
|
||||
if isinstance(self.node, TextNode):
|
||||
return self.node.get_text()
|
||||
else:
|
||||
raise ValueError("Node must be a TextNode to get text.")
|
||||
|
||||
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
|
||||
return self.node.get_content(metadata_mode=metadata_mode)
|
||||
|
||||
def get_embedding(self) -> List[float]:
|
||||
return self.node.get_embedding()
|
||||
|
||||
|
||||
# Document Classes for Readers
|
||||
|
||||
|
||||
class Document(TextNode):
|
||||
"""Generic interface for a data document.
|
||||
|
||||
This document connects to data sources.
|
||||
|
||||
"""
|
||||
|
||||
# TODO: A lot of backwards compatibility logic here, clean up
|
||||
id_: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the node.",
|
||||
alias="doc_id",
|
||||
)
|
||||
|
||||
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Get Document type."""
|
||||
return ObjectType.DOCUMENT
|
||||
|
||||
@property
|
||||
def doc_id(self) -> str:
|
||||
"""Get document ID."""
|
||||
return self.id_
|
||||
|
||||
def __str__(self) -> str:
|
||||
source_text_truncated = truncate_text(
|
||||
self.get_content().strip(), TRUNCATE_LENGTH
|
||||
)
|
||||
source_text_wrapped = textwrap.fill(
|
||||
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
|
||||
)
|
||||
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
|
||||
|
||||
def get_doc_id(self) -> str:
|
||||
"""TODO: Deprecated: Get document ID."""
|
||||
return self.id_
|
||||
|
||||
def __setattr__(self, name: str, value: object) -> None:
|
||||
if name in self._compat_fields:
|
||||
name = self._compat_fields[name]
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def to_langchain_format(self) -> "LCDocument":
|
||||
"""Convert struct to LangChain document format."""
|
||||
from llama_index.bridge.langchain import Document as LCDocument
|
||||
|
||||
metadata = self.metadata or {}
|
||||
return LCDocument(page_content=self.text, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, doc: "LCDocument") -> "Document":
|
||||
"""Convert struct from LangChain document format."""
|
||||
return cls(text=doc.page_content, metadata=doc.metadata)
|
||||
|
||||
def to_haystack_format(self) -> "HaystackDocument":
|
||||
"""Convert struct to Haystack document format."""
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
return HaystackDocument(
|
||||
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_haystack_format(cls, doc: "HaystackDocument") -> "Document":
|
||||
"""Convert struct from Haystack document format."""
|
||||
return cls(
|
||||
text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id
|
||||
)
|
||||
|
||||
def to_embedchain_format(self) -> Dict[str, Any]:
|
||||
"""Convert struct to EmbedChain document format."""
|
||||
return {
|
||||
"doc_id": self.id_,
|
||||
"data": {"content": self.text, "meta_data": self.metadata},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document":
|
||||
"""Convert struct from EmbedChain document format."""
|
||||
return cls(
|
||||
text=doc["data"]["content"],
|
||||
metadata=doc["data"]["meta_data"],
|
||||
id_=doc["doc_id"],
|
||||
)
|
||||
|
||||
def to_semantic_kernel_format(self) -> "MemoryRecord":
|
||||
"""Convert struct to Semantic Kernel document format."""
|
||||
import numpy as np
|
||||
from semantic_kernel.memory.memory_record import MemoryRecord
|
||||
|
||||
return MemoryRecord(
|
||||
id=self.id_,
|
||||
text=self.text,
|
||||
additional_metadata=self.get_metadata_str(),
|
||||
embedding=np.array(self.embedding) if self.embedding else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document":
|
||||
"""Convert struct from Semantic Kernel document format."""
|
||||
return cls(
|
||||
text=doc._text,
|
||||
metadata={"additional_metadata": doc._additional_metadata},
|
||||
embedding=doc._embedding.tolist() if doc._embedding is not None else None,
|
||||
id_=doc._id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def example(cls) -> "Document":
|
||||
return Document(
|
||||
text=SAMPLE_TEXT,
|
||||
metadata={"filename": "README.md", "category": "codebase"},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "Document"
|
||||
|
||||
|
||||
class ImageDocument(Document):
|
||||
"""Data document containing an image."""
|
||||
|
||||
# base64 encoded image str
|
||||
image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ImageDocument"
|
@ -1,2 +1,2 @@
|
||||
# from swarms.structs.workflow import Workflow
|
||||
# from swarms.structs.task import Task
|
||||
from swarms.structs.workflow import Workflow
|
||||
from swarms.structs.task import Task
|
||||
|
@ -0,0 +1,225 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
from swarms.models import OpenAIChat
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
# Custome stopping condition
|
||||
def stop_when_repeats(response: str) -> bool:
|
||||
# Stop if the word stop appears in the response
|
||||
return "Stop" in response.lower()
|
||||
|
||||
|
||||
# class Flow:
|
||||
# def __init__(
|
||||
# self,
|
||||
# llm: Any,
|
||||
# template: str,
|
||||
# max_loops: int = 1,
|
||||
# stopping_condition: Optional[Callable[[str], bool]] = None,
|
||||
# **kwargs: Any
|
||||
# ):
|
||||
# self.llm = llm
|
||||
# self.template = template
|
||||
# self.max_loops = max_loops
|
||||
# self.stopping_condition = stopping_condition
|
||||
# self.feedback = []
|
||||
# self.history = []
|
||||
|
||||
# def __call__(
|
||||
# self,
|
||||
# prompt,
|
||||
# **kwargs
|
||||
# ) -> str:
|
||||
# """Invoke the flow by providing a template and it's variables"""
|
||||
# response = self.llm(prompt, **kwargs)
|
||||
# return response
|
||||
|
||||
# def _check_stopping_condition(self, response: str) -> bool:
|
||||
# """Check if the stopping condition is met"""
|
||||
# if self.stopping_condition:
|
||||
# return self.stopping_condition(response)
|
||||
# return False
|
||||
|
||||
|
||||
# def provide_feedback(self, feedback: str) -> None:
|
||||
# """Allow users to to provide feedback on the responses"""
|
||||
# feedback = self.feedback.append(feedback)
|
||||
# return feedback
|
||||
|
||||
# def format_prompt(self, **kwargs: Any) -> str:
|
||||
# """Format the template with the provided kwargs using f string interpolation"""
|
||||
# return self.template.format(**kwargs)
|
||||
|
||||
# def _generate(self, formatted_prompts: str) -> str:
|
||||
# """
|
||||
# Generate a result using the lm
|
||||
|
||||
# """
|
||||
# return self.llm(formatted_prompts)
|
||||
|
||||
# def run(self, **kwargs: Any) -> str:
|
||||
# """Generate a result using the provided keyword args"""
|
||||
# prompt = self.format_prompt(**kwargs)
|
||||
# response = self._generate(prompt)
|
||||
# return response
|
||||
|
||||
# def bulk_run(
|
||||
# self,
|
||||
# inputs: List[Dict[str, Any]]
|
||||
# ) -> List[str]:
|
||||
# """Generate responses for multiple input sets"""
|
||||
# return [self.run(**input_data) for input_data in inputs]
|
||||
|
||||
# @staticmethod
|
||||
# def from_llm_and_template(llm: Any, template: str) -> "Flow":
|
||||
# """Create FlowStream from LLM and a string template"""
|
||||
# return Flow(llm=llm, template=template)
|
||||
|
||||
# @staticmethod
|
||||
# def from_llm_and_template_file(llm: Any, template_file: str) -> "Flow":
|
||||
# """Create FlowStream from LLM and a template file"""
|
||||
# with open(template_file, "r") as f:
|
||||
# template = f.read()
|
||||
|
||||
# return Flow(llm=llm, template=template)
|
||||
|
||||
|
||||
|
||||
class Flow:
|
||||
def __init__(
|
||||
self,
|
||||
llm: Any,
|
||||
template: str,
|
||||
max_loops: int = 1,
|
||||
stopping_condition: Optional[Callable[[str], bool]] = None,
|
||||
loop_interval: int = 1,
|
||||
retry_attempts: int = 3,
|
||||
retry_interval: int = 1,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.llm = llm
|
||||
self.template = template
|
||||
self.max_loops = max_loops
|
||||
self.stopping_condition = stopping_condition
|
||||
self.loop_interval = loop_interval
|
||||
self.retry_attempts = retry_attempts
|
||||
self.retry_interval = retry_interval
|
||||
self.feedback = []
|
||||
|
||||
def provide_feedback(self, feedback: str) -> None:
|
||||
"""Allow users to provide feedback on the responses."""
|
||||
self.feedback.append(feedback)
|
||||
logging.info(f"Feedback received: {feedback}")
|
||||
|
||||
def _check_stopping_condition(self, response: str) -> bool:
|
||||
"""Check if the stopping condition is met."""
|
||||
if self.stopping_condition:
|
||||
return self.stopping_condition(response)
|
||||
return False
|
||||
|
||||
def __call__(self, prompt, **kwargs) -> str:
|
||||
"""Invoke the flow by providing a template and its variables."""
|
||||
response = self.llm(prompt, **kwargs)
|
||||
return response
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> str:
|
||||
"""Format the template with the provided kwargs using f-string interpolation."""
|
||||
return self.template.format(**kwargs)
|
||||
|
||||
def _generate(self, task: str, formatted_prompts: str) -> str:
|
||||
"""
|
||||
Generate a result using the lm with optional query loops and stopping conditions.
|
||||
"""
|
||||
response = formatted_prompts
|
||||
history = [task]
|
||||
for _ in range(self.max_loops):
|
||||
if self._check_stopping_condition(response):
|
||||
break
|
||||
attempt = 0
|
||||
while attempt < self.retry_attempts:
|
||||
try:
|
||||
response = self.llm(response)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating response: {e}")
|
||||
attempt += 1
|
||||
time.sleep(self.retry_interval)
|
||||
logging.info(f"Generated response: {response}")
|
||||
history.append(response)
|
||||
time.sleep(self.loop_interval)
|
||||
return response, history
|
||||
|
||||
def run(self, **kwargs: Any) -> str:
|
||||
"""Generate a result using the provided keyword args."""
|
||||
task = self.format_prompt(**kwargs)
|
||||
response, history = self._generate(task, task)
|
||||
logging.info(f"Message history: {history}")
|
||||
return response
|
||||
|
||||
def bulk_run(self, inputs: List[Dict[str, Any]]) -> List[str]:
|
||||
"""Generate responses for multiple input sets."""
|
||||
return [self.run(**input_data) for input_data in inputs]
|
||||
|
||||
@staticmethod
|
||||
def from_llm_and_template(llm: Any, template: str) -> "Flow":
|
||||
"""Create FlowStream from LLM and a string template."""
|
||||
return Flow(llm=llm, template=template)
|
||||
|
||||
@staticmethod
|
||||
def from_llm_and_template_file(llm: Any, template_file: str) -> "Flow":
|
||||
"""Create FlowStream from LLM and a template file."""
|
||||
with open(template_file, "r") as f:
|
||||
template = f.read()
|
||||
return Flow(llm=llm, template=template)
|
||||
|
||||
|
||||
# # Configure logging
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# llm = OpenAIChat(
|
||||
# api_key="YOUR_API_KEY",
|
||||
# max_tokens=1000,
|
||||
# temperature=0.9,
|
||||
# )
|
||||
|
||||
|
||||
# def main():
|
||||
# # Initialize the Flow class with parameters
|
||||
# flow = Flow(
|
||||
# llm=llm,
|
||||
# template="Translate this to backwards: {sentence}",
|
||||
# max_loops=3,
|
||||
# stopping_condition=stop_when_repeats,
|
||||
# loop_interval=2, # Wait 2 seconds between loops
|
||||
# retry_attempts=2,
|
||||
# retry_interval=1, # Wait 1 second between retries
|
||||
# )
|
||||
|
||||
# # Predict using the Flow
|
||||
# response = flow.run(sentence="Hello, World!")
|
||||
# print("Response:", response)
|
||||
# time.sleep(1) # Pause for demonstration purposes
|
||||
|
||||
# # Provide feedback on the result
|
||||
# flow.provide_feedback("The translation was interesting!")
|
||||
# time.sleep(1) # Pause for demonstration purposes
|
||||
|
||||
# # Bulk run
|
||||
# inputs = [
|
||||
# {"sentence": "This is a test."},
|
||||
# {"sentence": "OpenAI is great."},
|
||||
# {"sentence": "GPT models are powerful."},
|
||||
# {"sentence": "stop and check if our stopping condition works."},
|
||||
# ]
|
||||
|
||||
# responses = flow.bulk_run(inputs=inputs)
|
||||
# for idx, res in enumerate(responses):
|
||||
# print(f"Input: {inputs[idx]['sentence']}, Response: {res}")
|
||||
# time.sleep(1) # Pause for demonstration purposes
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
Loading…
Reference in new issue