license: cc-by-nc-4.0
tags:
- audio
- music
- continual-pretraining
- cross-cultural-MIR
metrics:
- roc_auc
- average_precision
- f1
pipeline_tag: audio-classification
CultureMERT: Continual Pre-Training for Cross-Cultural Music Representation Learning
CultureMERT-95M is a multi-culturally adapted 95M-parameter music foundation model based on MERT-v1-95M. It is developed through a two-stage continual pre-training strategy on 650 hours of culturally diverse audio spanning Greek, Turkish, and Indian musical traditions. The model significantly improves representation quality for "non-Western" music, achieving an average ROC-AUC improvement of 4.43% across culturally diverse music tagging tasks, surpassing prior state-of-the-art, while maintaining strong performance on Western-centric benchmarks such as MagnaTagATune and FMA-medium.
π§ Model Details
- Architecture: 12-layer Transformer encoder (768-dim) with a 7-layer 1D CNN frontend
- Input: Raw mono audio at 24kHz
- Training Context Length: 5 seconds
- Pretraining Objective: MLM-style multi-task masked prediction of discrete EnCodec acoustic tokens and continuous constant-Q transform (CQT) spectrogram reconstruction at a 75Hz feature rate.
π Training Data
| Dataset | Music Tradition | Hours Used |
|---|---|---|
| Lyra | Greek traditional/folk | 50h |
| Turkish-makam | Turkish/Ottoman classical | 200h |
| Hindustani | North Indian classical | 200h |
| Carnatic | South Indian classical | 200h |
π The datasets used were obtained under research-use agreements and are not redistributed.
π Evaluation
We evaluate CultureMERT-95M via probing on both Western and non-Western music tagging tasks to assess its cross-cultural generalization performance. All results are averaged over five random seeds. Metrics used:
- ROC-AUC (Receiver Operating Characteristic - Area Under Curve)
- Mean Average Precision (mAP)
- Micro-F1 and Macro-F1
We follow the MARBLE protocol under constrained settings.
Evaluation Datasets
- Non-Western traditions:
- Turkish-makam (Ottoman classical)
- Hindustani (North Indian classical)
- Carnatic (South Indian classical)
- Lyra (Greek traditional/folk)
- Western benchmarks:
- MagnaTagATune (MTAT)
- FMA-medium
ROC-AUC and mAP
| Model | Turkish-makam | Hindustani | Carnatic | Lyra | FMA | MTAT | Avg. |
|---|---|---|---|---|---|---|---|
| MERT-v1-95M | 83.2 / 53.3 | 82.4 / 52.9 | 74.9 / 39.7 | 85.7 / 56.5 | 90.7 / 48.1 | 89.6 / 35.9 | 66.1 |
| CultureMERT-95M | 89.6 / 60.6 | 88.2 / 63.5 | 79.2 / 43.1 | 86.9 / 56.7 | 90.7 / 48.1 | 89.4 / 35.9 | 69.3 |
Micro-F1 and Macro-F1
| Model | Turkish-makam | Hindustani | Carnatic | Lyra | FMA | MTAT | Avg. |
|---|---|---|---|---|---|---|---|
| MERT-v1-95M | 73.0 / 38.9 | 71.1 / 33.2 | 80.1 / 30.0 | 72.4 / 42.6 | 57.0 / 36.9 | 35.7 / 21.2 | 49.3 |
| CultureMERT-95M | 77.4 / 45.8 | 77.8 / 50.4 | 82.7 / 32.5 | 73.1 / 43.1 | 58.3 / 36.6 | 35.6 / 22.9 | 52.9 |
CultureMERT-95M outperforms the original MERT-v1-95M by an average of 4.43% in ROC-AUC across non-Western traditions, with consistent improvements of 5.4% in mAP, 3.6% in Micro-F1, and 6.8% in Macro-F1, while exhibiting minimal forgetting on Western datasets.
Model Usage
from transformers import Wav2Vec2FeatureExtractor, AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
from datasets import load_dataset
# Load model weights and preprocessor config
model = AutoModel.from_pretrained("ntua-slp/CultureMERT-95M", trust_remote_code=True)
processor = Wav2Vec2FeatureExtractor.from_pretrained("ntua-slp/CultureMERT-95M", trust_remote_code=True)
# Load example audio
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation").sort("id")
audio_array = dataset[0]["audio"]["array"]
sampling_rate = dataset.features["audio"].sampling_rate
# Resample if needed
resample_rate = processor.sampling_rate
if resample_rate != sampling_rate:
print(f'Setting rate from {sampling_rate} to {resample_rate}')
resampler = T.Resample(sampling_rate, resample_rate)
else:
resampler = None
# Audio file is decoded on the fly
if resampler is None:
input_audio = dataset[0]["audio"]["array"]
else:
input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
# Extract hidden states
inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Representations: 13 layers (CNN + 12 Transformer)
# NOTE: each layer performs differently in different downstream tasks - you should choose empirically
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
# For utterance-level classification tasks, you can simply reduce the representation in time
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
print(time_reduced_hidden_states.shape) # [13, 768]
# You can even use a learnable weighted average representation
aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
print(weighted_avg_hidden_states.shape) # [768]
Ethical Considerations
This model is released under a non-commercial CC BY-NC 4.0 license and is intended for academic research. While it is designed to address cultural bias in MIR, its training data and pretraining paradigm may still reflect cultural and dataset-specific biases. The model should not be used in commercial or generative applications without explicit consideration of cultural representation, proper licensing, and the consent of the relevant communities or dataset curators.
Citation
...