|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import sys |
|
|
|
from fairseq.dataclass.initialize import hydra_init |
|
from fairseq_cli.train import main as pre_main |
|
from fairseq import distributed_utils, metrics |
|
from fairseq.dataclass.configs import FairseqConfig |
|
|
|
import hydra |
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
|
|
logger = logging.getLogger("fairseq_cli.hydra_train") |
|
|
|
|
|
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") |
|
def hydra_main(cfg: FairseqConfig) -> float: |
|
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) |
|
|
|
OmegaConf.set_struct(cfg, True) |
|
|
|
if cfg.common.reset_logging: |
|
reset_logging() |
|
|
|
try: |
|
if cfg.common.profile: |
|
with torch.cuda.profiler.profile(): |
|
with torch.autograd.profiler.emit_nvtx(): |
|
distributed_utils.call_main(cfg, pre_main) |
|
else: |
|
distributed_utils.call_main(cfg, pre_main) |
|
except BaseException as e: |
|
if not cfg.common.suppress_crashes: |
|
raise |
|
else: |
|
logger.error("Crashed! " + str(e)) |
|
|
|
|
|
try: |
|
best_val = metrics.get_smoothed_value( |
|
"valid", cfg.checkpoint.best_checkpoint_metric |
|
) |
|
except: |
|
best_val = None |
|
|
|
if best_val is None: |
|
best_val = float("inf") |
|
|
|
return best_val |
|
|
|
|
|
def reset_logging(): |
|
root = logging.getLogger() |
|
for handler in root.handlers: |
|
root.removeHandler(handler) |
|
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) |
|
handler = logging.StreamHandler(sys.stdout) |
|
handler.setFormatter( |
|
logging.Formatter( |
|
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
) |
|
root.addHandler(handler) |
|
|
|
|
|
def cli_main(): |
|
try: |
|
from hydra._internal.utils import get_args |
|
|
|
cfg_name = get_args().config_name or "config" |
|
except: |
|
logger.warning("Failed to get config name from hydra args") |
|
cfg_name = "config" |
|
|
|
hydra_init(cfg_name) |
|
hydra_main() |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|