from .sst2 import SST2ProbInferenceForMC


task_mapper = {"sst2": SST2ProbInferenceForMC}


def load_task(name):
    if name not in task_mapper.keys():
        raise ValueError(f"Unrecognized dataset `{name}`")

    return task_mapper[name]