import torch


def l1(output, target):
    return torch.mean(torch.abs(output - target))


def l1_wav(output_dict, target_dict):
	return l1(output_dict['segment'], target_dict['segment'])


def get_loss_function(loss_type):
    if loss_type == "l1_wav":
        return l1_wav

    else:
        raise NotImplementedError("Error!")