parent
c1397fdf26
commit
3f497e5207
@ -0,0 +1,198 @@
|
||||
"""
|
||||
A module that contains all the types used in this project
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
|
||||
python_version = list(platform.python_version_tuple())
|
||||
SUPPORT_ADD_NOTES = int(python_version[0]) >= 3 and int(python_version[1]) >= 11
|
||||
|
||||
|
||||
class ChatbotError(Exception):
|
||||
"""
|
||||
Base class for all Chatbot errors in this Project
|
||||
"""
|
||||
|
||||
def __init__(self, *args: object) -> None:
|
||||
if SUPPORT_ADD_NOTES:
|
||||
super().add_note(
|
||||
"Please check that the input is correct, or you can resolve this issue by filing an issue",
|
||||
)
|
||||
super().add_note("Project URL: https://github.com/acheong08/ChatGPT")
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class ActionError(ChatbotError):
|
||||
"""
|
||||
Subclass of ChatbotError
|
||||
|
||||
An object that throws an error because the execution of an operation is blocked
|
||||
"""
|
||||
|
||||
def __init__(self, *args: object) -> None:
|
||||
if SUPPORT_ADD_NOTES:
|
||||
super().add_note(
|
||||
"The current operation is not allowed, which may be intentional",
|
||||
)
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class ActionNotAllowedError(ActionError):
|
||||
"""
|
||||
Subclass of ActionError
|
||||
|
||||
An object that throws an error because the execution of an unalloyed operation is blocked
|
||||
"""
|
||||
|
||||
|
||||
class ActionRefuseError(ActionError):
|
||||
"""
|
||||
Subclass of ActionError
|
||||
|
||||
An object that throws an error because the execution of a refused operation is blocked.
|
||||
"""
|
||||
|
||||
|
||||
class CLIError(ChatbotError):
|
||||
"""
|
||||
Subclass of ChatbotError
|
||||
|
||||
The error caused by a CLI program error
|
||||
"""
|
||||
|
||||
|
||||
class ErrorType(Enum):
|
||||
"""
|
||||
Enumeration class for different types of errors.
|
||||
"""
|
||||
|
||||
USER_ERROR = -1
|
||||
UNKNOWN_ERROR = 0
|
||||
SERVER_ERROR = 1
|
||||
RATE_LIMIT_ERROR = 2
|
||||
INVALID_REQUEST_ERROR = 3
|
||||
EXPIRED_ACCESS_TOKEN_ERROR = 4
|
||||
INVALID_ACCESS_TOKEN_ERROR = 5
|
||||
PROHIBITED_CONCURRENT_QUERY_ERROR = 6
|
||||
AUTHENTICATION_ERROR = 7
|
||||
CLOUDFLARE_ERROR = 8
|
||||
|
||||
|
||||
class Error(ChatbotError):
|
||||
"""
|
||||
Base class for exceptions in V1 module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: str,
|
||||
message: str,
|
||||
*args: object,
|
||||
code: Union[ErrorType, int] = ErrorType.UNKNOWN_ERROR,
|
||||
) -> None:
|
||||
self.source: str = source
|
||||
self.message: str = message
|
||||
self.code: ErrorType | int = code
|
||||
super().__init__(*args)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.source}: {self.message} (code: {self.code})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.source}: {self.message} (code: {self.code})"
|
||||
|
||||
|
||||
class AuthenticationError(ChatbotError):
|
||||
"""
|
||||
Subclass of ChatbotError
|
||||
|
||||
The object of the error thrown by a validation failure or exception
|
||||
"""
|
||||
|
||||
def __init__(self, *args: object) -> None:
|
||||
if SUPPORT_ADD_NOTES:
|
||||
super().add_note(
|
||||
"Please check if your key is correct, maybe it may not be valid",
|
||||
)
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class APIConnectionError(ChatbotError):
|
||||
"""
|
||||
Subclass of ChatbotError
|
||||
|
||||
An exception object thrown when an API connection fails or fails to connect due to network or
|
||||
other miscellaneous reasons
|
||||
"""
|
||||
|
||||
def __init__(self, *args: object) -> None:
|
||||
if SUPPORT_ADD_NOTES:
|
||||
super().add_note(
|
||||
"Please check if there is a problem with your network connection",
|
||||
)
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class NotAllowRunning(ActionNotAllowedError):
|
||||
"""
|
||||
Subclass of ActionNotAllowedError
|
||||
|
||||
Direct startup is not allowed for some reason
|
||||
"""
|
||||
|
||||
|
||||
class ResponseError(APIConnectionError):
|
||||
"""
|
||||
Subclass of APIConnectionError
|
||||
|
||||
Error objects caused by API request errors due to network or other miscellaneous reasons
|
||||
"""
|
||||
|
||||
|
||||
class OpenAIError(APIConnectionError):
|
||||
"""
|
||||
Subclass of APIConnectionError
|
||||
|
||||
Error objects caused by OpenAI's own server errors
|
||||
"""
|
||||
|
||||
|
||||
class RequestError(APIConnectionError):
|
||||
"""
|
||||
Subclass of APIConnectionError
|
||||
|
||||
There is a problem with the API response due to network or other miscellaneous reasons, or there
|
||||
is no reply to the object that caused the error at all
|
||||
"""
|
||||
|
||||
|
||||
class Colors:
|
||||
"""
|
||||
Colors for printing
|
||||
"""
|
||||
|
||||
HEADER = "\033[95m"
|
||||
OKBLUE = "\033[94m"
|
||||
OKCYAN = "\033[96m"
|
||||
OKGREEN = "\033[92m"
|
||||
WARNING = "\033[93m"
|
||||
FAIL = "\033[91m"
|
||||
ENDC = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
UNDERLINE = "\033[4m"
|
||||
|
||||
def __init__(self) -> None:
|
||||
if os.getenv("NO_COLOR"):
|
||||
Colors.HEADER = ""
|
||||
Colors.OKBLUE = ""
|
||||
Colors.OKCYAN = ""
|
||||
Colors.OKGREEN = ""
|
||||
Colors.WARNING = ""
|
||||
Colors.FAIL = ""
|
||||
Colors.ENDC = ""
|
||||
Colors.BOLD = ""
|
||||
Colors.UNDERLINE = ""
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,166 @@
|
||||
from functools import wraps
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
|
||||
import time
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
from schemas.typings import Colors
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
# BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/")
|
||||
BASE_URL = os.environ.get("CHATGPT_BASE_URL", "https://ai.fakeopen.com/api/")
|
||||
# BASE_URL = environ.get("CHATGPT_BASE_URL", "https://bypass.churchless.tech/")
|
||||
|
||||
def create_keybindings(key: str = "c-@") -> KeyBindings:
|
||||
"""
|
||||
Create keybindings for prompt_toolkit. Default key is ctrl+space.
|
||||
For possible keybindings, see: https://python-prompt-toolkit.readthedocs.io/en/stable/pages/advanced_topics/key_bindings.html#list-of-special-keys
|
||||
"""
|
||||
|
||||
@bindings.add(key)
|
||||
def _(event: dict) -> None:
|
||||
event.app.exit(result=event.app.current_buffer.text)
|
||||
|
||||
return bindings
|
||||
|
||||
|
||||
def create_session() -> PromptSession:
|
||||
return PromptSession(history=InMemoryHistory())
|
||||
|
||||
|
||||
def create_completer(commands: list, pattern_str: str = "$") -> WordCompleter:
|
||||
return WordCompleter(words=commands, pattern=re.compile(pattern_str))
|
||||
|
||||
|
||||
def get_input(
|
||||
session: PromptSession = None,
|
||||
completer: WordCompleter = None,
|
||||
key_bindings: KeyBindings = None,
|
||||
) -> str:
|
||||
"""
|
||||
Multiline input function.
|
||||
"""
|
||||
return (
|
||||
session.prompt(
|
||||
completer=completer,
|
||||
multiline=True,
|
||||
auto_suggest=AutoSuggestFromHistory(),
|
||||
key_bindings=key_bindings,
|
||||
)
|
||||
if session
|
||||
else prompt(multiline=True)
|
||||
)
|
||||
|
||||
|
||||
async def get_input_async(
|
||||
session: PromptSession = None,
|
||||
completer: WordCompleter = None,
|
||||
) -> str:
|
||||
"""
|
||||
Multiline input function.
|
||||
"""
|
||||
return (
|
||||
await session.prompt_async(
|
||||
completer=completer,
|
||||
multiline=True,
|
||||
auto_suggest=AutoSuggestFromHistory(),
|
||||
)
|
||||
if session
|
||||
else prompt(multiline=True)
|
||||
)
|
||||
|
||||
|
||||
def get_filtered_keys_from_object(obj: object, *keys: str) -> any:
|
||||
"""
|
||||
Get filtered list of object variable names.
|
||||
:param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
|
||||
:return: List of class keys.
|
||||
"""
|
||||
class_keys = obj.__dict__.keys()
|
||||
if not keys:
|
||||
return set(class_keys)
|
||||
|
||||
# Remove the passed keys from the class keys.
|
||||
if keys[0] == "not":
|
||||
return {key for key in class_keys if key not in keys[1:]}
|
||||
# Check if all passed keys are valid
|
||||
if invalid_keys := set(keys) - class_keys:
|
||||
raise ValueError(
|
||||
f"Invalid keys: {invalid_keys}",
|
||||
)
|
||||
# Only return specified keys that are in class_keys
|
||||
return {key for key in keys if key in class_keys}
|
||||
|
||||
def generate_random_hex(length: int = 17) -> str:
|
||||
"""Generate a random hex string
|
||||
Args:
|
||||
length (int, optional): Length of the hex string. Defaults to 17.
|
||||
Returns:
|
||||
str: Random hex string
|
||||
"""
|
||||
return secrets.token_hex(length)
|
||||
|
||||
|
||||
def random_int(min: int, max: int) -> int:
|
||||
"""Generate a random integer
|
||||
Args:
|
||||
min (int): Minimum value
|
||||
max (int): Maximum value
|
||||
Returns:
|
||||
int: Random integer
|
||||
"""
|
||||
return secrets.randbelow(max - min) + min
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def logger(is_timed: bool):
|
||||
"""Logger decorator
|
||||
Args:
|
||||
is_timed (bool): Whether to include function running time in exit log
|
||||
Returns:
|
||||
_type_: decorated function
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
wraps(func)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
log.debug(
|
||||
"Entering %s with args %s and kwargs %s",
|
||||
func.__name__,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
start = time.time()
|
||||
out = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
if is_timed:
|
||||
log.debug(
|
||||
"Exiting %s with return value %s. Took %s seconds.",
|
||||
func.__name__,
|
||||
out,
|
||||
end - start,
|
||||
)
|
||||
else:
|
||||
log.debug("Exiting %s with return value %s", func.__name__, out)
|
||||
return out
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
Loading…
Reference in new issue