|
import os |
|
import logging |
|
|
|
from transformers.trainer_callback import TrainerCallback |
|
|
|
|
|
def _setup_logger(): |
|
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") |
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setFormatter(log_format) |
|
|
|
data_dir = './data/' |
|
os.makedirs(data_dir, exist_ok=True) |
|
file_handler = logging.FileHandler('{}/log.txt'.format(data_dir)) |
|
file_handler.setFormatter(log_format) |
|
|
|
logger.handlers = [console_handler, file_handler] |
|
|
|
return logger |
|
|
|
|
|
logger = _setup_logger() |
|
|
|
|
|
class LoggerCallback(TrainerCallback): |
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
_ = logs.pop("total_flos", None) |
|
if state.is_world_process_zero: |
|
logger.info(logs) |
|
|