You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
4.1 KiB
131 lines
4.1 KiB
import logging
|
|
from typing import List, Optional
|
|
|
|
logger_initialized = {}
|
|
|
|
|
|
def get_logger(
|
|
name: str,
|
|
log_file: Optional[str] = None,
|
|
log_level: int = logging.INFO,
|
|
file_mode: str = "w",
|
|
):
|
|
"""Initialize and get a logger by name.
|
|
|
|
If the logger has not been initialized, this method will initialize the
|
|
logger by adding one or two handlers, otherwise the initialized logger will
|
|
be directly returned. During initialization, a StreamHandler will always be
|
|
added. If `log_file` is specified, a FileHandler will also be added.
|
|
Args:
|
|
name (str): Logger name.
|
|
log_file (str | None): The log filename. If specified, a FileHandler
|
|
will be added to the logger.
|
|
log_level (int): The logger level.
|
|
file_mode (str): The file mode used in opening log file.
|
|
Defaults to 'w'.
|
|
Returns:
|
|
logging.Logger: The expected logger.
|
|
"""
|
|
# use logger in mmengine if exists.
|
|
try:
|
|
from mmengine.logging import MMLogger
|
|
|
|
if MMLogger.check_instance_created(name):
|
|
logger = MMLogger.get_instance(name)
|
|
else:
|
|
logger = MMLogger.get_instance(
|
|
name,
|
|
logger_name=name,
|
|
log_file=log_file,
|
|
log_level=log_level,
|
|
file_mode=file_mode,
|
|
)
|
|
return logger
|
|
|
|
except Exception:
|
|
pass
|
|
|
|
logger = logging.getLogger(name)
|
|
if name in logger_initialized:
|
|
return logger
|
|
# handle hierarchical names
|
|
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
|
# initialization since it is a child of "a".
|
|
for logger_name in logger_initialized:
|
|
if name.startswith(logger_name):
|
|
return logger
|
|
|
|
# handle duplicate logs to the console
|
|
for handler in logger.root.handlers:
|
|
if type(handler) is logging.StreamHandler:
|
|
handler.setLevel(logging.ERROR)
|
|
|
|
stream_handler = logging.StreamHandler()
|
|
handlers = [stream_handler]
|
|
|
|
if log_file is not None:
|
|
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
|
# provide an interface to change the file mode to the default
|
|
# behaviour.
|
|
file_handler = logging.FileHandler(log_file, file_mode)
|
|
handlers.append(file_handler)
|
|
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
for handler in handlers:
|
|
handler.setFormatter(formatter)
|
|
handler.setLevel(log_level)
|
|
logger.addHandler(handler)
|
|
|
|
logger.setLevel(log_level)
|
|
logger_initialized[name] = True
|
|
|
|
return logger
|
|
|
|
|
|
def filter_suffix(
|
|
response: str, suffixes: Optional[List[str]] = None
|
|
) -> str:
|
|
"""Filter response with suffixes.
|
|
|
|
Args:
|
|
response (str): generated response by LLMs.
|
|
suffixes (str): a list of suffixes to be deleted.
|
|
|
|
Return:
|
|
str: a clean response.
|
|
"""
|
|
if suffixes is None:
|
|
return response
|
|
for item in suffixes:
|
|
if response.endswith(item):
|
|
response = response[: len(response) - len(item)]
|
|
return response
|
|
|
|
|
|
# TODO remove stop_word_offsets stuff and make it clean
|
|
def _stop_words(stop_words: List[str], tokenizer: object):
|
|
"""return list of stop-words to numpy.ndarray."""
|
|
import numpy as np
|
|
|
|
if stop_words is None:
|
|
return None
|
|
assert isinstance(stop_words, List) and all(
|
|
isinstance(elem, str) for elem in stop_words
|
|
), f"stop_words must be a list but got {type(stop_words)}"
|
|
stop_indexes = []
|
|
for stop_word in stop_words:
|
|
stop_indexes += tokenizer.indexes_containing_token(stop_word)
|
|
assert isinstance(stop_indexes, List) and all(
|
|
isinstance(elem, int) for elem in stop_indexes
|
|
), "invalid stop_words"
|
|
# each id in stop_indexes represents a stop word
|
|
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
|
|
# detailed explanation about fastertransformer's stop_indexes
|
|
stop_word_offsets = range(1, len(stop_indexes) + 1)
|
|
stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(
|
|
np.int32
|
|
)
|
|
return stop_words
|