feat: Remove typing

Former-commit-id: 8ed39cac53
discord-bot-framework
Zack 1 year ago
parent 8665c249d0
commit f2eb9c461a

@ -23,14 +23,10 @@ try:
except ImportError: except ImportError:
pass pass
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator
from typing import Generator
from typing import NoReturn
import tempfile import tempfile
import random import random
import os import os
# Import function type # Import function type
from typing import Callable as function
import httpx import httpx
import requests import requests
@ -39,12 +35,9 @@ from OpenAIAuth import Auth0 as Authenticator
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
from . import __version__
from . import typings as t
import argparse import argparse
import re import re
from typing import Set
from prompt_toolkit import prompt from prompt_toolkit import prompt
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
@ -55,6 +48,34 @@ from prompt_toolkit.key_binding import KeyBindings
bindings = KeyBindings() bindings = KeyBindings()
BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/")
class Colors:
"""
Colors for printing
"""
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
def __init__(self) -> None:
if os.getenv("NO_COLOR"):
Colors.HEADER = ""
Colors.OKBLUE = ""
Colors.OKCYAN = ""
Colors.OKGREEN = ""
Colors.WARNING = ""
Colors.FAIL = ""
Colors.ENDC = ""
Colors.BOLD = ""
Colors.UNDERLINE = ""
def create_keybindings(key: str = "c-@") -> KeyBindings: def create_keybindings(key: str = "c-@") -> KeyBindings:
""" """
@ -115,7 +136,7 @@ async def get_input_async(
) )
def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]: def get_filtered_keys_from_object(obj: object, *keys: str) -> any:
""" """
Get filtered list of object variable names. Get filtered list of object variable names.
:param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys. :param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
@ -203,9 +224,8 @@ def logger(is_timed: bool):
BASE_URL = environ.get("CHATGPT_BASE_URL", "http://192.168.250.249:9898/api/")
bcolors = t.Colors() bcolors = Colors()
def captcha_solver(images: list[str], challenge_details: dict) -> int: def captcha_solver(images: list[str], challenge_details: dict) -> int:
@ -340,7 +360,7 @@ class Chatbot:
cached_access_token = self.__get_cached_access_token( cached_access_token = self.__get_cached_access_token(
self.config.get("email", None), self.config.get("email", None),
) )
except t.Error as error: except Colors.Error as error:
if error.code == 5: if error.code == 5:
raise raise
cached_access_token = None cached_access_token = None
@ -413,7 +433,7 @@ class Chatbot:
if "access_token" in self.config: if "access_token" in self.config:
self.set_access_token(self.config["access_token"]) self.set_access_token(self.config["access_token"])
elif "email" not in self.config or "password" not in self.config: elif "email" not in self.config or "password" not in self.config:
error = t.AuthenticationError("Insufficient login details provided!") error = Colors.AuthenticationError("Insufficient login details provided!")
raise error raise error
if "access_token" not in self.config: if "access_token" not in self.config:
try: try:
@ -470,26 +490,26 @@ class Chatbot:
d_access_token = base64.b64decode(s_access_token[1]) d_access_token = base64.b64decode(s_access_token[1])
d_access_token = json.loads(d_access_token) d_access_token = json.loads(d_access_token)
except binascii.Error: except binascii.Error:
error = t.Error( error = Colors.Error(
source="__get_cached_access_token", source="__get_cached_access_token",
message="Invalid access token", message="Invalid access token",
code=t.ErrorType.INVALID_ACCESS_TOKEN_ERROR, code=Colors.ErrorType.INVALID_ACCESS_TOKEN_ERROR,
) )
raise error from None raise error from None
except json.JSONDecodeError: except json.JSONDecodeError:
error = t.Error( error = Colors.Error(
source="__get_cached_access_token", source="__get_cached_access_token",
message="Invalid access token", message="Invalid access token",
code=t.ErrorType.INVALID_ACCESS_TOKEN_ERROR, code=Colors.ErrorType.INVALID_ACCESS_TOKEN_ERROR,
) )
raise error from None raise error from None
exp = d_access_token.get("exp", None) exp = d_access_token.get("exp", None)
if exp is not None and exp < time.time(): if exp is not None and exp < time.time():
error = t.Error( error = Colors.Error(
source="__get_cached_access_token", source="__get_cached_access_token",
message="Access token expired", message="Access token expired",
code=t.ErrorType.EXPIRED_ACCESS_TOKEN_ERROR, code=Colors.ErrorType.EXPIRED_ACCESS_TOKEN_ERROR,
) )
raise error raise error
@ -535,7 +555,7 @@ class Chatbot:
"""Login to OpenAI by email and password""" """Login to OpenAI by email and password"""
if not self.config.get("email") and not self.config.get("password"): if not self.config.get("email") and not self.config.get("password"):
log.error("Insufficient login details provided!") log.error("Insufficient login details provided!")
error = t.AuthenticationError("Insufficient login details provided!") error = Colors.AuthenticationError("Insufficient login details provided!")
raise error raise error
auth = Authenticator( auth = Authenticator(
email_address=self.config.get("email"), email_address=self.config.get("email"),
@ -553,7 +573,7 @@ class Chatbot:
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
**kwargs, **kwargs,
) -> Generator[dict, None, None]: ) -> any:
log.debug("Sending the payload") log.debug("Sending the payload")
if ( if (
@ -591,10 +611,10 @@ class Chatbot:
line = str(line)[2:-1] line = str(line)[2:-1]
if line.lower() == "internal server error": if line.lower() == "internal server error":
log.error(f"Internal Server Error: {line}") log.error(f"Internal Server Error: {line}")
error = t.Error( error = Colors.Error(
source="ask", source="ask",
message="Internal Server Error", message="Internal Server Error",
code=t.ErrorType.SERVER_ERROR, code=Colors.ErrorType.SERVER_ERROR,
) )
raise error raise error
if not line or line is None: if not line or line is None:
@ -676,7 +696,7 @@ class Chatbot:
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
**kwargs, **kwargs,
) -> Generator[dict, None, None]: ) -> any:
"""Ask a question to the chatbot """Ask a question to the chatbot
Args: Args:
messages (list[dict]): The messages to send messages (list[dict]): The messages to send
@ -699,10 +719,10 @@ class Chatbot:
""" """
print(conversation_id) print(conversation_id)
if parent_id and not conversation_id: if parent_id and not conversation_id:
raise t.Error( raise Colors.Error(
source="User", source="User",
message="conversation_id must be set once parent_id is set", message="conversation_id must be set once parent_id is set",
code=t.ErrorType.USER_ERROR, code=Colors.ErrorType.USER_ERROR,
) )
print(conversation_id) print(conversation_id)
if conversation_id and conversation_id != self.conversation_id: if conversation_id and conversation_id != self.conversation_id:
@ -771,7 +791,7 @@ class Chatbot:
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
**kwargs, **kwargs,
) -> Generator[dict, None, None]: ) -> any:
"""Ask a question to the chatbot """Ask a question to the chatbot
Args: Args:
prompt (str): The question prompt (str): The question
@ -818,7 +838,7 @@ class Chatbot:
model: str = "", model: str = "",
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
) -> Generator[dict, None, None]: ) -> any:
"""let the chatbot continue to write. """let the chatbot continue to write.
Args: Args:
conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None.
@ -838,10 +858,10 @@ class Chatbot:
} }
""" """
if parent_id and not conversation_id: if parent_id and not conversation_id:
raise t.Error( raise Colors.Error(
source="User", source="User",
message="conversation_id must be set once parent_id is set", message="conversation_id must be set once parent_id is set",
code=t.ErrorType.USER_ERROR, code=Colors.ErrorType.USER_ERROR,
) )
if conversation_id and conversation_id != self.conversation_id: if conversation_id and conversation_id != self.conversation_id:
@ -913,7 +933,7 @@ class Chatbot:
try: try:
response.raise_for_status() response.raise_for_status()
except requests.exceptions.HTTPError as ex: except requests.exceptions.HTTPError as ex:
error = t.Error( error = Colors.Error(
source="OpenAI", source="OpenAI",
message=response.text, message=response.text,
code=response.status_code, code=response.status_code,
@ -1146,7 +1166,7 @@ class AsyncChatbot(Chatbot):
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
**kwargs, **kwargs,
) -> AsyncGenerator[dict, None]: ) -> any:
log.debug("Sending the payload") log.debug("Sending the payload")
cid, pid = data["conversation_id"], data["parent_message_id"] cid, pid = data["conversation_id"], data["parent_message_id"]
@ -1165,10 +1185,10 @@ class AsyncChatbot(Chatbot):
async for line in response.aiter_lines(): async for line in response.aiter_lines():
if line.lower() == "internal server error": if line.lower() == "internal server error":
log.error(f"Internal Server Error: {line}") log.error(f"Internal Server Error: {line}")
error = t.Error( error = Colors.Error(
source="ask", source="ask",
message="Internal Server Error", message="Internal Server Error",
code=t.ErrorType.SERVER_ERROR, code=Colors.ErrorType.SERVER_ERROR,
) )
raise error raise error
if not line or line is None: if not line or line is None:
@ -1247,7 +1267,7 @@ class AsyncChatbot(Chatbot):
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
**kwargs, **kwargs,
) -> AsyncGenerator[dict, None]: ) -> any:
"""Post messages to the chatbot """Post messages to the chatbot
Args: Args:
messages (list[dict]): the messages to post messages (list[dict]): the messages to post
@ -1270,10 +1290,10 @@ class AsyncChatbot(Chatbot):
} }
""" """
if parent_id and not conversation_id: if parent_id and not conversation_id:
raise t.Error( raise Colors.Error(
source="User", source="User",
message="conversation_id must be set once parent_id is set", message="conversation_id must be set once parent_id is set",
code=t.ErrorType.USER_ERROR, code=Colors.ErrorType.USER_ERROR,
) )
if conversation_id and conversation_id != self.conversation_id: if conversation_id and conversation_id != self.conversation_id:
@ -1338,7 +1358,7 @@ class AsyncChatbot(Chatbot):
auto_continue: bool = False, auto_continue: bool = False,
timeout: int = 360, timeout: int = 360,
**kwargs, **kwargs,
) -> AsyncGenerator[dict, None]: ) -> any:
"""Ask a question to the chatbot """Ask a question to the chatbot
Args: Args:
prompt (str): The question to ask prompt (str): The question to ask
@ -1391,7 +1411,7 @@ class AsyncChatbot(Chatbot):
model: str = "", model: str = "",
auto_continue: bool = False, auto_continue: bool = False,
timeout: float = 360, timeout: float = 360,
) -> AsyncGenerator[dict, None]: ) -> any:
"""let the chatbot continue to write """let the chatbot continue to write
Args: Args:
conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None. conversation_id (str | None, optional): UUID for the conversation to continue on. Defaults to None.
@ -1412,10 +1432,10 @@ class AsyncChatbot(Chatbot):
} }
""" """
if parent_id and not conversation_id: if parent_id and not conversation_id:
error = t.Error( error = Colors.Error(
source="User", source="User",
message="conversation_id must be set once parent_id is set", message="conversation_id must be set once parent_id is set",
code=t.ErrorType.SERVER_ERROR, code=Colors.ErrorType.SERVER_ERROR,
) )
raise error raise error
if conversation_id and conversation_id != self.conversation_id: if conversation_id and conversation_id != self.conversation_id:
@ -1590,7 +1610,7 @@ class AsyncChatbot(Chatbot):
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as ex: except httpx.HTTPStatusError as ex:
await response.aread() await response.aread()
error = t.Error( error = Colors.Error(
source="OpenAI", source="OpenAI",
message=response.text, message=response.text,
code=response.status_code, code=response.status_code,
@ -1623,7 +1643,7 @@ def configure() -> dict:
@logger(is_timed=False) @logger(is_timed=False)
def main(config: dict) -> NoReturn: def main(config: dict) -> any:
""" """
Main function for the chatGPT program. Main function for the chatGPT program.
""" """
@ -1747,7 +1767,7 @@ def main(config: dict) -> NoReturn:
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
exit() exit()
except Exception as exc: except Exception as exc:
error = t.CLIError("command line program unknown error") error = Colors.CLIError("command line program unknown error")
raise error from exc raise error from exc
@ -1756,7 +1776,6 @@ if __name__ == "__main__":
f""" f"""
ChatGPT - A command-line interface to OpenAI's ChatGPT (https://chat.openai.com/chat) ChatGPT - A command-line interface to OpenAI's ChatGPT (https://chat.openai.com/chat)
Repo: github.com/acheong08/ChatGPT Repo: github.com/acheong08/ChatGPT
Version: {__version__}
""", """,
) )
print("Type '!help' to show a full list of commands") print("Type '!help' to show a full list of commands")

Loading…
Cancel
Save