From 24b9e645ef54261e2cfaf10439d56eb3412d6d45 Mon Sep 17 00:00:00 2001 From: Zack Date: Fri, 20 Oct 2023 15:04:19 -0500 Subject: [PATCH] feat: Add v1 for testing --- playground/agents/revgpt_agent.py | 2 +- revgpt.py | 8 +- schemas/typings.py | 198 +++ swarms/models/revgptV1.py | 1798 ++++++++++++++++++++++ swarms/models/{revgpt.py => revgptV4.py} | 71 +- swarms/utils/revutils.py | 166 ++ 6 files changed, 2193 insertions(+), 50 deletions(-) create mode 100644 schemas/typings.py create mode 100644 swarms/models/revgptV1.py rename swarms/models/{revgpt.py => revgptV4.py} (97%) create mode 100644 swarms/utils/revutils.py diff --git a/playground/agents/revgpt_agent.py b/playground/agents/revgpt_agent.py index 66b20fca..e8667e90 100644 --- a/playground/agents/revgpt_agent.py +++ b/playground/agents/revgpt_agent.py @@ -1,6 +1,6 @@ import os from dotenv import load_dotenv -from swarms.models.revgpt import RevChatGPTModel +from swarms.models.revgptV4 import RevChatGPTModel from swarms.workers.worker import Worker load_dotenv() diff --git a/revgpt.py b/revgpt.py index 03a6c81e..4bae7729 100644 --- a/revgpt.py +++ b/revgpt.py @@ -1,7 +1,8 @@ import os import sys from dotenv import load_dotenv -from swarms.models.revgpt import RevChatGPTModel +from swarms.models.revgptV4 import RevChatGPTModelv4 +from swarms.models.revgptV1 import RevChatGPTModelv1 root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(root_dir) @@ -17,10 +18,11 @@ config = { } # For v1 model -# model = RevChatGPTModel(access_token=os.getenv("ACCESS_TOKEN"), **config) +model = RevChatGPTModelv1(access_token=os.getenv("ACCESS_TOKEN"), **config) +# model = RevChatGPTModelv4(access_token=os.getenv("ACCESS_TOKEN"), **config) # For v3 model -model = RevChatGPTModel(access_token=os.getenv("OPENAI_API_KEY"), **config) +# model = RevChatGPTModel(access_token=os.getenv("OPENAI_API_KEY"), **config) task = "Write a cli snake game" response = model.run(task) diff --git a/schemas/typings.py b/schemas/typings.py new file mode 100644 index 00000000..faa902b5 --- /dev/null +++ b/schemas/typings.py @@ -0,0 +1,198 @@ +""" +A module that contains all the types used in this project +""" + +import os +import platform +from enum import Enum +from typing import Union + + +python_version = list(platform.python_version_tuple()) +SUPPORT_ADD_NOTES = int(python_version[0]) >= 3 and int(python_version[1]) >= 11 + + +class ChatbotError(Exception): + """ + Base class for all Chatbot errors in this Project + """ + + def __init__(self, *args: object) -> None: + if SUPPORT_ADD_NOTES: + super().add_note( + "Please check that the input is correct, or you can resolve this issue by filing an issue", + ) + super().add_note("Project URL: https://github.com/acheong08/ChatGPT") + super().__init__(*args) + + +class ActionError(ChatbotError): + """ + Subclass of ChatbotError + + An object that throws an error because the execution of an operation is blocked + """ + + def __init__(self, *args: object) -> None: + if SUPPORT_ADD_NOTES: + super().add_note( + "The current operation is not allowed, which may be intentional", + ) + super().__init__(*args) + + +class ActionNotAllowedError(ActionError): + """ + Subclass of ActionError + + An object that throws an error because the execution of an unalloyed operation is blocked + """ + + +class ActionRefuseError(ActionError): + """ + Subclass of ActionError + + An object that throws an error because the execution of a refused operation is blocked. + """ + + +class CLIError(ChatbotError): + """ + Subclass of ChatbotError + + The error caused by a CLI program error + """ + + +class ErrorType(Enum): + """ + Enumeration class for different types of errors. + """ + + USER_ERROR = -1 + UNKNOWN_ERROR = 0 + SERVER_ERROR = 1 + RATE_LIMIT_ERROR = 2 + INVALID_REQUEST_ERROR = 3 + EXPIRED_ACCESS_TOKEN_ERROR = 4 + INVALID_ACCESS_TOKEN_ERROR = 5 + PROHIBITED_CONCURRENT_QUERY_ERROR = 6 + AUTHENTICATION_ERROR = 7 + CLOUDFLARE_ERROR = 8 + + +class Error(ChatbotError): + """ + Base class for exceptions in V1 module. + """ + + def __init__( + self, + source: str, + message: str, + *args: object, + code: Union[ErrorType, int] = ErrorType.UNKNOWN_ERROR, + ) -> None: + self.source: str = source + self.message: str = message + self.code: ErrorType | int = code + super().__init__(*args) + + def __str__(self) -> str: + return f"{self.source}: {self.message} (code: {self.code})" + + def __repr__(self) -> str: + return f"{self.source}: {self.message} (code: {self.code})" + + +class AuthenticationError(ChatbotError): + """ + Subclass of ChatbotError + + The object of the error thrown by a validation failure or exception + """ + + def __init__(self, *args: object) -> None: + if SUPPORT_ADD_NOTES: + super().add_note( + "Please check if your key is correct, maybe it may not be valid", + ) + super().__init__(*args) + + +class APIConnectionError(ChatbotError): + """ + Subclass of ChatbotError + + An exception object thrown when an API connection fails or fails to connect due to network or + other miscellaneous reasons + """ + + def __init__(self, *args: object) -> None: + if SUPPORT_ADD_NOTES: + super().add_note( + "Please check if there is a problem with your network connection", + ) + super().__init__(*args) + + +class NotAllowRunning(ActionNotAllowedError): + """ + Subclass of ActionNotAllowedError + + Direct startup is not allowed for some reason + """ + + +class ResponseError(APIConnectionError): + """ + Subclass of APIConnectionError + + Error objects caused by API request errors due to network or other miscellaneous reasons + """ + + +class OpenAIError(APIConnectionError): + """ + Subclass of APIConnectionError + + Error objects caused by OpenAI's own server errors + """ + + +class RequestError(APIConnectionError): + """ + Subclass of APIConnectionError + + There is a problem with the API response due to network or other miscellaneous reasons, or there + is no reply to the object that caused the error at all + """ + + +class Colors: + """ + Colors for printing + """ + + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + def __init__(self) -> None: + if os.getenv("NO_COLOR"): + Colors.HEADER = "" + Colors.OKBLUE = "" + Colors.OKCYAN = "" + Colors.OKGREEN = "" + Colors.WARNING = "" + Colors.FAIL = "" + Colors.ENDC = "" + Colors.BOLD = "" + Colors.UNDERLINE = "" diff --git a/swarms/models/revgptV1.py b/swarms/models/revgptV1.py new file mode 100644 index 00000000..a3a4e878 --- /dev/null +++ b/swarms/models/revgptV1.py @@ -0,0 +1,1798 @@ +""" +Standard ChatGPT +""" +from __future__ import annotations +import argparse + +import base64 +import binascii +import contextlib +import json +import logging +import os +import secrets +import subprocess +import sys +import time +import uuid +from functools import wraps +from os import environ +from os import getenv + +try: + from os import startfile +except ImportError: + pass +from pathlib import Path +import tempfile +import random + +# Import function type + +import httpx +import requests +from httpx import AsyncClient +from OpenAIAuth import Auth0 as Authenticator +from rich.live import Live +from rich.markdown import Markdown +import schemas.typings as t +from swarms.utils.revutils import create_completer +from swarms.utils.revutils import create_session +from swarms.utils.revutils import get_input + +BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/") +# BASE_URL = os.environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/") +# BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/") + +bcolors = t.Colors() + +def generate_random_hex(length: int = 17) -> str: + """Generate a random hex string + + Args: + length (int, optional): Length of the hex string. Defaults to 17. + + Returns: + str: Random hex string + """ + return secrets.token_hex(length) + + +def random_int(min: int, max: int) -> int: + """Generate a random integer + + Args: + min (int): Minimum value + max (int): Maximum value + + Returns: + int: Random integer + """ + return secrets.randbelow(max - min) + min + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", + ) + +log = logging.getLogger(__name__) + + +def logger(is_timed: bool) -> function: + """Logger decorator + + Args: + is_timed (bool): Whether to include function running time in exit log + + Returns: + _type_: decorated function + """ + + def decorator(func: function) -> function: + wraps(func) + + def wrapper(*args, **kwargs): + log.debug( + "Entering %s with args %s and kwargs %s", + func.__name__, + args, + kwargs, + ) + start = time.time() + out = func(*args, **kwargs) + end = time.time() + if is_timed: + log.debug( + "Exiting %s with return value %s. Took %s seconds.", + func.__name__, + out, + end - start, + ) + else: + log.debug("Exiting %s with return value %s", func.__name__, out) + return out + + return wrapper + + return decorator + + +BASE_URL = environ.get("CHATGPT_BASE_URL", "http://bypass.bzff.cn:9090/") + + + +def captcha_solver(images: list[str], challenge_details: dict) -> int: + # Create tempfile + with tempfile.TemporaryDirectory() as tempdir: + filenames: list[Path] = [] + + for image in images: + filename = Path(tempdir, f"{time.time()}.jpeg") + with open(filename, "wb") as f: + f.write(base64.b64decode(image)) + print(f"Saved captcha image to {filename}") + # If MacOS, open the image + if sys.platform == "darwin": + subprocess.call(["open", filename]) + if sys.platform == "linux": + subprocess.call(["xdg-open", filename]) + if sys.platform == "win32": + startfile(filename) + filenames.append(filename) + + print(f'Captcha instructions: {challenge_details.get("instructions")}') + print( + "Developer instructions: The captcha images have an index starting from 0 from left to right", + ) + print("Enter the index of the images that matches the captcha instructions:") + return int(input()) + + +CAPTCHA_URL = getenv("CAPTCHA_URL", "https://bypass.churchless.tech/captcha/") + + +def get_arkose_token( + download_images: bool = True, + solver: function = captcha_solver, + captcha_supported: bool = True, +) -> str: + """ + The solver function should take in a list of images in base64 and a dict of challenge details + and return the index of the image that matches the challenge details + + Challenge details: + game_type: str - Audio or Image + instructions: str - Instructions for the captcha + URLs: list[str] - URLs of the images or audio files + """ + if captcha_supported: + resp = requests.get( + (CAPTCHA_URL + "start?download_images=true") + if download_images + else CAPTCHA_URL + "start", + ) + resp_json: dict = resp.json() + if resp.status_code == 200: + return resp_json.get("token") + if resp.status_code != 511: + raise Exception(resp_json.get("error", "Unknown error")) + + if resp_json.get("status") != "captcha": + raise Exception("unknown error") + + challenge_details: dict = resp_json.get("session", {}).get("concise_challenge") + if not challenge_details: + raise Exception("missing details") + + images: list[str] = resp_json.get("images") + + index = solver(images, challenge_details) + + resp = requests.post( + CAPTCHA_URL + "verify", + json={"session": resp_json.get("session"), "index": index}, + ) + if resp.status_code != 200: + raise Exception("Failed to verify captcha") + return resp_json.get("token") + # else: + # working_endpoints: list[str] = [] + # # Check uptime for different endpoints via gatus + # resp2: list[dict] = requests.get( + # "https://stats.churchless.tech/api/v1/endpoints/statuses?page=1" + # ).json() + # for endpoint in resp2: + # # print(endpoint.get("name")) + # if endpoint.get("group") != "Arkose Labs": + # continue + # # Check the last 5 results + # results: list[dict] = endpoint.get("results", [])[-5:-1] + # # print(results) + # if not results: + # print(f"Endpoint {endpoint.get('name')} has no results") + # continue + # # Check if all the results are up + # if all(result.get("success") == True for result in results): + # working_endpoints.append(endpoint.get("name")) + # if not working_endpoints: + # print("No working endpoints found. Please solve the captcha manually.\n找不到工作终结点。请手动解决captcha") + # return get_arkose_token(download_images=True, captcha_supported=False) + # # Choose a random endpoint + # endpoint = random.choice(working_endpoints) + # resp: requests.Response = requests.get(endpoint) + # if resp.status_code != 200: + # if resp.status_code != 511: + # raise Exception("Failed to get captcha token") + # else: + # print("需要验证码,请手动解决captcha.") + # return get_arkose_token(download_images=True, captcha_supported=True) + # try: + # return resp.json().get("token") + # except Exception: + # return resp.text + + +class Chatbot: + """ + Chatbot class for ChatGPT + """ + + @logger(is_timed=True) + def __init__( + self, + config: dict[str, str], + conversation_id: str | None = None, + parent_id: str | None = None, + lazy_loading: bool = True, + base_url: str | None = None, + captcha_solver: function = captcha_solver, + captcha_download_images: bool = True, + ) -> None: + """Initialize a chatbot + + Args: + config (dict[str, str]): Login and proxy info. Example: + { + "access_token": "" + "proxy": "", + "model": "", + "plugin": "", + } + More details on these are available at https://github.com/acheong08/ChatGPT#configuration + conversation_id (str | None, optional): Id of the conversation to continue on. Defaults to None. + parent_id (str | None, optional): Id of the previous response message to continue on. Defaults to None. + lazy_loading (bool, optional): Whether to load only the active conversation. Defaults to True. + base_url (str | None, optional): Base URL of the ChatGPT server. Defaults to None. + captcha_solver (function, optional): Function to solve captcha. Defaults to captcha_solver. + captcha_download_images (bool, optional): Whether to download captcha images. Defaults to True. + + Raises: + Exception: _description_ + """ + user_home = getenv("HOME") or getenv("USERPROFILE") + if user_home is None: + user_home = Path().cwd() + self.cache_path = Path(Path().cwd(), ".chatgpt_cache.json") + else: + # mkdir ~/.config/revChatGPT + if not Path(user_home, ".config").exists(): + Path(user_home, ".config").mkdir() + if not Path(user_home, ".config", "revChatGPT").exists(): + Path(user_home, ".config", "revChatGPT").mkdir() + self.cache_path = Path(user_home, ".config", "revChatGPT", "cache.json") + + self.config = config + self.session = requests.Session() + if "email" in config and "password" in config: + try: + cached_access_token = self.__get_cached_access_token( + self.config.get("email", None), + ) + except t.Error as error: + if error.code == 5: + raise + cached_access_token = None + if cached_access_token is not None: + self.config["access_token"] = cached_access_token + + if "proxy" in config: + if not isinstance(config["proxy"], str): + error = TypeError("Proxy must be a string!") + raise error + proxies = { + "http": config["proxy"], + "https": config["proxy"], + } + if isinstance(self.session, AsyncClient): + proxies = { + "http://": config["proxy"], + "https://": config["proxy"], + } + self.session = AsyncClient(proxies=proxies) # type: ignore + else: + self.session.proxies.update(proxies) + + self.conversation_id = conversation_id or config.get("conversation_id") + self.parent_id = parent_id or config.get("parent_id") + self.conversation_mapping = {} + self.conversation_id_prev_queue = [] + self.parent_id_prev_queue = [] + self.lazy_loading = lazy_loading + self.base_url = base_url or BASE_URL + self.disable_history = config.get("disable_history", False) + + self.__check_credentials() + + if self.config.get("plugin_ids", []): + for plugin in self.config.get("plugin_ids"): + self.install_plugin(plugin) + if self.config.get("unverified_plugin_domains", []): + for domain in self.config.get("unverified_plugin_domains"): + if self.config.get("plugin_ids"): + self.config["plugin_ids"].append( + self.get_unverified_plugin(domain, install=True).get("id"), + ) + else: + self.config["plugin_ids"] = [ + self.get_unverified_plugin(domain, install=True).get("id"), + ] + # Get PUID cookie + try: + auth = Authenticator("blah", "blah") + auth.access_token = self.config["access_token"] + puid = auth.get_puid() + self.session.headers.update({"PUID": puid}) + print("Setting PUID (You are a Plus user!): " + puid) + except: + pass + self.captcha_solver = captcha_solver + self.captcha_download_images = captcha_download_images + + @logger(is_timed=True) + def __check_credentials(self) -> None: + """Check login info and perform login + + Any one of the following is sufficient for login. Multiple login info can be provided at the same time and they will be used in the order listed below. + - access_token + - email + password + + Raises: + Exception: _description_ + AuthError: _description_ + """ + if "access_token" in self.config: + self.set_access_token(self.config["access_token"]) + elif "email" not in self.config or "password" not in self.config: + error = t.AuthenticationError("Insufficient login details provided!") + raise error + if "access_token" not in self.config: + try: + self.login() + except Exception as error: + print(error) + raise error + + @logger(is_timed=False) + def set_access_token(self, access_token: str) -> None: + """Set access token in request header and self.config, then cache it to file. + + Args: + access_token (str): access_token + """ + self.session.headers.clear() + self.session.headers.update( + { + "Accept": "text/event-stream", + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36", + }, + ) + + self.config["access_token"] = access_token + + email = self.config.get("email", None) + if email is not None: + self.__cache_access_token(email, access_token) + + @logger(is_timed=False) + def __get_cached_access_token(self, email: str | None) -> str | None: + """Read access token from cache + + Args: + email (str | None): email of the account to get access token + + Raises: + Error: _description_ + Error: _description_ + Error: _description_ + + Returns: + str | None: access token string or None if not found + """ + email = email or "default" + cache = self.__read_cache() + access_token = cache.get("access_tokens", {}).get(email, None) + + # Parse access_token as JWT + if access_token is not None: + try: + # Split access_token into 3 parts + s_access_token = access_token.split(".") + # Add padding to the middle part + s_access_token[1] += "=" * ((4 - len(s_access_token[1]) % 4) % 4) + d_access_token = base64.b64decode(s_access_token[1]) + d_access_token = json.loads(d_access_token) + except binascii.Error: + del cache["access_tokens"][email] + self.__write_cache(cache) + error = t.Error( + source="__get_cached_access_token", + message="Invalid access token", + code=t.ErrorType.INVALID_ACCESS_TOKEN_ERROR, + ) + del cache["access_tokens"][email] + raise error from None + except json.JSONDecodeError: + del cache["access_tokens"][email] + self.__write_cache(cache) + error = t.Error( + source="__get_cached_access_token", + message="Invalid access token", + code=t.ErrorType.INVALID_ACCESS_TOKEN_ERROR, + ) + raise error from None + except IndexError: + del cache["access_tokens"][email] + self.__write_cache(cache) + error = t.Error( + source="__get_cached_access_token", + message="Invalid access token", + code=t.ErrorType.INVALID_ACCESS_TOKEN_ERROR, + ) + raise error from None + + exp = d_access_token.get("exp", None) + if exp is not None and exp < time.time(): + error = t.Error( + source="__get_cached_access_token", + message="Access token expired", + code=t.ErrorType.EXPIRED_ACCESS_TOKEN_ERROR, + ) + raise error + + return access_token + + @logger(is_timed=False) + def __cache_access_token(self, email: str, access_token: str) -> None: + """Write an access token to cache + + Args: + email (str): account email + access_token (str): account access token + """ + email = email or "default" + cache = self.__read_cache() + if "access_tokens" not in cache: + cache["access_tokens"] = {} + cache["access_tokens"][email] = access_token + self.__write_cache(cache) + + @logger(is_timed=False) + def __write_cache(self, info: dict) -> None: + """Write cache info to file + + Args: + info (dict): cache info, current format + { + "access_tokens":{"someone@example.com": 'this account's access token', } + } + """ + dirname = self.cache_path.home() or Path(".") + dirname.mkdir(parents=True, exist_ok=True) + json.dump(info, open(self.cache_path, "w", encoding="utf-8"), indent=4) + + @logger(is_timed=False) + def __read_cache(self) -> dict[str, dict[str, str]]: + try: + cached = json.load(open(self.cache_path, encoding="utf-8")) + except (FileNotFoundError, json.decoder.JSONDecodeError): + cached = {} + return cached + + @logger(is_timed=True) + def login(self) -> None: + """Login to OpenAI by email and password""" + if not self.config.get("email") and not self.config.get("password"): + log.error("Insufficient login details provided!") + error = t.AuthenticationError("Insufficient login details provided!") + raise error + auth = Authenticator( + email_address=self.config.get("email"), + password=self.config.get("password"), + proxy=self.config.get("proxy"), + ) + log.debug("Using authenticator to get access token") + + self.set_access_token(auth.get_access_token()) + + @logger(is_timed=True) + def __send_request( + self, + data: dict, + auto_continue: bool = False, + timeout: float = 360, + **kwargs, + ) -> any: + log.debug("Sending the payload") + + if ( + data.get("model", "").startswith("gpt-4") + and not self.config.get("SERVER_SIDE_ARKOSE") + and not getenv("SERVER_SIDE_ARKOSE") + ): + try: + data["arkose_token"] = get_arkose_token( + self.captcha_download_images, + self.captcha_solver, + captcha_supported=False, + ) + # print(f"Arkose token obtained: {data['arkose_token']}") + except Exception as e: + print(e) + raise + + cid, pid = data["conversation_id"], data["parent_message_id"] + message = "" + + self.conversation_id_prev_queue.append(cid) + self.parent_id_prev_queue.append(pid) + response = self.session.post( + url=f"{self.base_url}conversation", + data=json.dumps(data), + timeout=timeout, + stream=True, + ) + self.__check_response(response) + + finish_details = None + for line in response.iter_lines(): + # remove b' and ' at the beginning and end and ignore case + line = str(line)[2:-1] + if line.lower() == "internal server error": + log.error(f"Internal Server Error: {line}") + error = t.Error( + source="ask", + message="Internal Server Error", + code=t.ErrorType.SERVER_ERROR, + ) + raise error + if not line or line is None: + continue + if "data: " in line: + line = line[6:] + if line == "[DONE]": + break + + # DO NOT REMOVE THIS + line = line.replace('\\"', '"') + line = line.replace("\\'", "'") + line = line.replace("\\\\", "\\") + + try: + line = json.loads(line) + except json.decoder.JSONDecodeError: + continue + if not self.__check_fields(line): + continue + if line.get("message").get("author").get("role") != "assistant": + continue + + cid = line["conversation_id"] + pid = line["message"]["id"] + metadata = line["message"].get("metadata", {}) + message_exists = False + author = {} + if line.get("message"): + author = metadata.get("author", {}) or line["message"].get("author", {}) + if ( + line["message"].get("content") + and line["message"]["content"].get("parts") + and len(line["message"]["content"]["parts"]) > 0 + ): + message_exists = True + message: str = ( + line["message"]["content"]["parts"][0] if message_exists else "" + ) + model = metadata.get("model_slug", None) + finish_details = metadata.get("finish_details", {"type": None})["type"] + yield { + "author": author, + "message": message, + "conversation_id": cid, + "parent_id": pid, + "model": model, + "finish_details": finish_details, + "end_turn": line["message"].get("end_turn", True), + "recipient": line["message"].get("recipient", "all"), + "citations": metadata.get("citations", []), + } + + self.conversation_mapping[cid] = pid + if pid is not None: + self.parent_id = pid + if cid is not None: + self.conversation_id = cid + + if not (auto_continue and finish_details == "max_tokens"): + return + message = message.strip("\n") + for i in self.continue_write( + conversation_id=cid, + model=model, + timeout=timeout, + auto_continue=False, + ): + i["message"] = message + i["message"] + yield i + + @logger(is_timed=True) + def post_messages( + self, + messages: list[dict], + conversation_id: str | None = None, + parent_id: str | None = None, + plugin_ids: list = None, + model: str | None = None, + auto_continue: bool = False, + timeout: float = 360, + **kwargs, + ) -> any: + """Ask a question to the chatbot + Args: + messages (list[dict]): The messages to send + conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. + parent_id (str | None, optional): UUID for the message to continue on. Defaults to None. + model (str | None, optional): The model to use. Defaults to None. + auto_continue (bool, optional): Whether to continue the conversation automatically. Defaults to False. + timeout (float, optional): Timeout for getting the full response, unit is second. Defaults to 360. + + Yields: Generator[dict, None, None] - The response from the chatbot + dict: { + "message": str, + "conversation_id": str, + "parent_id": str, + "model": str, + "finish_details": str, # "max_tokens" or "stop" + "end_turn": bool, + "recipient": str, + "citations": list[dict], + } + """ + if plugin_ids is None: + plugin_ids = [] + if parent_id and not conversation_id: + raise t.Error( + source="User", + message="conversation_id must be set once parent_id is set", + code=t.ErrorType.USER_ERROR, + ) + + if conversation_id and conversation_id != self.conversation_id: + self.parent_id = None + conversation_id = conversation_id or self.conversation_id + parent_id = parent_id or self.parent_id or "" + if not conversation_id and not parent_id: + parent_id = str(uuid.uuid4()) + + if conversation_id and not parent_id: + if conversation_id not in self.conversation_mapping: + if self.lazy_loading: + log.debug( + "Conversation ID %s not found in conversation mapping, try to get conversation history for the given ID", + conversation_id, + ) + try: + history = self.get_msg_history(conversation_id) + self.conversation_mapping[conversation_id] = history[ + "current_node" + ] + except requests.exceptions.HTTPError: + print("Conversation unavailable") + else: + self.__map_conversations() + if conversation_id in self.conversation_mapping: + parent_id = self.conversation_mapping[conversation_id] + else: + print( + "Warning: Invalid conversation_id provided, treat as a new conversation", + ) + conversation_id = None + parent_id = str(uuid.uuid4()) + model = model or self.config.get("model") or "text-davinci-002-render-sha" + data = { + "action": "next", + "messages": messages, + "conversation_id": conversation_id, + "parent_message_id": parent_id, + "model": model, + "history_and_training_disabled": self.disable_history, + } + plugin_ids = self.config.get("plugin_ids", []) or plugin_ids + if len(plugin_ids) > 0 and not conversation_id: + data["plugin_ids"] = plugin_ids + + yield from self.__send_request( + data, + timeout=timeout, + auto_continue=auto_continue, + ) + + @logger(is_timed=True) + def ask( + self, + prompt: str, + conversation_id: str | None = None, + parent_id: str = "", + model: str = "", + plugin_ids: list = None, + auto_continue: bool = False, + timeout: float = 360, + **kwargs, + ) -> any: + """Ask a question to the chatbot + Args: + prompt (str): The question + conversation_id (str, optional): UUID for the conversation to continue on. Defaults to None. + parent_id (str, optional): UUID for the message to continue on. Defaults to "". + model (str, optional): The model to use. Defaults to "". + auto_continue (bool, optional): Whether to continue the conversation automatically. Defaults to False. + timeout (float, optional): Timeout for getting the full response, unit is second. Defaults to 360. + + Yields: The response from the chatbot + dict: { + "message": str, + "conversation_id": str, + "parent_id": str, + "model": str, + "finish_details": str, # "max_tokens" or "stop" + "end_turn": bool, + "recipient": str, + } + """ + if plugin_ids is None: + plugin_ids = [] + messages = [ + { + "id": str(uuid.uuid4()), + "role": "user", + "author": {"role": "user"}, + "content": {"content_type": "text", "parts": [prompt]}, + }, + ] + + yield from self.post_messages( + messages, + conversation_id=conversation_id, + parent_id=parent_id, + plugin_ids=plugin_ids, + model=model, + auto_continue=auto_continue, + timeout=timeout, + ) + + @logger(is_timed=True) + def continue_write( + self, + conversation_id: str | None = None, + parent_id: str = "", + model: str = "", + auto_continue: bool = False, + timeout: float = 360, + ) -> any: + """let the chatbot continue to write. + Args: + conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. + parent_id (str, optional): UUID for the message to continue on. Defaults to None. + model (str, optional): The model to use. Defaults to None. + auto_continue (bool, optional): Whether to continue the conversation automatically. Defaults to False. + timeout (float, optional): Timeout for getting the full response, unit is second. Defaults to 360. + + Yields: + dict: { + "message": str, + "conversation_id": str, + "parent_id": str, + "model": str, + "finish_details": str, # "max_tokens" or "stop" + "end_turn": bool, + "recipient": str, + } + """ + if parent_id and not conversation_id: + raise t.Error( + source="User", + message="conversation_id must be set once parent_id is set", + code=t.ErrorType.USER_ERROR, + ) + + if conversation_id and conversation_id != self.conversation_id: + self.parent_id = None + conversation_id = conversation_id or self.conversation_id + parent_id = parent_id or self.parent_id or "" + if not conversation_id and not parent_id: + parent_id = str(uuid.uuid4()) + + if conversation_id and not parent_id: + if conversation_id not in self.conversation_mapping: + if self.lazy_loading: + log.debug( + "Conversation ID %s not found in conversation mapping, try to get conversation history for the given ID", + conversation_id, + ) + with contextlib.suppress(Exception): + history = self.get_msg_history(conversation_id) + self.conversation_mapping[conversation_id] = history[ + "current_node" + ] + else: + log.debug( + f"Conversation ID {conversation_id} not found in conversation mapping, mapping conversations", + ) + self.__map_conversations() + if conversation_id in self.conversation_mapping: + parent_id = self.conversation_mapping[conversation_id] + else: # invalid conversation_id provided, treat as a new conversation + conversation_id = None + parent_id = str(uuid.uuid4()) + model = model or self.config.get("model") or "text-davinci-002-render-sha" + data = { + "action": "continue", + "conversation_id": conversation_id, + "parent_message_id": parent_id, + "model": model + or self.config.get("model") + or ( + "text-davinci-002-render-paid" + if self.config.get("paid") + else "text-davinci-002-render-sha" + ), + "history_and_training_disabled": self.disable_history, + } + yield from self.__send_request( + data, + timeout=timeout, + auto_continue=auto_continue, + ) + + @logger(is_timed=False) + def __check_fields(self, data: dict) -> bool: + try: + data["message"]["content"] + except (TypeError, KeyError): + return False + return True + + @logger(is_timed=False) + def __check_response(self, response: requests.Response) -> None: + """Make sure response is success + + Args: + response (_type_): _description_ + + Raises: + Error: _description_ + """ + try: + response.raise_for_status() + except requests.exceptions.HTTPError as ex: + error = t.Error( + source="OpenAI", + message=response.text, + code=response.status_code, + ) + raise error from ex + + @logger(is_timed=True) + def get_conversations( + self, + offset: int = 0, + limit: int = 20, + encoding: str | None = None, + ) -> list: + """ + Get conversations + :param offset: Integer + :param limit: Integer + """ + url = f"{self.base_url}conversations?offset={offset}&limit={limit}" + response = self.session.get(url) + self.__check_response(response) + if encoding is not None: + response.encoding = encoding + data = json.loads(response.text) + return data["items"] + + @logger(is_timed=True) + def get_msg_history(self, convo_id: str, encoding: str | None = None) -> list: + """ + Get message history + :param id: UUID of conversation + :param encoding: String + """ + url = f"{self.base_url}conversation/{convo_id}" + response = self.session.get(url) + self.__check_response(response) + if encoding is not None: + response.encoding = encoding + return response.json() + + def share_conversation( + self, + title: str = None, + convo_id: str = None, + node_id: str = None, + anonymous: bool = True, + ) -> str: + """ + Creates a share link to a conversation + :param convo_id: UUID of conversation + :param node_id: UUID of node + :param anonymous: Boolean + :param title: String + + Returns: + str: A URL to the shared link + """ + convo_id = convo_id or self.conversation_id + node_id = node_id or self.parent_id + headers = { + "Content-Type": "application/json", + "origin": "https://chat.openai.com", + "referer": f"https://chat.openai.com/c/{convo_id}", + } + # First create the share + payload = { + "conversation_id": convo_id, + "current_node_id": node_id, + "is_anonymous": anonymous, + } + url = f"{self.base_url}share/create" + response = self.session.post(url, data=json.dumps(payload), headers=headers) + self.__check_response(response) + share_url = response.json().get("share_url") + # Then patch the share to make public + share_id = response.json().get("share_id") + url = f"{self.base_url}share/{share_id}" + payload = { + "share_id": share_id, + "highlighted_message_id": node_id, + "title": title or response.json().get("title", "New chat"), + "is_public": True, + "is_visible": True, + "is_anonymous": True, + } + response = self.session.patch(url, data=json.dumps(payload), headers=headers) + self.__check_response(response) + return share_url + + @logger(is_timed=True) + def gen_title(self, convo_id: str, message_id: str) -> str: + """ + Generate title for conversation + :param id: UUID of conversation + :param message_id: UUID of message + """ + response = self.session.post( + f"{self.base_url}conversation/gen_title/{convo_id}", + data=json.dumps( + {"message_id": message_id, "model": "text-davinci-002-render"}, + ), + ) + self.__check_response(response) + return response.json().get("title", "Error generating title") + + @logger(is_timed=True) + def change_title(self, convo_id: str, title: str) -> None: + """ + Change title of conversation + :param id: UUID of conversation + :param title: String + """ + url = f"{self.base_url}conversation/{convo_id}" + response = self.session.patch(url, data=json.dumps({"title": title})) + self.__check_response(response) + + @logger(is_timed=True) + def delete_conversation(self, convo_id: str) -> None: + """ + Delete conversation + :param id: UUID of conversation + """ + url = f"{self.base_url}conversation/{convo_id}" + response = self.session.patch(url, data='{"is_visible": false}') + self.__check_response(response) + + @logger(is_timed=True) + def clear_conversations(self) -> None: + """ + Delete all conversations + """ + url = f"{self.base_url}conversations" + response = self.session.patch(url, data='{"is_visible": false}') + self.__check_response(response) + + @logger(is_timed=False) + def __map_conversations(self) -> None: + conversations = self.get_conversations() + histories = [self.get_msg_history(x["id"]) for x in conversations] + for x, y in zip(conversations, histories): + self.conversation_mapping[x["id"]] = y["current_node"] + + @logger(is_timed=False) + def reset_chat(self) -> None: + """ + Reset the conversation ID and parent ID. + + :return: None + """ + self.conversation_id = None + self.parent_id = str(uuid.uuid4()) + + @logger(is_timed=False) + def rollback_conversation(self, num: int = 1) -> None: + """ + Rollback the conversation. + :param num: Integer. The number of messages to rollback + :return: None + """ + for _ in range(num): + self.conversation_id = self.conversation_id_prev_queue.pop() + self.parent_id = self.parent_id_prev_queue.pop() + + @logger(is_timed=True) + def get_plugins( + self, + offset: int = 0, + limit: int = 250, + status: str = "approved", + ) -> dict[str, str]: + """ + Get plugins + :param offset: Integer. Offset (Only supports 0) + :param limit: Integer. Limit (Only below 250) + :param status: String. Status of plugin (approved) + """ + url = f"{self.base_url}aip/p?offset={offset}&limit={limit}&statuses={status}" + response = self.session.get(url) + self.__check_response(response) + # Parse as JSON + return json.loads(response.text) + + @logger(is_timed=True) + def install_plugin(self, plugin_id: str) -> None: + """ + Install plugin by ID + :param plugin_id: String. ID of plugin + """ + url = f"{self.base_url}aip/p/{plugin_id}/user-settings" + payload = {"is_installed": True} + response = self.session.patch(url, data=json.dumps(payload)) + self.__check_response(response) + + @logger(is_timed=True) + def get_unverified_plugin(self, domain: str, install: bool = True) -> dict: + """ + Get unverified plugin by domain + :param domain: String. Domain of plugin + :param install: Boolean. Install plugin if found + """ + url = f"{self.base_url}aip/p/domain?domain={domain}" + response = self.session.get(url) + self.__check_response(response) + if install: + self.install_plugin(response.json().get("id")) + return response.json() + + +class AsyncChatbot(Chatbot): + """Async Chatbot class for ChatGPT""" + + def __init__( + self, + config: dict, + conversation_id: str | None = None, + parent_id: str | None = None, + base_url: str | None = None, + lazy_loading: bool = True, + ) -> None: + """ + Same as Chatbot class, but with async methods. + """ + super().__init__( + config=config, + conversation_id=conversation_id, + parent_id=parent_id, + base_url=base_url, + lazy_loading=lazy_loading, + ) + + # overwrite inherited normal session with async + self.session = AsyncClient(headers=self.session.headers) + + async def __send_request( + self, + data: dict, + auto_continue: bool = False, + timeout: float = 360, + **kwargs, + ) -> any: + log.debug("Sending the payload") + + cid, pid = data["conversation_id"], data["parent_message_id"] + message = "" + self.conversation_id_prev_queue.append(cid) + self.parent_id_prev_queue.append(pid) + async with self.session.stream( + "POST", + url=f"{self.base_url}conversation", + data=json.dumps(data), + timeout=timeout, + ) as response: + await self.__check_response(response) + + finish_details = None + async for line in response.aiter_lines(): + if line.lower() == "internal server error": + log.error(f"Internal Server Error: {line}") + error = t.Error( + source="ask", + message="Internal Server Error", + code=t.ErrorType.SERVER_ERROR, + ) + raise error + if not line or line is None: + continue + if "data: " in line: + line = line[6:] + if line == "[DONE]": + break + + try: + line = json.loads(line) + except json.decoder.JSONDecodeError: + continue + + if not self.__check_fields(line): + continue + if line.get("message").get("author").get("role") != "assistant": + continue + + cid = line["conversation_id"] + pid = line["message"]["id"] + metadata = line["message"].get("metadata", {}) + message_exists = False + author = {} + if line.get("message"): + author = metadata.get("author", {}) or line["message"].get( + "author", + {}, + ) + if ( + line["message"].get("content") + and line["message"]["content"].get("parts") + and len(line["message"]["content"]["parts"]) > 0 + ): + message_exists = True + message: str = ( + line["message"]["content"]["parts"][0] if message_exists else "" + ) + model = metadata.get("model_slug", None) + finish_details = metadata.get("finish_details", {"type": None})["type"] + yield { + "author": author, + "message": message, + "conversation_id": cid, + "parent_id": pid, + "model": model, + "finish_details": finish_details, + "end_turn": line["message"].get("end_turn", True), + "recipient": line["message"].get("recipient", "all"), + "citations": metadata.get("citations", []), + } + + self.conversation_mapping[cid] = pid + if pid is not None: + self.parent_id = pid + if cid is not None: + self.conversation_id = cid + + if not (auto_continue and finish_details == "max_tokens"): + return + message = message.strip("\n") + async for i in self.continue_write( + conversation_id=cid, + model=model, + timeout=timeout, + auto_continue=False, + ): + i["message"] = message + i["message"] + yield i + + async def post_messages( + self, + messages: list[dict], + conversation_id: str | None = None, + parent_id: str | None = None, + plugin_ids: list = None, + model: str | None = None, + auto_continue: bool = False, + timeout: float = 360, + **kwargs, + ) -> any: + """Post messages to the chatbot + + Args: + messages (list[dict]): the messages to post + conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. + parent_id (str | None, optional): UUID for the message to continue on. Defaults to None. + model (str | None, optional): The model to use. Defaults to None. + auto_continue (bool, optional): Whether to continue the conversation automatically. Defaults to False. + timeout (float, optional): Timeout for getting the full response, unit is second. Defaults to 360. + + Yields: + AsyncGenerator[dict, None]: The response from the chatbot + { + "message": str, + "conversation_id": str, + "parent_id": str, + "model": str, + "finish_details": str, + "end_turn": bool, + "recipient": str, + "citations": list[dict], + } + """ + if plugin_ids is None: + plugin_ids = [] + if parent_id and not conversation_id: + raise t.Error( + source="User", + message="conversation_id must be set once parent_id is set", + code=t.ErrorType.USER_ERROR, + ) + + if conversation_id and conversation_id != self.conversation_id: + self.parent_id = None + conversation_id = conversation_id or self.conversation_id + parent_id = parent_id or self.parent_id or "" + if not conversation_id and not parent_id: + parent_id = str(uuid.uuid4()) + + if conversation_id and not parent_id: + if conversation_id not in self.conversation_mapping: + if self.lazy_loading: + log.debug( + "Conversation ID %s not found in conversation mapping, try to get conversation history for the given ID", + conversation_id, + ) + try: + history = await self.get_msg_history(conversation_id) + self.conversation_mapping[conversation_id] = history[ + "current_node" + ] + except requests.exceptions.HTTPError: + print("Conversation unavailable") + else: + await self.__map_conversations() + if conversation_id in self.conversation_mapping: + parent_id = self.conversation_mapping[conversation_id] + else: + print( + "Warning: Invalid conversation_id provided, treat as a new conversation", + ) + conversation_id = None + parent_id = str(uuid.uuid4()) + model = model or self.config.get("model") or "text-davinci-002-render-sha" + data = { + "action": "next", + "messages": messages, + "conversation_id": conversation_id, + "parent_message_id": parent_id, + "model": model, + "history_and_training_disabled": self.disable_history, + } + plugin_ids = self.config.get("plugin_ids", []) or plugin_ids + if len(plugin_ids) > 0 and not conversation_id: + data["plugin_ids"] = plugin_ids + async for msg in self.__send_request( + data, + timeout=timeout, + auto_continue=auto_continue, + ): + yield msg + + async def ask( + self, + prompt: str, + conversation_id: str | None = None, + parent_id: str = "", + model: str = "", + plugin_ids: list = None, + auto_continue: bool = False, + timeout: int = 360, + **kwargs, + ) -> any: + """Ask a question to the chatbot + + Args: + prompt (str): The question to ask + conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. + parent_id (str, optional): UUID for the message to continue on. Defaults to "". + model (str, optional): The model to use. Defaults to "". + auto_continue (bool, optional): Whether to continue the conversation automatically. Defaults to False. + timeout (float, optional): Timeout for getting the full response, unit is second. Defaults to 360. + + Yields: + AsyncGenerator[dict, None]: The response from the chatbot + { + "message": str, + "conversation_id": str, + "parent_id": str, + "model": str, + "finish_details": str, + "end_turn": bool, + "recipient": str, + } + """ + + if plugin_ids is None: + plugin_ids = [] + messages = [ + { + "id": str(uuid.uuid4()), + "author": {"role": "user"}, + "content": {"content_type": "text", "parts": [prompt]}, + }, + ] + + async for msg in self.post_messages( + messages=messages, + conversation_id=conversation_id, + parent_id=parent_id, + plugin_ids=plugin_ids, + model=model, + auto_continue=auto_continue, + timeout=timeout, + ): + yield msg + + async def continue_write( + self, + conversation_id: str | None = None, + parent_id: str = "", + model: str = "", + auto_continue: bool = False, + timeout: float = 360, + ) -> any: + """let the chatbot continue to write + Args: + conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. + parent_id (str, optional): UUID for the message to continue on. Defaults to None. + model (str, optional): Model to use. Defaults to None. + auto_continue (bool, optional): Whether to continue writing automatically. Defaults to False. + timeout (float, optional): Timeout for getting the full response, unit is second. Defaults to 360. + + + Yields: + AsyncGenerator[dict, None]: The response from the chatbot + { + "message": str, + "conversation_id": str, + "parent_id": str, + "model": str, + "finish_details": str, + "end_turn": bool, + "recipient": str, + } + """ + if parent_id and not conversation_id: + error = t.Error( + source="User", + message="conversation_id must be set once parent_id is set", + code=t.ErrorType.SERVER_ERROR, + ) + raise error + if conversation_id and conversation_id != self.conversation_id: + self.parent_id = None + conversation_id = conversation_id or self.conversation_id + + parent_id = parent_id or self.parent_id or "" + if not conversation_id and not parent_id: + parent_id = str(uuid.uuid4()) + if conversation_id and not parent_id: + if conversation_id not in self.conversation_mapping: + await self.__map_conversations() + if conversation_id in self.conversation_mapping: + parent_id = self.conversation_mapping[conversation_id] + else: # invalid conversation_id provided, treat as a new conversation + conversation_id = None + parent_id = str(uuid.uuid4()) + model = model or self.config.get("model") or "text-davinci-002-render-sha" + data = { + "action": "continue", + "conversation_id": conversation_id, + "parent_message_id": parent_id, + "model": model + or self.config.get("model") + or ( + "text-davinci-002-render-paid" + if self.config.get("paid") + else "text-davinci-002-render-sha" + ), + "history_and_training_disabled": self.disable_history, + } + async for msg in self.__send_request( + data=data, + auto_continue=auto_continue, + timeout=timeout, + ): + yield msg + + async def get_conversations(self, offset: int = 0, limit: int = 20) -> list: + """ + Get conversations + :param offset: Integer + :param limit: Integer + """ + url = f"{self.base_url}conversations?offset={offset}&limit={limit}" + response = await self.session.get(url) + await self.__check_response(response) + data = json.loads(response.text) + return data["items"] + + async def get_msg_history( + self, + convo_id: str, + encoding: str | None = "utf-8", + ) -> dict: + """ + Get message history + :param id: UUID of conversation + """ + url = f"{self.base_url}conversation/{convo_id}" + response = await self.session.get(url) + if encoding is not None: + response.encoding = encoding + await self.__check_response(response) + return json.loads(response.text) + return None + + async def share_conversation( + self, + title: str = None, + convo_id: str = None, + node_id: str = None, + anonymous: bool = True, + ) -> str: + """ + Creates a share link to a conversation + :param convo_id: UUID of conversation + :param node_id: UUID of node + + Returns: + str: A URL to the shared link + """ + convo_id = convo_id or self.conversation_id + node_id = node_id or self.parent_id + # First create the share + payload = { + "conversation_id": convo_id, + "current_node_id": node_id, + "is_anonymous": anonymous, + } + url = f"{self.base_url}share/create" + response = await self.session.post( + url, + data=json.dumps(payload), + ) + await self.__check_response(response) + share_url = response.json().get("share_url") + # Then patch the share to make public + share_id = response.json().get("share_id") + url = f"{self.base_url}share/{share_id}" + print(url) + payload = { + "share_id": share_id, + "highlighted_message_id": node_id, + "title": title or response.json().get("title", "New chat"), + "is_public": True, + "is_visible": True, + "is_anonymous": True, + } + response = await self.session.patch( + url, + data=json.dumps(payload), + ) + await self.__check_response(response) + return share_url + + async def gen_title(self, convo_id: str, message_id: str) -> None: + """ + Generate title for conversation + """ + url = f"{self.base_url}conversation/gen_title/{convo_id}" + response = await self.session.post( + url, + data=json.dumps( + {"message_id": message_id, "model": "text-davinci-002-render"}, + ), + ) + await self.__check_response(response) + + async def change_title(self, convo_id: str, title: str) -> None: + """ + Change title of conversation + :param convo_id: UUID of conversation + :param title: String + """ + url = f"{self.base_url}conversation/{convo_id}" + response = await self.session.patch(url, data=f'{{"title": "{title}"}}') + await self.__check_response(response) + + async def delete_conversation(self, convo_id: str) -> None: + """ + Delete conversation + :param convo_id: UUID of conversation + """ + url = f"{self.base_url}conversation/{convo_id}" + response = await self.session.patch(url, data='{"is_visible": false}') + await self.__check_response(response) + + async def clear_conversations(self) -> None: + """ + Delete all conversations + """ + url = f"{self.base_url}conversations" + response = await self.session.patch(url, data='{"is_visible": false}') + await self.__check_response(response) + + async def __map_conversations(self) -> None: + conversations = await self.get_conversations() + histories = [await self.get_msg_history(x["id"]) for x in conversations] + for x, y in zip(conversations, histories): + self.conversation_mapping[x["id"]] = y["current_node"] + + def __check_fields(self, data: dict) -> bool: + try: + data["message"]["content"] + except (TypeError, KeyError): + return False + return True + + async def __check_response(self, response: httpx.Response) -> None: + # 改成自带的错误处理 + try: + response.raise_for_status() + except httpx.HTTPStatusError as ex: + await response.aread() + error = t.Error( + source="OpenAI", + message=response.text, + code=response.status_code, + ) + raise error from ex + + +get_input = logger(is_timed=False)(get_input) + + +@logger(is_timed=False) +def configure() -> dict: + """ + Looks for a config file in the following locations: + """ + config_files: list[Path] = [Path("config.json")] + if xdg_config_home := getenv("XDG_CONFIG_HOME"): + config_files.append(Path(xdg_config_home, "revChatGPT/config.json")) + if user_home := getenv("HOME"): + config_files.append(Path(user_home, ".config/revChatGPT/config.json")) + if windows_home := getenv("HOMEPATH"): + config_files.append(Path(f"{windows_home}/.config/revChatGPT/config.json")) + if config_file := next((f for f in config_files if f.exists()), None): + with open(config_file, encoding="utf-8") as f: + config = json.load(f) + else: + print("No config file found.") + raise FileNotFoundError("No config file found.") + return config + + +@logger(is_timed=False) +def main(config: dict) -> any: + """ + Main function for the chatGPT program. + """ + chatbot = Chatbot( + config, + conversation_id=config.get("conversation_id"), + parent_id=config.get("parent_id"), + ) + + def handle_commands(command: str) -> bool: + if command == "!help": + print( + """ + !help - Show this message + !reset - Forget the current conversation + !config - Show the current configuration + !plugins - Show the current plugins + !switch x - Switch to plugin x. Need to reset the conversation to ativate the plugin. + !rollback x - Rollback the conversation (x being the number of messages to rollback) + !setconversation - Changes the conversation + !share - Creates a share link to the current conversation + !exit - Exit this program + """, + ) + elif command == "!reset": + chatbot.reset_chat() + print("Chat session successfully reset.") + elif command == "!config": + print(json.dumps(chatbot.config, indent=4)) + elif command.startswith("!rollback"): + try: + rollback = int(command.split(" ")[1]) + except IndexError: + logging.exception( + "No number specified, rolling back 1 message", + stack_info=True, + ) + rollback = 1 + chatbot.rollback_conversation(rollback) + print(f"Rolled back {rollback} messages.") + elif command.startswith("!setconversation"): + try: + chatbot.conversation_id = chatbot.config[ + "conversation_id" + ] = command.split(" ")[1] + print("Conversation has been changed") + except IndexError: + log.exception( + "Please include conversation UUID in command", + stack_info=True, + ) + print("Please include conversation UUID in command") + elif command.startswith("!continue"): + print() + print(f"{bcolors.OKGREEN + bcolors.BOLD}Chatbot: {bcolors.ENDC}") + prev_text = "" + for data in chatbot.continue_write(): + message = data["message"][len(prev_text) :] + print(message, end="", flush=True) + prev_text = data["message"] + print(bcolors.ENDC) + print() + elif command.startswith("!share"): + print(f"Conversation shared at {chatbot.share_conversation()}") + elif command == "!exit": + if isinstance(chatbot.session, httpx.AsyncClient): + chatbot.session.aclose() + sys.exit() + else: + return False + return True + + session = create_session() + completer = create_completer( + [ + "!help", + "!reset", + "!config", + "!rollback", + "!exit", + "!setconversation", + "!continue", + "!plugins", + "!switch", + "!share", + ], + ) + print() + try: + result = {} + while True: + print(f"{bcolors.OKBLUE + bcolors.BOLD}You: {bcolors.ENDC}") + + prompt = get_input(session=session, completer=completer) + if prompt.startswith("!") and handle_commands(prompt): + continue + + print() + print(f"{bcolors.OKGREEN + bcolors.BOLD}Chatbot: {bcolors.ENDC}") + if chatbot.config.get("model") == "gpt-4-browsing": + print("Browsing takes a while, please wait...") + with Live(Markdown(""), auto_refresh=False) as live: + for data in chatbot.ask(prompt=prompt, auto_continue=True): + if data["recipient"] != "all": + continue + result = data + message = data["message"] + live.update(Markdown(message), refresh=True) + print() + + if result.get("citations", False): + print( + f"{bcolors.WARNING + bcolors.BOLD}Citations: {bcolors.ENDC}", + ) + for citation in result["citations"]: + print( + f'{citation["metadata"]["title"]}: {citation["metadata"]["url"]}', + ) + print() + + except (KeyboardInterrupt, EOFError): + sys.exit() + except Exception as exc: + error = t.CLIError("command line program unknown error") + raise error from exc + + +if __name__ == "__main__": + print( + f""" + ChatGPT - A command-line interface to OpenAI's ChatGPT (https://chat.openai.com/chat) + Repo: github.com/acheong08/ChatGPT + """, + ) + print("Type '!help' to show a full list of commands") + print( + f"{bcolors.BOLD}{bcolors.WARNING}Press Esc followed by Enter or Alt+Enter to send a message.{bcolors.ENDC}", + ) + main(configure()) + +class RevChatGPTModelv1: + def __init__(self, access_token=None, **kwargs): + super().__init__() + self.config = kwargs + if access_token: + self.chatbot = Chatbot(config={"access_token": access_token}) + else: + raise ValueError("access_token must be provided.") + + def run(self, task: str) -> str: + self.start_time = time.time() + prev_text = "" + for data in self.chatbot.ask(task, fileinfo=None): + message = data["message"][len(prev_text):] + prev_text = data["message"] + self.end_time = time.time() + return prev_text + + def generate_summary(self, text: str) -> str: + # Implement this method based on your requirements + pass + + def enable_plugin(self, plugin_id: str): + self.chatbot.install_plugin(plugin_id=plugin_id) + + def list_plugins(self): + return self.chatbot.get_plugins() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Manage RevChatGPT plugins.') + parser.add_argument('--enable', metavar='plugin_id', help='the plugin to enable') + parser.add_argument('--list', action='store_true', help='list all available plugins') + parser.add_argument('--access_token', required=True, help='access token for RevChatGPT') + + args = parser.parse_args() + + model = RevChatGPTModelv1(access_token=args.access_token) + + if args.enable: + model.enable_plugin(args.enable) + if args.list: + plugins = model.list_plugins() + for plugin in plugins: + print(f"Plugin ID: {plugin['id']}, Name: {plugin['name']}") + diff --git a/swarms/models/revgpt.py b/swarms/models/revgptV4.py similarity index 97% rename from swarms/models/revgpt.py rename to swarms/models/revgptV4.py index ad75a4db..c3e2f78c 100644 --- a/swarms/models/revgpt.py +++ b/swarms/models/revgptV4.py @@ -39,43 +39,22 @@ from rich.markdown import Markdown import argparse import re +import schemas.typings as t from prompt_toolkit import prompt from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.completion import WordCompleter from prompt_toolkit.history import InMemoryHistory from prompt_toolkit.key_binding import KeyBindings +from schemas.typings import Colors bindings = KeyBindings() -BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/") +# BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/") +BASE_URL = environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/") +# BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/") -class Colors: - """ - Colors for printing - """ - - HEADER = "\033[95m" - OKBLUE = "\033[94m" - OKCYAN = "\033[96m" - OKGREEN = "\033[92m" - WARNING = "\033[93m" - FAIL = "\033[91m" - ENDC = "\033[0m" - BOLD = "\033[1m" - UNDERLINE = "\033[4m" - - def __init__(self) -> None: - if os.getenv("NO_COLOR"): - Colors.HEADER = "" - Colors.OKBLUE = "" - Colors.OKCYAN = "" - Colors.OKGREEN = "" - Colors.WARNING = "" - Colors.FAIL = "" - Colors.ENDC = "" - Colors.BOLD = "" - Colors.UNDERLINE = "" +bcolors = t.Colors() def create_keybindings(key: str = "c-@") -> KeyBindings: """ @@ -922,23 +901,23 @@ class Chatbot: return False return True - @logger(is_timed=False) - def __check_response(self, response: requests.Response) -> None: - """Make sure response is success - Args: - response (_type_): _description_ - Raises: - Error: _description_ - """ - try: - response.raise_for_status() - except requests.exceptions.HTTPError as ex: - error = Colors.Error( - source="OpenAI", - message=response.text, - code=response.status_code, - ) - raise error from ex + # @logger(is_timed=False) + # def __check_response(self, response: requests.Response) -> None: + # """Make sure response is success + # Args: + # response (_type_): _description_ + # Raises: + # Error: _description_ + # """ + # try: + # response.raise_for_status() + # except requests.exceptions.HTTPError as ex: + # error = Colors.Error( + # source="OpenAI", + # message=response.text, + # code=response.status_code, + # ) + # raise error from ex @logger(is_timed=True) def get_conversations( @@ -1784,7 +1763,7 @@ if __name__ == "__main__": ) main(configure()) -class RevChatGPTModel: +class RevChatGPTModelv4: def __init__(self, access_token=None, **kwargs): super().__init__() self.config = kwargs @@ -1820,7 +1799,7 @@ if __name__ == "__main__": args = parser.parse_args() - model = RevChatGPTModel(access_token=args.access_token) + model = RevChatGPTModelv4(access_token=args.access_token) if args.enable: model.enable_plugin(args.enable) diff --git a/swarms/utils/revutils.py b/swarms/utils/revutils.py new file mode 100644 index 00000000..8e7e0b75 --- /dev/null +++ b/swarms/utils/revutils.py @@ -0,0 +1,166 @@ +from functools import wraps +import logging +import os +import re +import secrets + +import time +from prompt_toolkit import prompt +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from prompt_toolkit.key_binding import KeyBindings +from schemas.typings import Colors + +bindings = KeyBindings() + +# BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/") +BASE_URL = os.environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/") +# BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/") + +def create_keybindings(key: str = "c-@") -> KeyBindings: + """ + Create keybindings for prompt_toolkit. Default key is ctrl+space. + For possible keybindings, see: https://python-prompt-toolkit.readthedocs.io/en/stable/pages/advanced_topics/key_bindings.html#list-of-special-keys + """ + + @bindings.add(key) + def _(event: dict) -> None: + event.app.exit(result=event.app.current_buffer.text) + + return bindings + + +def create_session() -> PromptSession: + return PromptSession(history=InMemoryHistory()) + + +def create_completer(commands: list, pattern_str: str = "$") -> WordCompleter: + return WordCompleter(words=commands, pattern=re.compile(pattern_str)) + + +def get_input( + session: PromptSession = None, + completer: WordCompleter = None, + key_bindings: KeyBindings = None, +) -> str: + """ + Multiline input function. + """ + return ( + session.prompt( + completer=completer, + multiline=True, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=key_bindings, + ) + if session + else prompt(multiline=True) + ) + + +async def get_input_async( + session: PromptSession = None, + completer: WordCompleter = None, +) -> str: + """ + Multiline input function. + """ + return ( + await session.prompt_async( + completer=completer, + multiline=True, + auto_suggest=AutoSuggestFromHistory(), + ) + if session + else prompt(multiline=True) + ) + + +def get_filtered_keys_from_object(obj: object, *keys: str) -> any: + """ + Get filtered list of object variable names. + :param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys. + :return: List of class keys. + """ + class_keys = obj.__dict__.keys() + if not keys: + return set(class_keys) + + # Remove the passed keys from the class keys. + if keys[0] == "not": + return {key for key in class_keys if key not in keys[1:]} + # Check if all passed keys are valid + if invalid_keys := set(keys) - class_keys: + raise ValueError( + f"Invalid keys: {invalid_keys}", + ) + # Only return specified keys that are in class_keys + return {key for key in keys if key in class_keys} + +def generate_random_hex(length: int = 17) -> str: + """Generate a random hex string + Args: + length (int, optional): Length of the hex string. Defaults to 17. + Returns: + str: Random hex string + """ + return secrets.token_hex(length) + + +def random_int(min: int, max: int) -> int: + """Generate a random integer + Args: + min (int): Minimum value + max (int): Maximum value + Returns: + int: Random integer + """ + return secrets.randbelow(max - min) + min + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", + ) + +log = logging.getLogger(__name__) + + +def logger(is_timed: bool): + """Logger decorator + Args: + is_timed (bool): Whether to include function running time in exit log + Returns: + _type_: decorated function + """ + + def decorator(func): + wraps(func) + + def wrapper(*args, **kwargs): + log.debug( + "Entering %s with args %s and kwargs %s", + func.__name__, + args, + kwargs, + ) + start = time.time() + out = func(*args, **kwargs) + end = time.time() + if is_timed: + log.debug( + "Exiting %s with return value %s. Took %s seconds.", + func.__name__, + out, + end - start, + ) + else: + log.debug("Exiting %s with return value %s", func.__name__, out) + return out + + return wrapper + + return decorator +