clean up of bingchapgpt, revgpt, cookies, json, schemas in swarms

pull/67/head
Kye 1 year ago
parent f5f95f3c0b
commit ccb239dc2b

@ -1,19 +0,0 @@
from swarms.models.bing_chat import BingChat
from swarms.workers.worker import Worker
from swarms.tools.autogpt import EdgeGPTTool, tool
# Initialize the language model,
# This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC
llm = BingChat(cookies_path="./cookies.json")
# Initialize the Worker with the custom tool
worker = Worker(
llm=llm,
ai_name="EdgeGPT Worker",
)
# Use the worker to process a task
task = "Hello, my name is ChatGPT"
response = worker.run(task)
print(response)

@ -1,6 +0,0 @@
[
{
"name": "cookie1",
"value": "1GJjj1-tM6Jlo4HFtnbocQ3r0QbQ9Aq_R65dqbcSWKzKxnN8oEMW1xa4RlsJ_nGyNjFlXQRzMWRR2GK11bve8-6n_bjF0zTczYcQQ8oDB8W66jgpIWSL7Hr4hneB0R9dIt-OQ4cVPs4eehL2lcRCObWQr0zkG14MHlH5EMwAKthv_NNIQSfThq4Ey2Hmzhq9sRuyS04JveHdLC9gfthJ8xk3J12yr7j4HsynpzmvFUcA"
}
]

@ -10,15 +10,12 @@ config = {
"plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")], "plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")],
"disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True", "disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True",
"PUID": os.getenv("REVGPT_PUID"), "PUID": os.getenv("REVGPT_PUID"),
"unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")] "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")],
} }
llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config) llm = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config)
worker = Worker( worker = Worker(ai_name="Optimus Prime", llm=llm)
ai_name="Optimus Prime",
llm=llm
)
task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times." task = "What were the winning boston marathon times for the past 5 years (ending in 2022)? Generate a table of the year, name, country of origin, and times."
response = worker.run(task) response = worker.run(task)

@ -2,18 +2,20 @@ from swarms.models.bing_chat import BingChat
from swarms.workers.worker import Worker from swarms.workers.worker import Worker
from swarms.tools.autogpt import EdgeGPTTool, tool from swarms.tools.autogpt import EdgeGPTTool, tool
from swarms.models import OpenAIChat from swarms.models import OpenAIChat
import os import os
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
# Initialize the EdgeGPTModel # Initialize the EdgeGPTModel
edgegpt = BingChat(cookies_path="./cookies.txt") edgegpt = BingChat(cookies_path="./cookies.txt")
@tool @tool
def edgegpt(task: str = None): def edgegpt(task: str = None):
"""A tool to run infrence on the EdgeGPT Model""" """A tool to run infrence on the EdgeGPT Model"""
return EdgeGPTTool.run(task) return EdgeGPTTool.run(task)
# Initialize the language model, # Initialize the language model,
# This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC # This model can be swapped out with Anthropic, ETC, Huggingface Models like Mistral, ETC
llm = OpenAIChat( llm = OpenAIChat(
@ -22,11 +24,7 @@ llm = OpenAIChat(
) )
# Initialize the Worker with the custom tool # Initialize the Worker with the custom tool
worker = Worker( worker = Worker(llm=llm, ai_name="EdgeGPT Worker", external_tools=[edgegpt])
llm=llm,
ai_name="EdgeGPT Worker",
external_tools=[edgegpt]
)
# Use the worker to process a task # Use the worker to process a task
task = "Hello, my name is ChatGPT" task = "Hello, my name is ChatGPT"

@ -14,7 +14,7 @@ config = {
"plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")], "plugin_ids": [os.getenv("REVGPT_PLUGIN_IDS")],
"disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True", "disable_history": os.getenv("REVGPT_DISABLE_HISTORY") == "True",
"PUID": os.getenv("REVGPT_PUID"), "PUID": os.getenv("REVGPT_PUID"),
"unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")] "unverified_plugin_domains": [os.getenv("REVGPT_UNVERIFIED_PLUGIN_DOMAINS")],
} }
# For v1 model # For v1 model

@ -0,0 +1,437 @@
from __future__ import annotations
import json
import time
from typing import Any, Callable, List, Optional
from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.memory import ChatMessageHistory
from langchain.prompts.chat import (
BaseChatPromptTemplate,
)
from langchain.schema import (
BaseChatMessageHistory,
Document,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.vectorstore import VectorStoreRetriever
from langchain.tools.base import BaseTool
from langchain.tools.human.tool import HumanInputRun
from langchain_experimental.autonomous_agents.autogpt.output_parser import (
AutoGPTOutputParser,
BaseAutoGPTOutputParser,
)
from langchain_experimental.autonomous_agents.autogpt.prompt import AutoGPTPrompt
from langchain_experimental.autonomous_agents.autogpt.prompt_generator import (
FINISH_NAME,
get_prompt,
)
from langchain_experimental.pydantic_v1 import BaseModel, ValidationError
# PROMPT
FINISH_NAME = "finish"
# This class has a metaclass conflict: both `BaseChatPromptTemplate` and `BaseModel`
# define a metaclass to use, and the two metaclasses attempt to define
# the same functions but in mutually-incompatible ways.
# It isn't clear how to resolve this, and this code predates mypy
# beginning to perform that check.
#
# Mypy errors:
# ```
# Definition of "__private_attributes__" in base class "BaseModel" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__repr_name__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__pretty__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__repr_str__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Definition of "__rich_repr__" in base class "Representation" is
# incompatible with definition in base class "BaseModel" [misc]
# Metaclass conflict: the metaclass of a derived class must be
# a (non-strict) subclass of the metaclasses of all its bases [misc]
# ```
#
# TODO: look into refactoring this class in a way that avoids the mypy type errors
class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
"""Prompt for AutoGPT."""
ai_name: str
ai_role: str
tools: List[BaseTool]
token_counter: Callable[[str], int]
send_token_limit: int = 4196
def construct_full_prompt(self, goals: List[str]) -> str:
prompt_start = (
"Your decisions must always be made independently "
"without seeking user assistance.\n"
"Play to your strengths as an LLM and pursue simple "
"strategies with no legal complications.\n"
"If you have completed all your tasks, make sure to "
'use the "finish" command.'
)
# Construct full prompt
full_prompt = (
f"You are {self.ai_name}, {self.ai_role}\n{prompt_start}\n\nGOALS:\n\n"
)
for i, goal in enumerate(goals):
full_prompt += f"{i+1}. {goal}\n"
full_prompt += f"\n\n{get_prompt(self.tools)}"
return full_prompt
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
base_prompt = SystemMessage(content=self.construct_full_prompt(kwargs["goals"]))
time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}"
)
used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
time_prompt.content
)
memory: VectorStoreRetriever = kwargs["memory"]
previous_messages = kwargs["messages"]
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
relevant_memory = [d.page_content for d in relevant_docs]
relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory]
)
while used_tokens + relevant_memory_tokens > 2500:
relevant_memory = relevant_memory[:-1]
relevant_memory_tokens = sum(
[self.token_counter(doc) for doc in relevant_memory]
)
content_format = (
f"This reminds you of these events "
f"from your past:\n{relevant_memory}\n\n"
)
memory_message = SystemMessage(content=content_format)
used_tokens += self.token_counter(memory_message.content)
historical_messages: List[BaseMessage] = []
for message in previous_messages[-10:][::-1]:
message_tokens = self.token_counter(message.content)
if used_tokens + message_tokens > self.send_token_limit - 1000:
break
historical_messages = [message] + historical_messages
used_tokens += message_tokens
input_message = HumanMessage(content=kwargs["user_input"])
messages: List[BaseMessage] = [base_prompt, time_prompt, memory_message]
messages += historical_messages
messages.append(input_message)
return messages
class PromptGenerator:
"""A class for generating custom prompt strings.
Does this based on constraints, commands, resources, and performance evaluations.
"""
def __init__(self) -> None:
"""Initialize the PromptGenerator object.
Starts with empty lists of constraints, commands, resources,
and performance evaluations.
"""
self.constraints: List[str] = []
self.commands: List[BaseTool] = []
self.resources: List[str] = []
self.performance_evaluation: List[str] = []
self.response_format = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
def add_constraint(self, constraint: str) -> None:
"""
Add a constraint to the constraints list.
Args:
constraint (str): The constraint to be added.
"""
self.constraints.append(constraint)
def add_tool(self, tool: BaseTool) -> None:
self.commands.append(tool)
def _generate_command_string(self, tool: BaseTool) -> str:
output = f"{tool.name}: {tool.description}"
output += f", args json schema: {json.dumps(tool.args)}"
return output
def add_resource(self, resource: str) -> None:
"""
Add a resource to the resources list.
Args:
resource (str): The resource to be added.
"""
self.resources.append(resource)
def add_performance_evaluation(self, evaluation: str) -> None:
"""
Add a performance evaluation item to the performance_evaluation list.
Args:
evaluation (str): The evaluation item to be added.
"""
self.performance_evaluation.append(evaluation)
def _generate_numbered_list(self, items: list, item_type: str = "list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = [
f"{i + 1}. {self._generate_command_string(item)}"
for i, item in enumerate(items)
]
finish_description = (
"use this to signal that you have finished all your objectives"
)
finish_args = (
'"response": "final response to let '
'people know you have finished your objectives"'
)
finish_string = (
f"{len(items) + 1}. {FINISH_NAME}: "
f"{finish_description}, args: {finish_args}"
)
return "\n".join(command_strings + [finish_string])
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_prompt_string(self) -> str:
"""Generate a prompt string.
Returns:
str: The generated prompt string.
"""
formatted_response_format = json.dumps(self.response_format, indent=4)
prompt_string = (
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
f"Commands:\n"
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
f"Performance Evaluation:\n"
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
f"You should only respond in JSON format as described below "
f"\nResponse Format: \n{formatted_response_format} "
f"\nEnsure the response can be parsed by Python json.loads"
)
return prompt_string
def get_prompt(tools: List[BaseTool]) -> str:
"""Generates a prompt string.
It includes various constraints, commands, resources, and performance evaluations.
Returns:
str: The generated prompt string.
"""
# Initialize the PromptGenerator object
prompt_generator = PromptGenerator()
# Add constraints to the PromptGenerator object
prompt_generator.add_constraint(
"~16000 word limit for short term memory. "
"Your short term memory is short, "
"so immediately save important information to files."
)
prompt_generator.add_constraint(
"If you are unsure how you previously did something "
"or want to recall past events, "
"thinking about similar events will help you remember."
)
prompt_generator.add_constraint("No user assistance")
prompt_generator.add_constraint(
'Exclusively use the commands listed in double quotes e.g. "command name"'
)
# Add commands to the PromptGenerator object
for tool in tools:
prompt_generator.add_tool(tool)
# Add resources to the PromptGenerator object
prompt_generator.add_resource(
"Internet access for searches and information gathering."
)
prompt_generator.add_resource("Long Term memory management.")
prompt_generator.add_resource(
"GPT-3.5 powered Agents for delegation of simple tasks."
)
prompt_generator.add_resource("File output.")
# Add performance evaluations to the PromptGenerator object
prompt_generator.add_performance_evaluation(
"Continuously review and analyze your actions "
"to ensure you are performing to the best of your abilities."
)
prompt_generator.add_performance_evaluation(
"Constructively self-criticize your big-picture behavior constantly."
)
prompt_generator.add_performance_evaluation(
"Reflect on past decisions and strategies to refine your approach."
)
prompt_generator.add_performance_evaluation(
"Every command has a cost, so be smart and efficient. "
"Aim to complete tasks in the least number of steps."
)
# Generate the prompt string
prompt_string = prompt_generator.generate_prompt_string()
return prompt_string
class AutoGPT:
"""
AutoAgent:
Args:
"""
def __init__(
self,
ai_name: str,
memory: VectorStoreRetriever,
chain: LLMChain,
output_parser: BaseAutoGPTOutputParser,
tools: List[BaseTool],
feedback_tool: Optional[HumanInputRun] = None,
chat_history_memory: Optional[BaseChatMessageHistory] = None,
):
self.ai_name = ai_name
self.memory = memory
self.next_action_count = 0
self.chain = chain
self.output_parser = output_parser
self.tools = tools
self.feedback_tool = feedback_tool
self.chat_history_memory = chat_history_memory or ChatMessageHistory()
@classmethod
def from_llm_and_tools(
cls,
ai_name: str,
ai_role: str,
memory: VectorStoreRetriever,
tools: List[BaseTool],
llm: BaseChatModel,
human_in_the_loop: bool = False,
output_parser: Optional[BaseAutoGPTOutputParser] = None,
chat_history_memory: Optional[BaseChatMessageHistory] = None,
) -> AutoGPT:
prompt = AutoGPTPrompt(
ai_name=ai_name,
ai_role=ai_role,
tools=tools,
input_variables=["memory", "messages", "goals", "user_input"],
token_counter=llm.get_num_tokens,
)
human_feedback_tool = HumanInputRun() if human_in_the_loop else None
chain = LLMChain(llm=llm, prompt=prompt)
return cls(
ai_name,
memory,
chain,
output_parser or AutoGPTOutputParser(),
tools,
feedback_tool=human_feedback_tool,
chat_history_memory=chat_history_memory,
)
def run(self, goals: List[str]) -> str:
user_input = (
"Determine which next command to use, "
"and respond using the format specified above:"
)
# Interaction Loop
loop_count = 0
while True:
# Discontinue if continuous limit is reached
loop_count += 1
# Send message to AI, get response
assistant_reply = self.chain.run(
goals=goals,
messages=self.chat_history_memory.messages,
memory=self.memory,
user_input=user_input,
)
# Print Assistant thoughts
print(assistant_reply)
self.chat_history_memory.add_message(HumanMessage(content=user_input))
self.chat_history_memory.add_message(AIMessage(content=assistant_reply))
# Get command name and arguments
action = self.output_parser.parse(assistant_reply)
tools = {t.name: t for t in self.tools}
if action.name == FINISH_NAME:
return action.args["response"]
if action.name in tools:
tool = tools[action.name]
try:
observation = tool.run(action.args)
except ValidationError as e:
observation = (
f"Validation Error in args: {str(e)}, args: {action.args}"
)
except Exception as e:
observation = (
f"Error: {str(e)}, {type(e).__name__}, args: {action.args}"
)
result = f"Command {tool.name} returned: {observation}"
elif action.name == "ERROR":
result = f"Error: {action.args}. "
else:
result = (
f"Unknown command '{action.name}'. "
f"Please refer to the 'COMMANDS' list for available "
f"commands and only respond in the specified JSON format."
)
memory_to_add = (
f"Assistant Reply: {assistant_reply} " f"\nResult: {result} "
)
if self.feedback_tool is not None:
feedback = f"\n{self.feedback_tool.run('Input: ')}"
if feedback in {"q", "stop"}:
print("EXITING")
return "EXITING"
memory_to_add += feedback
self.memory.add_documents([Document(page_content=memory_to_add)])
self.chat_history_memory.add_message(SystemMessage(content=result))

@ -0,0 +1,7 @@
"""
Data Loaders for APPS
TODO: Clean up all the llama index stuff, remake the logic from scratch
"""

@ -0,0 +1,103 @@
from typing import List, Optional
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
class AsanaReader(BaseReader):
"""Asana reader. Reads data from an Asana workspace.
Args:
asana_token (str): Asana token.
"""
def __init__(self, asana_token: str) -> None:
"""Initialize Asana reader."""
import asana
self.client = asana.Client.access_token(asana_token)
def load_data(
self, workspace_id: Optional[str] = None, project_id: Optional[str] = None
) -> List[Document]:
"""Load data from the workspace.
Args:
workspace_id (Optional[str], optional): Workspace ID. Defaults to None.
project_id (Optional[str], optional): Project ID. Defaults to None.
Returns:
List[Document]: List of documents.
"""
if workspace_id is None and project_id is None:
raise ValueError("Either workspace_id or project_id must be provided")
if workspace_id is not None and project_id is not None:
raise ValueError(
"Only one of workspace_id or project_id should be provided"
)
results = []
if workspace_id is not None:
workspace_name = self.client.workspaces.find_by_id(workspace_id)["name"]
projects = self.client.projects.find_all({"workspace": workspace_id})
# Case: Only project_id is provided
else: # since we've handled the other cases, this means project_id is not None
projects = [self.client.projects.find_by_id(project_id)]
workspace_name = projects[0]["workspace"]["name"]
for project in projects:
tasks = self.client.tasks.find_all(
{
"project": project["gid"],
"opt_fields": "name,notes,completed,completed_at,completed_by,assignee,followers,custom_fields",
}
)
for task in tasks:
stories = self.client.tasks.stories(task["gid"], opt_fields="type,text")
comments = "\n".join(
[
story["text"]
for story in stories
if story.get("type") == "comment" and "text" in story
]
)
task_metadata = {
"task_id": task.get("gid", ""),
"name": task.get("name", ""),
"assignee": (task.get("assignee") or {}).get("name", ""),
"completed_on": task.get("completed_at", ""),
"completed_by": (task.get("completed_by") or {}).get("name", ""),
"project_name": project.get("name", ""),
"custom_fields": [
i["display_value"]
for i in task.get("custom_fields")
if task.get("custom_fields") is not None
],
"workspace_name": workspace_name,
"url": f"https://app.asana.com/0/{project['gid']}/{task['gid']}",
}
if task.get("followers") is not None:
task_metadata["followers"] = [
i.get("name") for i in task.get("followers") if "name" in i
]
else:
task_metadata["followers"] = []
results.append(
Document(
text=task.get("name", "")
+ " "
+ task.get("notes", "")
+ " "
+ comments,
extra_info=task_metadata,
)
)
return results

@ -0,0 +1,622 @@
"""Base schema for data structures."""
import json
import textwrap
import uuid
from abc import abstractmethod
from enum import Enum, auto
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator
from llama_index.utils import SAMPLE_TEXT, truncate_text
from typing_extensions import Self
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
from semantic_kernel.memory.memory_record import MemoryRecord
####
DEFAULT_TEXT_NODE_TMPL = "{metadata_str}\n\n{content}"
DEFAULT_METADATA_TMPL = "{key}: {value}"
# NOTE: for pretty printing
TRUNCATE_LENGTH = 350
WRAP_WIDTH = 70
class BaseComponent(BaseModel):
"""Base component object to capture class names."""
@classmethod
@abstractmethod
def class_name(cls) -> str:
"""
Get the class name, used as a unique ID in serialization.
This provides a key that makes serialization robust against actual class
name changes.
"""
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
# TODO: return type here not supported by current mypy version
@classmethod
def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore
if isinstance(kwargs, dict):
data.update(kwargs)
data.pop("class_name", None)
return cls(**data)
@classmethod
def from_json(cls, data_str: str, **kwargs: Any) -> Self: # type: ignore
data = json.loads(data_str)
return cls.from_dict(data, **kwargs)
class NodeRelationship(str, Enum):
"""Node relationships used in `BaseNode` class.
Attributes:
SOURCE: The node is the source document.
PREVIOUS: The node is the previous node in the document.
NEXT: The node is the next node in the document.
PARENT: The node is the parent node in the document.
CHILD: The node is a child node in the document.
"""
SOURCE = auto()
PREVIOUS = auto()
NEXT = auto()
PARENT = auto()
CHILD = auto()
class ObjectType(str, Enum):
TEXT = auto()
IMAGE = auto()
INDEX = auto()
DOCUMENT = auto()
class MetadataMode(str, Enum):
ALL = auto()
EMBED = auto()
LLM = auto()
NONE = auto()
class RelatedNodeInfo(BaseComponent):
node_id: str
node_type: Optional[ObjectType] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
hash: Optional[str] = None
@classmethod
def class_name(cls) -> str:
return "RelatedNodeInfo"
RelatedNodeType = Union[RelatedNodeInfo, List[RelatedNodeInfo]]
# Node classes for indexes
class BaseNode(BaseComponent):
"""Base node Object.
Generic abstract interface for retrievable nodes
"""
class Config:
allow_population_by_field_name = True
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the node."
)
embedding: Optional[List[float]] = Field(
default=None, description="Embedding of the node."
)
""""
metadata fields
- injected as part of the text shown to LLMs as context
- injected as part of the text for generating embeddings
- used by vector DBs for metadata filtering
"""
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="A flat dictionary of metadata fields",
alias="extra_info",
)
excluded_embed_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are excluded from text for the embed model.",
)
excluded_llm_metadata_keys: List[str] = Field(
default_factory=list,
description="Metadata keys that are excluded from text for the LLM.",
)
relationships: Dict[NodeRelationship, RelatedNodeType] = Field(
default_factory=dict,
description="A mapping of relationships to other node information.",
)
hash: str = Field(default="", description="Hash of the node content.")
@classmethod
@abstractmethod
def get_type(cls) -> str:
"""Get Object type."""
@abstractmethod
def get_content(self, metadata_mode: MetadataMode = MetadataMode.ALL) -> str:
"""Get object content."""
@abstractmethod
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata string."""
@abstractmethod
def set_content(self, value: Any) -> None:
"""Set the content of the node."""
@property
def node_id(self) -> str:
return self.id_
@node_id.setter
def node_id(self, value: str) -> None:
self.id_ = value
@property
def source_node(self) -> Optional[RelatedNodeInfo]:
"""Source object node.
Extracted from the relationships field.
"""
if NodeRelationship.SOURCE not in self.relationships:
return None
relation = self.relationships[NodeRelationship.SOURCE]
if isinstance(relation, list):
raise ValueError("Source object must be a single RelatedNodeInfo object")
return relation
@property
def prev_node(self) -> Optional[RelatedNodeInfo]:
"""Prev node."""
if NodeRelationship.PREVIOUS not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PREVIOUS]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Previous object must be a single RelatedNodeInfo object")
return relation
@property
def next_node(self) -> Optional[RelatedNodeInfo]:
"""Next node."""
if NodeRelationship.NEXT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.NEXT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Next object must be a single RelatedNodeInfo object")
return relation
@property
def parent_node(self) -> Optional[RelatedNodeInfo]:
"""Parent node."""
if NodeRelationship.PARENT not in self.relationships:
return None
relation = self.relationships[NodeRelationship.PARENT]
if not isinstance(relation, RelatedNodeInfo):
raise ValueError("Parent object must be a single RelatedNodeInfo object")
return relation
@property
def child_nodes(self) -> Optional[List[RelatedNodeInfo]]:
"""Child nodes."""
if NodeRelationship.CHILD not in self.relationships:
return None
relation = self.relationships[NodeRelationship.CHILD]
if not isinstance(relation, list):
raise ValueError("Child objects must be a list of RelatedNodeInfo objects.")
return relation
@property
def ref_doc_id(self) -> Optional[str]:
"""Deprecated: Get ref doc id."""
source_node = self.source_node
if source_node is None:
return None
return source_node.node_id
@property
def extra_info(self) -> Dict[str, Any]:
"""TODO: DEPRECATED: Extra info."""
return self.metadata
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Node ID: {self.node_id}\n{source_text_wrapped}"
def get_embedding(self) -> List[float]:
"""Get embedding.
Errors if embedding is None.
"""
if self.embedding is None:
raise ValueError("embedding not set.")
return self.embedding
def as_related_node_info(self) -> RelatedNodeInfo:
"""Get node as RelatedNodeInfo."""
return RelatedNodeInfo(
node_id=self.node_id,
node_type=self.get_type(),
metadata=self.metadata,
hash=self.hash,
)
class TextNode(BaseNode):
text: str = Field(default="", description="Text content of the node.")
start_char_idx: Optional[int] = Field(
default=None, description="Start char index of the node."
)
end_char_idx: Optional[int] = Field(
default=None, description="End char index of the node."
)
text_template: str = Field(
default=DEFAULT_TEXT_NODE_TMPL,
description=(
"Template for how text is formatted, with {content} and "
"{metadata_str} placeholders."
),
)
metadata_template: str = Field(
default=DEFAULT_METADATA_TMPL,
description=(
"Template for how metadata is formatted, with {key} and "
"{value} placeholders."
),
)
metadata_seperator: str = Field(
default="\n",
description="Separator between metadata fields when converting to string.",
)
@classmethod
def class_name(cls) -> str:
return "TextNode"
@root_validator
def _check_hash(cls, values: dict) -> dict:
"""Generate a hash to represent the node."""
text = values.get("text", "")
metadata = values.get("metadata", {})
doc_identity = str(text) + str(metadata)
values["hash"] = str(
sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()
)
return values
@classmethod
def get_type(cls) -> str:
"""Get Object type."""
return ObjectType.TEXT
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
"""Get object content."""
metadata_str = self.get_metadata_str(mode=metadata_mode).strip()
if not metadata_str:
return self.text
return self.text_template.format(
content=self.text, metadata_str=metadata_str
).strip()
def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata info string."""
if mode == MetadataMode.NONE:
return ""
usable_metadata_keys = set(self.metadata.keys())
if mode == MetadataMode.LLM:
for key in self.excluded_llm_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
elif mode == MetadataMode.EMBED:
for key in self.excluded_embed_metadata_keys:
if key in usable_metadata_keys:
usable_metadata_keys.remove(key)
return self.metadata_seperator.join(
[
self.metadata_template.format(key=key, value=str(value))
for key, value in self.metadata.items()
if key in usable_metadata_keys
]
)
def set_content(self, value: str) -> None:
"""Set the content of the node."""
self.text = value
def get_node_info(self) -> Dict[str, Any]:
"""Get node info."""
return {"start": self.start_char_idx, "end": self.end_char_idx}
def get_text(self) -> str:
return self.get_content(metadata_mode=MetadataMode.NONE)
@property
def node_info(self) -> Dict[str, Any]:
"""Deprecated: Get node info."""
return self.get_node_info()
# TODO: legacy backport of old Node class
Node = TextNode
class ImageNode(TextNode):
"""Node with image."""
# TODO: store reference instead of actual image
# base64 encoded image str
image: Optional[str] = None
@classmethod
def get_type(cls) -> str:
return ObjectType.IMAGE
@classmethod
def class_name(cls) -> str:
return "ImageNode"
class IndexNode(TextNode):
"""Node with reference to any object.
This can include other indices, query engines, retrievers.
This can also include other nodes (though this is overlapping with `relationships`
on the Node class).
"""
index_id: str
@classmethod
def from_text_node(
cls,
node: TextNode,
index_id: str,
) -> "IndexNode":
"""Create index node from text node."""
# copy all attributes from text node, add index id
return cls(
**node.dict(),
index_id=index_id,
)
@classmethod
def get_type(cls) -> str:
return ObjectType.INDEX
@classmethod
def class_name(cls) -> str:
return "IndexNode"
class NodeWithScore(BaseComponent):
node: BaseNode
score: Optional[float] = None
def __str__(self) -> str:
return f"{self.node}\nScore: {self.score: 0.3f}\n"
def get_score(self, raise_error: bool = False) -> float:
"""Get score."""
if self.score is None:
if raise_error:
raise ValueError("Score not set.")
else:
return 0.0
else:
return self.score
@classmethod
def class_name(cls) -> str:
return "NodeWithScore"
##### pass through methods to BaseNode #####
@property
def node_id(self) -> str:
return self.node.node_id
@property
def id_(self) -> str:
return self.node.id_
@property
def text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.text
else:
raise ValueError("Node must be a TextNode to get text.")
@property
def metadata(self) -> Dict[str, Any]:
return self.node.metadata
@property
def embedding(self) -> Optional[List[float]]:
return self.node.embedding
def get_text(self) -> str:
if isinstance(self.node, TextNode):
return self.node.get_text()
else:
raise ValueError("Node must be a TextNode to get text.")
def get_content(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return self.node.get_content(metadata_mode=metadata_mode)
def get_embedding(self) -> List[float]:
return self.node.get_embedding()
# Document Classes for Readers
class Document(TextNode):
"""Generic interface for a data document.
This document connects to data sources.
"""
# TODO: A lot of backwards compatibility logic here, clean up
id_: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Unique ID of the node.",
alias="doc_id",
)
_compat_fields = {"doc_id": "id_", "extra_info": "metadata"}
@classmethod
def get_type(cls) -> str:
"""Get Document type."""
return ObjectType.DOCUMENT
@property
def doc_id(self) -> str:
"""Get document ID."""
return self.id_
def __str__(self) -> str:
source_text_truncated = truncate_text(
self.get_content().strip(), TRUNCATE_LENGTH
)
source_text_wrapped = textwrap.fill(
f"Text: {source_text_truncated}\n", width=WRAP_WIDTH
)
return f"Doc ID: {self.doc_id}\n{source_text_wrapped}"
def get_doc_id(self) -> str:
"""TODO: Deprecated: Get document ID."""
return self.id_
def __setattr__(self, name: str, value: object) -> None:
if name in self._compat_fields:
name = self._compat_fields[name]
super().__setattr__(name, value)
def to_langchain_format(self) -> "LCDocument":
"""Convert struct to LangChain document format."""
from llama_index.bridge.langchain import Document as LCDocument
metadata = self.metadata or {}
return LCDocument(page_content=self.text, metadata=metadata)
@classmethod
def from_langchain_format(cls, doc: "LCDocument") -> "Document":
"""Convert struct from LangChain document format."""
return cls(text=doc.page_content, metadata=doc.metadata)
def to_haystack_format(self) -> "HaystackDocument":
"""Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
return HaystackDocument(
content=self.text, meta=self.metadata, embedding=self.embedding, id=self.id_
)
@classmethod
def from_haystack_format(cls, doc: "HaystackDocument") -> "Document":
"""Convert struct from Haystack document format."""
return cls(
text=doc.content, metadata=doc.meta, embedding=doc.embedding, id_=doc.id
)
def to_embedchain_format(self) -> Dict[str, Any]:
"""Convert struct to EmbedChain document format."""
return {
"doc_id": self.id_,
"data": {"content": self.text, "meta_data": self.metadata},
}
@classmethod
def from_embedchain_format(cls, doc: Dict[str, Any]) -> "Document":
"""Convert struct from EmbedChain document format."""
return cls(
text=doc["data"]["content"],
metadata=doc["data"]["meta_data"],
id_=doc["doc_id"],
)
def to_semantic_kernel_format(self) -> "MemoryRecord":
"""Convert struct to Semantic Kernel document format."""
import numpy as np
from semantic_kernel.memory.memory_record import MemoryRecord
return MemoryRecord(
id=self.id_,
text=self.text,
additional_metadata=self.get_metadata_str(),
embedding=np.array(self.embedding) if self.embedding else None,
)
@classmethod
def from_semantic_kernel_format(cls, doc: "MemoryRecord") -> "Document":
"""Convert struct from Semantic Kernel document format."""
return cls(
text=doc._text,
metadata={"additional_metadata": doc._additional_metadata},
embedding=doc._embedding.tolist() if doc._embedding is not None else None,
id_=doc._id,
)
@classmethod
def example(cls) -> "Document":
return Document(
text=SAMPLE_TEXT,
metadata={"filename": "README.md", "category": "codebase"},
)
@classmethod
def class_name(cls) -> str:
return "Document"
class ImageDocument(Document):
"""Data document containing an image."""
# base64 encoded image str
image: Optional[str] = None
@classmethod
def class_name(cls) -> str:
return "ImageDocument"

@ -46,6 +46,7 @@ from swarms.utils.revutils import get_input
bcolors = t.Colors() bcolors = t.Colors()
def generate_random_hex(length: int = 17) -> str: def generate_random_hex(length: int = 17) -> str:
"""Generate a random hex string """Generate a random hex string
@ -121,7 +122,6 @@ def logger(is_timed: bool) -> function:
BASE_URL = environ.get("CHATGPT_BASE_URL", "http://bypass.bzff.cn:9090/") BASE_URL = environ.get("CHATGPT_BASE_URL", "http://bypass.bzff.cn:9090/")
def captcha_solver(images: list[str], challenge_details: dict) -> int: def captcha_solver(images: list[str], challenge_details: dict) -> int:
# Create tempfile # Create tempfile
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
@ -197,40 +197,40 @@ def get_arkose_token(
raise Exception("Failed to verify captcha") raise Exception("Failed to verify captcha")
return resp_json.get("token") return resp_json.get("token")
# else: # else:
# working_endpoints: list[str] = [] # working_endpoints: list[str] = []
# # Check uptime for different endpoints via gatus # # Check uptime for different endpoints via gatus
# resp2: list[dict] = requests.get( # resp2: list[dict] = requests.get(
# "https://stats.churchless.tech/api/v1/endpoints/statuses?page=1" # "https://stats.churchless.tech/api/v1/endpoints/statuses?page=1"
# ).json() # ).json()
# for endpoint in resp2: # for endpoint in resp2:
# # print(endpoint.get("name")) # # print(endpoint.get("name"))
# if endpoint.get("group") != "Arkose Labs": # if endpoint.get("group") != "Arkose Labs":
# continue # continue
# # Check the last 5 results # # Check the last 5 results
# results: list[dict] = endpoint.get("results", [])[-5:-1] # results: list[dict] = endpoint.get("results", [])[-5:-1]
# # print(results) # # print(results)
# if not results: # if not results:
# print(f"Endpoint {endpoint.get('name')} has no results") # print(f"Endpoint {endpoint.get('name')} has no results")
# continue # continue
# # Check if all the results are up # # Check if all the results are up
# if all(result.get("success") == True for result in results): # if all(result.get("success") == True for result in results):
# working_endpoints.append(endpoint.get("name")) # working_endpoints.append(endpoint.get("name"))
# if not working_endpoints: # if not working_endpoints:
# print("No working endpoints found. Please solve the captcha manually.\n找不到工作终结点。请手动解决captcha") # print("No working endpoints found. Please solve the captcha manually.\n找不到工作终结点。请手动解决captcha")
# return get_arkose_token(download_images=True, captcha_supported=False) # return get_arkose_token(download_images=True, captcha_supported=False)
# # Choose a random endpoint # # Choose a random endpoint
# endpoint = random.choice(working_endpoints) # endpoint = random.choice(working_endpoints)
# resp: requests.Response = requests.get(endpoint) # resp: requests.Response = requests.get(endpoint)
# if resp.status_code != 200: # if resp.status_code != 200:
# if resp.status_code != 511: # if resp.status_code != 511:
# raise Exception("Failed to get captcha token") # raise Exception("Failed to get captcha token")
# else: # else:
# print("需要验证码请手动解决captcha.") # print("需要验证码请手动解决captcha.")
# return get_arkose_token(download_images=True, captcha_supported=True) # return get_arkose_token(download_images=True, captcha_supported=True)
# try: # try:
# return resp.json().get("token") # return resp.json().get("token")
# except Exception: # except Exception:
# return resp.text # return resp.text
class Chatbot: class Chatbot:
@ -1751,6 +1751,7 @@ if __name__ == "__main__":
) )
main(configure()) main(configure())
class RevChatGPTModelv1: class RevChatGPTModelv1:
def __init__(self, access_token=None, **kwargs): def __init__(self, access_token=None, **kwargs):
super().__init__() super().__init__()
@ -1764,7 +1765,7 @@ class RevChatGPTModelv1:
self.start_time = time.time() self.start_time = time.time()
prev_text = "" prev_text = ""
for data in self.chatbot.ask(task, fileinfo=None): for data in self.chatbot.ask(task, fileinfo=None):
message = data["message"][len(prev_text):] message = data["message"][len(prev_text) :]
prev_text = data["message"] prev_text = data["message"]
self.end_time = time.time() self.end_time = time.time()
return prev_text return prev_text
@ -1779,11 +1780,16 @@ class RevChatGPTModelv1:
def list_plugins(self): def list_plugins(self):
return self.chatbot.get_plugins() return self.chatbot.get_plugins()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Manage RevChatGPT plugins.') parser = argparse.ArgumentParser(description="Manage RevChatGPT plugins.")
parser.add_argument('--enable', metavar='plugin_id', help='the plugin to enable') parser.add_argument("--enable", metavar="plugin_id", help="the plugin to enable")
parser.add_argument('--list', action='store_true', help='list all available plugins') parser.add_argument(
parser.add_argument('--access_token', required=True, help='access token for RevChatGPT') "--list", action="store_true", help="list all available plugins"
)
parser.add_argument(
"--access_token", required=True, help="access token for RevChatGPT"
)
args = parser.parse_args() args = parser.parse_args()
@ -1795,4 +1801,3 @@ if __name__ == "__main__":
plugins = model.list_plugins() plugins = model.list_plugins()
for plugin in plugins: for plugin in plugins:
print(f"Plugin ID: {plugin['id']}, Name: {plugin['name']}") print(f"Plugin ID: {plugin['id']}, Name: {plugin['name']}")

@ -1,4 +1,4 @@
#4v image recognition # 4v image recognition
""" """
Standard ChatGPT Standard ChatGPT
""" """
@ -26,6 +26,7 @@ from pathlib import Path
import tempfile import tempfile
import random import random
import os import os
# Import function type # Import function type
import httpx import httpx
@ -46,7 +47,7 @@ from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.history import InMemoryHistory
from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.key_binding import KeyBindings
from schemas.typings import Colors from schemas.typings import Colors
bindings = KeyBindings() bindings = KeyBindings()
@ -56,6 +57,7 @@ BASE_URL = environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/")
bcolors = t.Colors() bcolors = t.Colors()
def create_keybindings(key: str = "c-@") -> KeyBindings: def create_keybindings(key: str = "c-@") -> KeyBindings:
""" """
Create keybindings for prompt_toolkit. Default key is ctrl+space. Create keybindings for prompt_toolkit. Default key is ctrl+space.
@ -136,6 +138,7 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any:
# Only return specified keys that are in class_keys # Only return specified keys that are in class_keys
return {key for key in keys if key in class_keys} return {key for key in keys if key in class_keys}
def generate_random_hex(length: int = 17) -> str: def generate_random_hex(length: int = 17) -> str:
"""Generate a random hex string """Generate a random hex string
Args: Args:
@ -202,8 +205,6 @@ def logger(is_timed: bool):
return decorator return decorator
bcolors = Colors() bcolors = Colors()
@ -284,7 +285,6 @@ def get_arkose_token(
return resp_json.get("token") return resp_json.get("token")
class Chatbot: class Chatbot:
""" """
Chatbot class for ChatGPT Chatbot class for ChatGPT
@ -636,7 +636,7 @@ class Chatbot:
yield { yield {
"author": author, "author": author,
"message": message, "message": message,
"conversation_id": cid+'***************************', "conversation_id": cid + "***************************",
"parent_id": pid, "parent_id": pid,
"model": model, "model": model,
"finish_details": finish_details, "finish_details": finish_details,
@ -711,7 +711,6 @@ class Chatbot:
if not conversation_id and not parent_id: if not conversation_id and not parent_id:
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
if conversation_id and not parent_id: if conversation_id and not parent_id:
if conversation_id not in self.conversation_mapping: if conversation_id not in self.conversation_mapping:
print(conversation_id) print(conversation_id)
@ -735,8 +734,8 @@ class Chatbot:
print( print(
"Warning: Invalid conversation_id provided, treat as a new conversation", "Warning: Invalid conversation_id provided, treat as a new conversation",
) )
#conversation_id = None # conversation_id = None
conversation_id =str(uuid.uuid4()) conversation_id = str(uuid.uuid4())
print(conversation_id) print(conversation_id)
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
model = model or self.config.get("model") or "text-davinci-002-render-sha" model = model or self.config.get("model") or "text-davinci-002-render-sha"
@ -762,7 +761,7 @@ class Chatbot:
def ask( def ask(
self, self,
prompt: str, prompt: str,
fileinfo: dict , fileinfo: dict,
conversation_id: str | None = None, conversation_id: str | None = None,
parent_id: str = "", parent_id: str = "",
model: str = "", model: str = "",
@ -795,7 +794,10 @@ class Chatbot:
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"role": "user", "role": "user",
"author": {"role": "user"}, "author": {"role": "user"},
"content": {"content_type": "multimodal_text", "parts": [prompt, fileinfo]}, "content": {
"content_type": "multimodal_text",
"parts": [prompt, fileinfo],
},
}, },
] ]
@ -871,7 +873,7 @@ class Chatbot:
parent_id = self.conversation_mapping[conversation_id] parent_id = self.conversation_mapping[conversation_id]
else: # invalid conversation_id provided, treat as a new conversation else: # invalid conversation_id provided, treat as a new conversation
conversation_id = None conversation_id = None
conversation_id=str(uuid.uuid4()) conversation_id = str(uuid.uuid4())
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
model = model or self.config.get("model") or "text-davinci-002-render-sha" model = model or self.config.get("model") or "text-davinci-002-render-sha"
data = { data = {
@ -1304,7 +1306,7 @@ class AsyncChatbot(Chatbot):
print( print(
"Warning: Invalid conversation_id provided, treat as a new conversation", "Warning: Invalid conversation_id provided, treat as a new conversation",
) )
#conversation_id = None # conversation_id = None
conversation_id = str(uuid.uuid4()) conversation_id = str(uuid.uuid4())
print(conversation_id) print(conversation_id)
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
@ -1363,12 +1365,18 @@ class AsyncChatbot(Chatbot):
{ {
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"author": {"role": "user"}, "author": {"role": "user"},
"content": {"content_type": "multimodal_text", "parts": [prompt, { "content": {
"asset_pointer": "file-service://file-V9IZRkWQnnk1HdHsBKAdoiGf", "content_type": "multimodal_text",
"size_bytes": 239505, "parts": [
"width": 1706, prompt,
"height": 1280 {
}]}, "asset_pointer": "file-service://file-V9IZRkWQnnk1HdHsBKAdoiGf",
"size_bytes": 239505,
"width": 1706,
"height": 1280,
},
],
},
}, },
] ]
@ -1763,6 +1771,7 @@ if __name__ == "__main__":
) )
main(configure()) main(configure())
class RevChatGPTModelv4: class RevChatGPTModelv4:
def __init__(self, access_token=None, **kwargs): def __init__(self, access_token=None, **kwargs):
super().__init__() super().__init__()
@ -1776,7 +1785,7 @@ class RevChatGPTModelv4:
self.start_time = time.time() self.start_time = time.time()
prev_text = "" prev_text = ""
for data in self.chatbot.ask(task, fileinfo=None): for data in self.chatbot.ask(task, fileinfo=None):
message = data["message"][len(prev_text):] message = data["message"][len(prev_text) :]
prev_text = data["message"] prev_text = data["message"]
self.end_time = time.time() self.end_time = time.time()
return prev_text return prev_text
@ -1791,11 +1800,16 @@ class RevChatGPTModelv4:
def list_plugins(self): def list_plugins(self):
return self.chatbot.get_plugins() return self.chatbot.get_plugins()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Manage RevChatGPT plugins.') parser = argparse.ArgumentParser(description="Manage RevChatGPT plugins.")
parser.add_argument('--enable', metavar='plugin_id', help='the plugin to enable') parser.add_argument("--enable", metavar="plugin_id", help="the plugin to enable")
parser.add_argument('--list', action='store_true', help='list all available plugins') parser.add_argument(
parser.add_argument('--access_token', required=True, help='access token for RevChatGPT') "--list", action="store_true", help="list all available plugins"
)
parser.add_argument(
"--access_token", required=True, help="access token for RevChatGPT"
)
args = parser.parse_args() args = parser.parse_args()

@ -1,2 +1,2 @@
# from swarms.structs.workflow import Workflow from swarms.structs.workflow import Workflow
# from swarms.structs.task import Task from swarms.structs.task import Task

@ -0,0 +1,225 @@
import time
from typing import Any, Dict, List, Optional, Union, Callable
from swarms.models import OpenAIChat
from typing import Any, Dict, List, Optional, Callable
import logging
import time
# Custome stopping condition
def stop_when_repeats(response: str) -> bool:
# Stop if the word stop appears in the response
return "Stop" in response.lower()
# class Flow:
# def __init__(
# self,
# llm: Any,
# template: str,
# max_loops: int = 1,
# stopping_condition: Optional[Callable[[str], bool]] = None,
# **kwargs: Any
# ):
# self.llm = llm
# self.template = template
# self.max_loops = max_loops
# self.stopping_condition = stopping_condition
# self.feedback = []
# self.history = []
# def __call__(
# self,
# prompt,
# **kwargs
# ) -> str:
# """Invoke the flow by providing a template and it's variables"""
# response = self.llm(prompt, **kwargs)
# return response
# def _check_stopping_condition(self, response: str) -> bool:
# """Check if the stopping condition is met"""
# if self.stopping_condition:
# return self.stopping_condition(response)
# return False
# def provide_feedback(self, feedback: str) -> None:
# """Allow users to to provide feedback on the responses"""
# feedback = self.feedback.append(feedback)
# return feedback
# def format_prompt(self, **kwargs: Any) -> str:
# """Format the template with the provided kwargs using f string interpolation"""
# return self.template.format(**kwargs)
# def _generate(self, formatted_prompts: str) -> str:
# """
# Generate a result using the lm
# """
# return self.llm(formatted_prompts)
# def run(self, **kwargs: Any) -> str:
# """Generate a result using the provided keyword args"""
# prompt = self.format_prompt(**kwargs)
# response = self._generate(prompt)
# return response
# def bulk_run(
# self,
# inputs: List[Dict[str, Any]]
# ) -> List[str]:
# """Generate responses for multiple input sets"""
# return [self.run(**input_data) for input_data in inputs]
# @staticmethod
# def from_llm_and_template(llm: Any, template: str) -> "Flow":
# """Create FlowStream from LLM and a string template"""
# return Flow(llm=llm, template=template)
# @staticmethod
# def from_llm_and_template_file(llm: Any, template_file: str) -> "Flow":
# """Create FlowStream from LLM and a template file"""
# with open(template_file, "r") as f:
# template = f.read()
# return Flow(llm=llm, template=template)
class Flow:
def __init__(
self,
llm: Any,
template: str,
max_loops: int = 1,
stopping_condition: Optional[Callable[[str], bool]] = None,
loop_interval: int = 1,
retry_attempts: int = 3,
retry_interval: int = 1,
**kwargs: Any,
):
self.llm = llm
self.template = template
self.max_loops = max_loops
self.stopping_condition = stopping_condition
self.loop_interval = loop_interval
self.retry_attempts = retry_attempts
self.retry_interval = retry_interval
self.feedback = []
def provide_feedback(self, feedback: str) -> None:
"""Allow users to provide feedback on the responses."""
self.feedback.append(feedback)
logging.info(f"Feedback received: {feedback}")
def _check_stopping_condition(self, response: str) -> bool:
"""Check if the stopping condition is met."""
if self.stopping_condition:
return self.stopping_condition(response)
return False
def __call__(self, prompt, **kwargs) -> str:
"""Invoke the flow by providing a template and its variables."""
response = self.llm(prompt, **kwargs)
return response
def format_prompt(self, **kwargs: Any) -> str:
"""Format the template with the provided kwargs using f-string interpolation."""
return self.template.format(**kwargs)
def _generate(self, task: str, formatted_prompts: str) -> str:
"""
Generate a result using the lm with optional query loops and stopping conditions.
"""
response = formatted_prompts
history = [task]
for _ in range(self.max_loops):
if self._check_stopping_condition(response):
break
attempt = 0
while attempt < self.retry_attempts:
try:
response = self.llm(response)
break
except Exception as e:
logging.error(f"Error generating response: {e}")
attempt += 1
time.sleep(self.retry_interval)
logging.info(f"Generated response: {response}")
history.append(response)
time.sleep(self.loop_interval)
return response, history
def run(self, **kwargs: Any) -> str:
"""Generate a result using the provided keyword args."""
task = self.format_prompt(**kwargs)
response, history = self._generate(task, task)
logging.info(f"Message history: {history}")
return response
def bulk_run(self, inputs: List[Dict[str, Any]]) -> List[str]:
"""Generate responses for multiple input sets."""
return [self.run(**input_data) for input_data in inputs]
@staticmethod
def from_llm_and_template(llm: Any, template: str) -> "Flow":
"""Create FlowStream from LLM and a string template."""
return Flow(llm=llm, template=template)
@staticmethod
def from_llm_and_template_file(llm: Any, template_file: str) -> "Flow":
"""Create FlowStream from LLM and a template file."""
with open(template_file, "r") as f:
template = f.read()
return Flow(llm=llm, template=template)
# # Configure logging
# logging.basicConfig(level=logging.INFO)
# llm = OpenAIChat(
# api_key="YOUR_API_KEY",
# max_tokens=1000,
# temperature=0.9,
# )
# def main():
# # Initialize the Flow class with parameters
# flow = Flow(
# llm=llm,
# template="Translate this to backwards: {sentence}",
# max_loops=3,
# stopping_condition=stop_when_repeats,
# loop_interval=2, # Wait 2 seconds between loops
# retry_attempts=2,
# retry_interval=1, # Wait 1 second between retries
# )
# # Predict using the Flow
# response = flow.run(sentence="Hello, World!")
# print("Response:", response)
# time.sleep(1) # Pause for demonstration purposes
# # Provide feedback on the result
# flow.provide_feedback("The translation was interesting!")
# time.sleep(1) # Pause for demonstration purposes
# # Bulk run
# inputs = [
# {"sentence": "This is a test."},
# {"sentence": "OpenAI is great."},
# {"sentence": "GPT models are powerful."},
# {"sentence": "stop and check if our stopping condition works."},
# ]
# responses = flow.bulk_run(inputs=inputs)
# for idx, res in enumerate(responses):
# print(f"Input: {inputs[idx]['sentence']}, Response: {res}")
# time.sleep(1) # Pause for demonstration purposes
# if __name__ == "__main__":
# main()

@ -142,8 +142,9 @@ class WebpageQATool(BaseTool):
async def _arun(self, url: str, question: str) -> str: async def _arun(self, url: str, question: str) -> str:
raise NotImplementedError raise NotImplementedError
class EdgeGPTTool: class EdgeGPTTool:
# Initialize the custom tool # Initialize the custom tool
def __init__( def __init__(
self, self,
model, model,
@ -152,10 +153,11 @@ class EdgeGPTTool:
): ):
super().__init__(name=name, description=description) super().__init__(name=name, description=description)
self.model = model self.model = model
def _run(self, prompt): def _run(self, prompt):
return self.model.__call__(prompt) return self.model.__call__(prompt)
@tool @tool
def VQAinference(self, inputs): def VQAinference(self, inputs):
""" """

@ -11,7 +11,7 @@ from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.history import InMemoryHistory
from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.key_binding import KeyBindings
from schemas.typings import Colors from schemas.typings import Colors
bindings = KeyBindings() bindings = KeyBindings()
@ -19,6 +19,7 @@ bindings = KeyBindings()
BASE_URL = os.environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/") BASE_URL = os.environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/")
# BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/") # BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/")
def create_keybindings(key: str = "c-@") -> KeyBindings: def create_keybindings(key: str = "c-@") -> KeyBindings:
""" """
Create keybindings for prompt_toolkit. Default key is ctrl+space. Create keybindings for prompt_toolkit. Default key is ctrl+space.
@ -99,6 +100,7 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> any:
# Only return specified keys that are in class_keys # Only return specified keys that are in class_keys
return {key for key in keys if key in class_keys} return {key for key in keys if key in class_keys}
def generate_random_hex(length: int = 17) -> str: def generate_random_hex(length: int = 17) -> str:
"""Generate a random hex string """Generate a random hex string
Args: Args:
@ -163,4 +165,3 @@ def logger(is_timed: bool):
return wrapper return wrapper
return decorator return decorator

@ -5,12 +5,12 @@ import os
# Assuming the BingChat class is in a file named "bing_chat.py" # Assuming the BingChat class is in a file named "bing_chat.py"
from bing_chat import BingChat, ConversationStyle from bing_chat import BingChat, ConversationStyle
class TestBingChat(unittest.TestCase):
class TestBingChat(unittest.TestCase):
def setUp(self): def setUp(self):
# Path to a mock cookies file for testing # Path to a mock cookies file for testing
self.mock_cookies_path = "./mock_cookies.json" self.mock_cookies_path = "./mock_cookies.json"
with open(self.mock_cookies_path, 'w') as file: with open(self.mock_cookies_path, "w") as file:
json.dump({"mock_cookie": "mock_value"}, file) json.dump({"mock_cookie": "mock_value"}, file)
self.chat = BingChat(cookies_path=self.mock_cookies_path) self.chat = BingChat(cookies_path=self.mock_cookies_path)
@ -33,10 +33,10 @@ class TestBingChat(unittest.TestCase):
class MockImageGen: class MockImageGen:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
def get_images(self, *args, **kwargs): def get_images(self, *args, **kwargs):
return [{"path": "mock_image.png"}] return [{"path": "mock_image.png"}]
@staticmethod @staticmethod
def save_images(*args, **kwargs): def save_images(*args, **kwargs):
pass pass
@ -54,5 +54,6 @@ class TestBingChat(unittest.TestCase):
BingChat.set_cookie_dir_path(test_path) BingChat.set_cookie_dir_path(test_path)
self.assertEqual(BingChat.Cookie.dir_path, test_path) self.assertEqual(BingChat.Cookie.dir_path, test_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -6,12 +6,11 @@ import torch
from transformers import BioGptForCausalLM, BioGptTokenizer from transformers import BioGptForCausalLM, BioGptTokenizer
# Fixture for BioGPT instance # Fixture for BioGPT instance
@pytest.fixture @pytest.fixture
def biogpt_instance(): def biogpt_instance():
from swarms.models import ( from swarms.models import (
BioGPT, BioGPT,
) )
return BioGPT() return BioGPT()

@ -9,6 +9,7 @@ from swarms.models.kosmos_two import Kosmos, is_overlapping
# A placeholder image URL for testing # A placeholder image URL for testing
TEST_IMAGE_URL = "https://images.unsplash.com/photo-1673267569891-ca4246caafd7?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDM1fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" TEST_IMAGE_URL = "https://images.unsplash.com/photo-1673267569891-ca4246caafd7?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDM1fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
# Mock the response for the test image # Mock the response for the test image
@pytest.fixture @pytest.fixture
def mock_image_request(): def mock_image_request():
@ -18,12 +19,14 @@ def mock_image_request():
with patch.object(requests, "get", return_value=mock_resp) as _fixture: with patch.object(requests, "get", return_value=mock_resp) as _fixture:
yield _fixture yield _fixture
# Test utility function # Test utility function
def test_is_overlapping(): def test_is_overlapping():
assert is_overlapping((1,1,3,3), (2,2,4,4)) == True assert is_overlapping((1, 1, 3, 3), (2, 2, 4, 4)) == True
assert is_overlapping((1,1,2,2), (3,3,4,4)) == False assert is_overlapping((1, 1, 2, 2), (3, 3, 4, 4)) == False
assert is_overlapping((0,0,1,1), (1,1,2,2)) == False assert is_overlapping((0, 0, 1, 1), (1, 1, 2, 2)) == False
assert is_overlapping((0,0,2,2), (1,1,2,2)) == True assert is_overlapping((0, 0, 2, 2), (1, 1, 2, 2)) == True
# Test model initialization # Test model initialization
def test_kosmos_init(): def test_kosmos_init():
@ -31,38 +34,49 @@ def test_kosmos_init():
assert kosmos.model is not None assert kosmos.model is not None
assert kosmos.processor is not None assert kosmos.processor is not None
# Test image fetching functionality # Test image fetching functionality
def test_get_image(mock_image_request): def test_get_image(mock_image_request):
kosmos = Kosmos() kosmos = Kosmos()
image = kosmos.get_image(TEST_IMAGE_URL) image = kosmos.get_image(TEST_IMAGE_URL)
assert image is not None assert image is not None
# Test multimodal grounding # Test multimodal grounding
def test_multimodal_grounding(mock_image_request): def test_multimodal_grounding(mock_image_request):
kosmos = Kosmos() kosmos = Kosmos()
kosmos.multimodal_grounding("Find the red apple in the image.", TEST_IMAGE_URL) kosmos.multimodal_grounding("Find the red apple in the image.", TEST_IMAGE_URL)
# TODO: Validate the result if possible # TODO: Validate the result if possible
# Test referring expression comprehension # Test referring expression comprehension
def test_referring_expression_comprehension(mock_image_request): def test_referring_expression_comprehension(mock_image_request):
kosmos = Kosmos() kosmos = Kosmos()
kosmos.referring_expression_comprehension("Show me the green bottle.", TEST_IMAGE_URL) kosmos.referring_expression_comprehension(
"Show me the green bottle.", TEST_IMAGE_URL
)
# TODO: Validate the result if possible # TODO: Validate the result if possible
# ... (continue with other functions in the same manner) ... # ... (continue with other functions in the same manner) ...
# Test error scenarios - Example # Test error scenarios - Example
@pytest.mark.parametrize("phrase, image_url", [ @pytest.mark.parametrize(
(None, TEST_IMAGE_URL), "phrase, image_url",
("Find the red apple in the image.", None), [
("", TEST_IMAGE_URL), (None, TEST_IMAGE_URL),
("Find the red apple in the image.", ""), ("Find the red apple in the image.", None),
]) ("", TEST_IMAGE_URL),
("Find the red apple in the image.", ""),
],
)
def test_kosmos_error_scenarios(phrase, image_url): def test_kosmos_error_scenarios(phrase, image_url):
kosmos = Kosmos() kosmos = Kosmos()
with pytest.raises(Exception): with pytest.raises(Exception):
kosmos.multimodal_grounding(phrase, image_url) kosmos.multimodal_grounding(phrase, image_url)
# ... (Add more tests for different edge cases and functionalities) ... # ... (Add more tests for different edge cases and functionalities) ...
# Sample test image URLs # Sample test image URLs
@ -72,6 +86,7 @@ IMG_URL3 = "https://images.unsplash.com/photo-1696900004042-60bcc200aca0?auto=fo
IMG_URL4 = "https://images.unsplash.com/photo-1676156340083-fd49e4e53a21?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDc4fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D" IMG_URL4 = "https://images.unsplash.com/photo-1676156340083-fd49e4e53a21?auto=format&fit=crop&q=60&w=400&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHx0b3BpYy1mZWVkfDc4fEpwZzZLaWRsLUhrfHxlbnwwfHx8fHw%3D"
IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=format&fit=crop&q=80&w=1287&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" IMG_URL5 = "https://images.unsplash.com/photo-1696862761045-0a65acbede8f?auto=format&fit=crop&q=80&w=1287&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
# Mock response for requests.get() # Mock response for requests.get()
class MockResponse: class MockResponse:
@staticmethod @staticmethod
@ -82,57 +97,69 @@ class MockResponse:
def raw(self): def raw(self):
return open("tests/sample_image.jpg", "rb") return open("tests/sample_image.jpg", "rb")
# Test the Kosmos class # Test the Kosmos class
@pytest.fixture @pytest.fixture
def kosmos(): def kosmos():
return Kosmos() return Kosmos()
# Mocking the requests.get() method # Mocking the requests.get() method
@pytest.fixture @pytest.fixture
def mock_request_get(monkeypatch): def mock_request_get(monkeypatch):
monkeypatch.setattr(requests, 'get', lambda url, **kwargs: MockResponse()) monkeypatch.setattr(requests, "get", lambda url, **kwargs: MockResponse())
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_multimodal_grounding(kosmos): def test_multimodal_grounding(kosmos):
kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1) kosmos.multimodal_grounding("Find the red apple in the image.", IMG_URL1)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_comprehension(kosmos): def test_referring_expression_comprehension(kosmos):
kosmos.referring_expression_comprehension("Show me the green bottle.", IMG_URL2) kosmos.referring_expression_comprehension("Show me the green bottle.", IMG_URL2)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_generation(kosmos): def test_referring_expression_generation(kosmos):
kosmos.referring_expression_generation("It is on the table.", IMG_URL3) kosmos.referring_expression_generation("It is on the table.", IMG_URL3)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_grounded_vqa(kosmos): def test_grounded_vqa(kosmos):
kosmos.grounded_vqa("What is the color of the car?", IMG_URL4) kosmos.grounded_vqa("What is the color of the car?", IMG_URL4)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_grounded_image_captioning(kosmos): def test_grounded_image_captioning(kosmos):
kosmos.grounded_image_captioning(IMG_URL5) kosmos.grounded_image_captioning(IMG_URL5)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_grounded_image_captioning_detailed(kosmos): def test_grounded_image_captioning_detailed(kosmos):
kosmos.grounded_image_captioning_detailed(IMG_URL1) kosmos.grounded_image_captioning_detailed(IMG_URL1)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_multimodal_grounding_2(kosmos): def test_multimodal_grounding_2(kosmos):
kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2) kosmos.multimodal_grounding("Find the yellow fruit in the image.", IMG_URL2)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_referring_expression_comprehension_2(kosmos): def test_referring_expression_comprehension_2(kosmos):
kosmos.referring_expression_comprehension("Where is the water bottle?", IMG_URL3) kosmos.referring_expression_comprehension("Where is the water bottle?", IMG_URL3)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_grounded_vqa_2(kosmos): def test_grounded_vqa_2(kosmos):
kosmos.grounded_vqa("How many cars are in the image?", IMG_URL4) kosmos.grounded_vqa("How many cars are in the image?", IMG_URL4)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_grounded_image_captioning_2(kosmos): def test_grounded_image_captioning_2(kosmos):
kosmos.grounded_image_captioning(IMG_URL2) kosmos.grounded_image_captioning(IMG_URL2)
@pytest.mark.usefixtures("mock_request_get") @pytest.mark.usefixtures("mock_request_get")
def test_grounded_image_captioning_detailed_2(kosmos): def test_grounded_image_captioning_detailed_2(kosmos):
kosmos.grounded_image_captioning_detailed(IMG_URL3) kosmos.grounded_image_captioning_detailed(IMG_URL3)

@ -2,12 +2,12 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from Sswarms.models.revgptv1 import RevChatGPTModelv1 from Sswarms.models.revgptv1 import RevChatGPTModelv1
class TestRevChatGPT(unittest.TestCase):
class TestRevChatGPT(unittest.TestCase):
def setUp(self): def setUp(self):
self.access_token = "<your_access_token>" self.access_token = "<your_access_token>"
self.model = RevChatGPTModelv1(access_token=self.access_token) self.model = RevChatGPTModelv1(access_token=self.access_token)
def test_run(self): def test_run(self):
prompt = "What is the capital of France?" prompt = "What is the capital of France?"
response = self.model.run(prompt) response = self.model.run(prompt)
@ -21,7 +21,7 @@ class TestRevChatGPT(unittest.TestCase):
def test_generate_summary(self): def test_generate_summary(self):
text = "This is a sample text to summarize. It has multiple sentences and details. The summary should be concise." text = "This is a sample text to summarize. It has multiple sentences and details. The summary should be concise."
summary = self.model.generate_summary(text) summary = self.model.generate_summary(text)
self.assertLess(len(summary), len(text)/2) self.assertLess(len(summary), len(text) / 2)
def test_enable_plugin(self): def test_enable_plugin(self):
plugin_id = "some_plugin_id" plugin_id = "some_plugin_id"
@ -39,9 +39,9 @@ class TestRevChatGPT(unittest.TestCase):
conversations = self.model.chatbot.get_conversations() conversations = self.model.chatbot.get_conversations()
self.assertIsInstance(conversations, list) self.assertIsInstance(conversations, list)
@patch("RevChatGPTModelv1.Chatbot.get_msg_history") @patch("RevChatGPTModelv1.Chatbot.get_msg_history")
def test_get_msg_history(self, mock_get_msg_history): def test_get_msg_history(self, mock_get_msg_history):
conversation_id = "convo_id" conversation_id = "convo_id"
self.model.chatbot.get_msg_history(conversation_id) self.model.chatbot.get_msg_history(conversation_id)
mock_get_msg_history.assert_called_with(conversation_id) mock_get_msg_history.assert_called_with(conversation_id)
@ -78,5 +78,6 @@ class TestRevChatGPT(unittest.TestCase):
self.model.chatbot.rollback_conversation(1) self.model.chatbot.rollback_conversation(1)
self.assertNotEqual(original_convo_id, self.model.chatbot.conversation_id) self.assertNotEqual(original_convo_id, self.model.chatbot.conversation_id)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -2,16 +2,16 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from RevChatGPTModelv4 import RevChatGPTModelv4 from RevChatGPTModelv4 import RevChatGPTModelv4
class TestRevChatGPT(unittest.TestCase):
class TestRevChatGPT(unittest.TestCase):
def setUp(self): def setUp(self):
self.access_token = "123" self.access_token = "123"
self.model = RevChatGPTModelv4(access_token=self.access_token) self.model = RevChatGPTModelv4(access_token=self.access_token)
def test_run(self): def test_run(self):
prompt = "What is the capital of France?" prompt = "What is the capital of France?"
self.model.start_time = 10 self.model.start_time = 10
self.model.end_time = 20 self.model.end_time = 20
response = self.model.run(prompt) response = self.model.run(prompt)
self.assertEqual(response, "The capital of France is Paris.") self.assertEqual(response, "The capital of France is Paris.")
self.assertEqual(self.model.start_time, 10) self.assertEqual(self.model.start_time, 10)
@ -44,7 +44,7 @@ class TestRevChatGPT(unittest.TestCase):
@patch("RevChatGPTModelv4.Chatbot.get_msg_history") @patch("RevChatGPTModelv4.Chatbot.get_msg_history")
def test_get_msg_history(self, mock_get_msg_history): def test_get_msg_history(self, mock_get_msg_history):
convo_id = "123" convo_id = "123"
self.model.chatbot.get_msg_history(convo_id) self.model.chatbot.get_msg_history(convo_id)
mock_get_msg_history.assert_called_with(convo_id) mock_get_msg_history.assert_called_with(convo_id)
@patch("RevChatGPTModelv4.Chatbot.share_conversation") @patch("RevChatGPTModelv4.Chatbot.share_conversation")
@ -52,7 +52,7 @@ class TestRevChatGPT(unittest.TestCase):
self.model.chatbot.share_conversation() self.model.chatbot.share_conversation()
mock_share_conversation.assert_called() mock_share_conversation.assert_called()
@patch("RevChatGPTModelv4.Chatbot.gen_title") @patch("RevChatGPTModelv4.Chatbot.gen_title")
def test_gen_title(self, mock_gen_title): def test_gen_title(self, mock_gen_title):
convo_id = "123" convo_id = "123"
message_id = "456" message_id = "456"
@ -77,7 +77,7 @@ class TestRevChatGPT(unittest.TestCase):
self.model.chatbot.clear_conversations() self.model.chatbot.clear_conversations()
mock_clear_conversations.assert_called() mock_clear_conversations.assert_called()
@patch("RevChatGPTModelv4.Chatbot.rollback_conversation") @patch("RevChatGPTModelv4.Chatbot.rollback_conversation")
def test_rollback_conversation(self, mock_rollback_conversation): def test_rollback_conversation(self, mock_rollback_conversation):
num = 2 num = 2
self.model.chatbot.rollback_conversation(num) self.model.chatbot.rollback_conversation(num)
@ -88,5 +88,6 @@ class TestRevChatGPT(unittest.TestCase):
self.model.chatbot.reset_chat() self.model.chatbot.reset_chat()
mock_reset_chat.assert_called() mock_reset_chat.assert_called()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save