#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os import sys from fairseq.dataclass.initialize import hydra_init from fairseq_cli.train import main as pre_main from fairseq import distributed_utils, metrics from fairseq.dataclass.configs import FairseqConfig import hydra import torch from omegaconf import OmegaConf logger = logging.getLogger("fairseq_cli.hydra_train") @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> float: cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) OmegaConf.set_struct(cfg, True) if cfg.common.reset_logging: reset_logging() # Hydra hijacks logging, fix that try: if cfg.common.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): distributed_utils.call_main(cfg, pre_main) else: distributed_utils.call_main(cfg, pre_main) except BaseException as e: if not cfg.common.suppress_crashes: raise else: logger.error("Crashed! " + str(e)) # get best val and return - useful for sweepers try: best_val = metrics.get_smoothed_value( "valid", cfg.checkpoint.best_checkpoint_metric ) except: best_val = None if best_val is None: best_val = float("inf") return best_val def reset_logging(): root = logging.getLogger() for handler in root.handlers: root.removeHandler(handler) root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) handler = logging.StreamHandler(sys.stdout) handler.setFormatter( logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) ) root.addHandler(handler) def cli_main(): try: from hydra._internal.utils import get_args cfg_name = get_args().config_name or "config" except: logger.warning("Failed to get config name from hydra args") cfg_name = "config" hydra_init(cfg_name) hydra_main() if __name__ == "__main__": cli_main()