xlm-roberta-sentiment-requests
This model is a fine-tuned version of cardiffnlp/twitter-xlm-roberta-base-sentiment on the community-datasets/disaster_response_messages dataset. It has been adapted into a powerful multi-head classification model designed to analyze messages from social media during disaster events.
It achieves the following results on the evaluation set:
- Loss: 0.1465
- F1 Micro: 0.7240
- F1 Macro: 0.3505
- Subset Accuracy: 0.2588
Model description
This model uses a shared XLM-RoBERTa
base to encode input text. The resulting text representation is then fed into two separate, independent classification layers (heads):
- A Sentiment Head (Frozen from pre-trained model) with 3 outputs for
positive
,neutral
, andnegative
classes. - A Multi-Label Head (Newly created and fine-tuned) with 41 outputs, which are decoded to predict the presence or absence of 37 different disaster-related categories.
This dual-head architecture allows for a nuanced understanding of a message, capturing both its emotional content and its specific, actionable information.
Intended uses & limitations
This model is intended for organizations and researchers involved in humanitarian aid and disaster response. Potential applications include:
- Automated Triage: Quickly sorting through thousands of social media messages to identify the most urgent requests for help.
- Situational Awareness: Building a real-time map of needs by aggregating categorized messages.
- Resource Allocation: Directing resources more effectively by understanding the specific types of aid being requested.
Important: Due to its custom architecture, this model cannot be used with the standard pipeline("text-classification")
function. Please see the usage code below for the correct implementation.
How to Use
This model requires custom code to handle its two-headed output. The following is a complete, self-contained Python script to run inference. You will need to have transformers
, torch
, safetensors
, and huggingface_hub
installed (pip install transformers torch safetensors huggingface_hub
).
The script automatically downloads all necessary files, including the model weights and metadata. Simply copy the code blocks below and run the script.
The script is broken into logical blocks:
- Model Architecture: A Python class that defines the model's structure. This blueprint is required to load the saved weights.
- Label Definitions: A "decoder ring" of functions to translate the model's numerical outputs into human-readable labels.
- Setup & Loading: A function that handles all the one-time setup.
- Prediction Function: The core logic that takes text and produces a dictionary of predictions.
- Main Execution: An example of how to run the script.
By copying the codes below from 1 to 5, you will be able to run the entire inference pipeline with all outputs.
- Model Architecture: We define the necessary imports and the model architecture.
import torch
from torch import nn
from transformers import AutoTokenizer, AutoConfig, AutoModel, PreTrainedModel
from huggingface_hub import hf_hub_download
from typing import Dict, Any
from safetensors.torch import load_file
import json
class MultiHeadClassificationModel(PreTrainedModel):
def __init__(self, config, **kwargs):
super().__init__(config)
num_multilabels = kwargs.get("num_multilabels")
if num_multilabels is None:
raise ValueError("`num_multilabels` must be provided to initialize the model.")
self.backbone = AutoModel.from_config(config)
self.sentiment_classifier = nn.Linear(config.hidden_size, config.num_sentiment_labels)
self.multilabel_classifier = nn.Linear(config.hidden_size, num_multilabels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.backbone(input_ids, attention_mask=attention_mask, **kwargs)
cls_token_output = outputs.last_hidden_state[:, 0, :]
sentiment_logits = self.sentiment_classifier(cls_token_output)
multilabel_logits = self.multilabel_classifier(cls_token_output)
return {"sentiment_logits": sentiment_logits, "multilabel_logits": multilabel_logits}
- Label Definitions: We embed the label definitions, which are essential for interpreting the model's output.
def get_all_labels() -> Dict[str, Dict[int, str]]:
return {
'sentiment': get_sentiment_labels(), 'genre': get_genre_labels(), 'related': get_related_labels(),
'request': get_request_labels(), 'offer': get_offer_labels(), 'aid_related': get_aid_related_labels(),
'medical_help': get_medical_help_labels(), 'medical_products': get_medical_products_labels(),
'search_and_rescue': get_search_and_rescue_labels(), 'security': get_security_labels(),
'military': get_military_labels(), 'child_alone': get_child_alone_labels(), 'water': get_water_labels(),
'food': get_food_labels(), 'shelter': get_shelter_labels(), 'clothing': get_clothing_labels(),
'money': get_money_labels(), 'missing_people': get_missing_people_labels(),
'refugees': get_refugees_labels(), 'death': get_death_labels(), 'other_aid': get_other_aid_labels(),
'infrastructure_related': get_infrastructure_related_labels(), 'transport': get_transport_labels(),
'buildings': get_buildings_labels(), 'electricity': get_electricity_labels(), 'tools': get_tools_labels(),
'hospitals': get_hospitals_labels(), 'shops': get_shops_labels(), 'aid_centers': get_aid_centers_labels(),
'other_infrastructure': get_other_infrastructure_labels(), 'weather_related': get_weather_related_labels(),
'floods': get_floods_labels(), 'storm': get_storm_labels(), 'fire': get_fire_labels(),
'earthquake': get_earthquake_labels(), 'cold': get_cold_labels(), 'other_weather': get_other_weather_labels(),
'direct_report': get_direct_report_labels(),
}
def get_genre_labels() -> Dict[int, str]: return {0: 'direct', 1: 'news', 2: 'social'}
def get_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes', 2: 'maybe'}
def get_request_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_offer_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_aid_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_medical_help_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_medical_products_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_search_and_rescue_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_security_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_military_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_child_alone_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_water_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_food_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_shelter_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_clothing_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_money_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_missing_people_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_refugees_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_death_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_other_aid_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_infrastructure_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_transport_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_buildings_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_electricity_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_tools_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_hospitals_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_shops_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_aid_centers_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_other_infrastructure_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_weather_related_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_floods_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_storm_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_fire_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_earthquake_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_cold_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_other_weather_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_direct_report_labels() -> Dict[int, str]: return {0: 'no', 1: 'yes'}
def get_sentiment_labels() -> Dict[int, str]: return {0: 'negative', 1: 'neutral', 2: 'positive'}
- Setup & Loading: This setup function downloads and loads all components, including
metadata.json
, from the Hub.
def load_essentials():
print("Loading model, tokenizer, and metadata... (This may take a moment on first run)")
hub_repo_id = "spencercdz/xlm-roberta-sentiment-requests"
subfolder = "final_model"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the model's output structure from the metadata.json file.
metadata_path = hf_hub_download(repo_id=hub_repo_id, filename="metadata.json", subfolder=subfolder)
with open(metadata_path, "r") as f:
file_metadata = json.load(f)
# Use the metadata to define the number of output neurons for the classification heads.
binary_tasks = file_metadata["binary_tasks"]
multiclass_tasks = file_metadata["multiclass_tasks"]
multilabel_column_names = file_metadata["multilabel_column_names"]
num_multilabels = len(multilabel_column_names)
num_sentiment_labels = len(get_sentiment_labels())
# Load the standard tokenizer and config.
tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, subfolder=subfolder)
config = AutoConfig.from_pretrained(hub_repo_id, subfolder=subfolder)
# Add our custom sentiment label count to the config.
config.num_sentiment_labels = num_sentiment_labels
# Manually load the custom model, as it's not a standard transformers architecture.
# Create a model 'shell' with our custom architecture.
model_shell = MultiHeadClassificationModel(config=config, num_multilabels=num_multilabels)
# Download and load the trained weights.
weights_path = hf_hub_download(repo_id=hub_repo_id, filename="model.safetensors", subfolder=subfolder)
state_dict = load_file(weights_path, device="cpu")
# Apply weights to the shell. `strict=False` is required for loading custom heads.
model_shell.load_state_dict(state_dict, strict=False)
# Move model to the target device and set to evaluation mode.
model = model_shell.to(device)
model.eval()
# Package all components for use in the predict function.
metadata_for_prediction = {
"binary_tasks": binary_tasks,
"multiclass_tasks": multiclass_tasks,
"multilabel_column_names": multilabel_column_names,
"all_labels": get_all_labels(),
"device": device
}
print("Loading complete.")
return model, tokenizer, metadata_for_prediction
- Prediction Function: The prediction function takes the loaded components and input text to produce a decoded dictionary.
def predict(text: str, model, tokenizer, metadata: Dict) -> Dict[str, Any]:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(metadata['device'])
with torch.no_grad():
outputs = model(**inputs)
sentiment_probs = torch.softmax(outputs['sentiment_logits'], dim=-1).cpu().numpy()
multilabel_probs = torch.sigmoid(outputs['multilabel_logits']).cpu().numpy()
results = {}
sentiment_decoder = metadata['all_labels']['sentiment']
sentiment_pred_idx = sentiment_probs.argmax()
results['sentiment'] = {'prediction': sentiment_decoder.get(sentiment_pred_idx, "unknown"), 'confidence': sentiment_probs[0, sentiment_pred_idx].item()}
for task_name in metadata['binary_tasks']:
idx = metadata['multilabel_column_names'].index(task_name)
prob = multilabel_probs[0, idx]
pred = 1 if prob > 0.5 else 0
results[task_name] = {'prediction': metadata['all_labels'][task_name][pred], 'confidence': (prob if pred == 1 else 1 - prob).item()}
for task_name, num_classes in metadata['multiclass_tasks'].items():
start_idx = metadata['multilabel_column_names'].index(f"{task_name}_0")
task_probs = multilabel_probs[0, start_idx : start_idx + num_classes]
pred_idx = task_probs.argmax()
results[task_name] = {'prediction': metadata['all_labels'][task_name].get(pred_idx, "unknown"), 'confidence': task_probs[pred_idx].item()}
return results
- Main Execution: The main execution block shows how to use the functions and print the raw JSON output.
if __name__ == "__main__":
model, tokenizer, metadata = load_essentials()
input_text = "I need food, water, and shelter. Help me! People are dying. We need more items."
print(f"\n--- Predicting for Input ---\n\"{input_text}\"")
predictions = predict(input_text, model, tokenizer, metadata)
# Print the raw dictionary output
print("\n--- RAW DICTIONARY OUTPUT ---")
print(json.dumps(predictions, indent=4))
Sample Output
{'sentiment': {'prediction': 'negative', 'confidence': 0.999014139175415}, 'request': {'prediction': 'yes', 'confidence': 0.9999805688858032}, 'offer': {'prediction': 'no', 'confidence': 0.9995545148849487}, 'aid_related': {'prediction': 'yes', 'confidence': 0.9995179176330566}, 'medical_help': {'prediction': 'no', 'confidence': 0.9931818246841431}, 'medical_products': {'prediction': 'no', 'confidence': 0.9975765943527222}, 'search_and_rescue': {'prediction': 'no', 'confidence': 0.9981554746627808}, 'security': {'prediction': 'no', 'confidence': 0.999071478843689}, 'military': {'prediction': 'no', 'confidence': 0.9981452226638794}, 'child_alone': {'prediction': 'no', 'confidence': 0.9998688697814941}, 'water': {'prediction': 'yes', 'confidence': 0.9991873502731323}, 'food': {'prediction': 'yes', 'confidence': 0.9998394250869751}, 'shelter': {'prediction': 'yes', 'confidence': 0.9997198581695557}, 'clothing': {'prediction': 'no', 'confidence': 0.9982467889785767}, 'money': {'prediction': 'no', 'confidence': 0.9985392093658447}, 'missing_people': {'prediction': 'no', 'confidence': 0.998404324054718}, 'refugees': {'prediction': 'no', 'confidence': 0.9981242418289185}, 'death': {'prediction': 'yes', 'confidence': 0.9850122332572937}, 'other_aid': {'prediction': 'no', 'confidence': 0.9654157757759094}, 'infrastructure_related': {'prediction': 'no', 'confidence': 0.984534740447998}, 'transport': {'prediction': 'no', 'confidence': 0.9972304105758667}, 'buildings': {'prediction': 'no', 'confidence': 0.9881182312965393}, 'electricity': {'prediction': 'no', 'confidence': 0.9988776445388794}, 'tools': {'prediction': 'no', 'confidence': 0.9995874166488647}, 'hospitals': {'prediction': 'no', 'confidence': 0.999099850654602}, 'shops': {'prediction': 'no', 'confidence': 0.9996023178100586}, 'aid_centers': {'prediction': 'no', 'confidence': 0.9981774091720581}, 'other_infrastructure': {'prediction': 'no', 'confidence': 0.9968826770782471}, 'weather_related': {'prediction': 'no', 'confidence': 0.9632836580276489}, 'floods': {'prediction': 'no', 'confidence': 0.9960920810699463}, 'storm': {'prediction': 'no', 'confidence': 0.9963870048522949}, 'fire': {'prediction': 'no', 'confidence': 0.9993714094161987}, 'earthquake': {'prediction': 'no', 'confidence': 0.99778151512146}, 'cold': {'prediction': 'no', 'confidence': 0.9991660118103027}, 'other_weather': {'prediction': 'no', 'confidence': 0.9974269866943359}, 'direct_report': {'prediction': 'yes', 'confidence': 0.9763266444206238}, 'genre': {'prediction': 'direct', 'confidence': 0.9912198185920715}, 'related': {'prediction': 'yes', 'confidence': 0.9997092485427856}}
Training and evaluation data
This model was fine-tuned on the community-datasets/disaster_response_messages
dataset, which contains over 26,000 messages from real disaster events. Each message is labeled with 37 different categories, such as aid_related
and weather_related
, as well as the message genre
(direct, news, social). The sentiment
labels were added programmatically for the purpose of this multi-task training.
The dataset was split into:
- Training set: ~21,000 samples
- Validation set: ~2,600 samples
- Test set: ~2,600 samples
Training procedure
The model was trained using the transformers.Trainer
with a custom MultiHeadClassificationModel
architecture. The training process optimized a combined loss from both the sentiment and multi-label classification heads. The best model was selected based on the F1 Micro
score on the validation set.
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 2e-05
- train_batch_size: 32
- eval_batch_size: 32
- seed: 42
- optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- num_epochs: 1000 (early stopping patience of 50 epochs)
- mixed_precision_training: Native AMP
Training results
The final results on the evaluation set are based on the best checkpoint at epoch 594. A truncated history of the 25 most important rows are shown below. For the full data, please refer to training_log.csv in the repository.
Training Loss | Epoch | Step | Validation Loss | F1 Micro | F1 Macro | Subset Accuracy |
---|---|---|---|---|---|---|
0.4267 | 1.0 | 658 | 0.2727 | 0.4953 | 0.0722 | 0.1053 |
0.2662 | 2.0 | 1316 | 0.2291 | 0.5446 | 0.0906 | 0.1123 |
0.2366 | 3.0 | 1974 | 0.2143 | 0.5682 | 0.1031 | 0.1279 |
0.2234 | 4.0 | 2632 | 0.2058 | 0.5878 | 0.1160 | 0.1333 |
0.2156 | 5.0 | 3290 | 0.1997 | 0.6022 | 0.1255 | 0.1380 |
... | ... | ... | ... | ... | ... | ... |
0.1773 | 25.0 | 16450 | 0.1670 | 0.6714 | 0.2305 | 0.1955 |
0.1694 | 50.0 | 32900 | 0.1592 | 0.6911 | 0.2701 | 0.2223 |
0.1662 | 75.0 | 49350 | 0.1558 | 0.7018 | 0.2960 | 0.2309 |
0.164 | 100.0 | 65800 | 0.1537 | 0.7077 | 0.3098 | 0.2425 |
0.1627 | 125.0 | 82250 | 0.1522 | 0.7104 | 0.3184 | 0.2449 |
0.1617 | 150.0 | 98700 | 0.1513 | 0.7130 | 0.3243 | 0.2449 |
0.1612 | 175.0 | 115150 | 0.1504 | 0.7143 | 0.3285 | 0.2499 |
0.1606 | 200.0 | 131600 | 0.1498 | 0.7161 | 0.3314 | 0.2515 |
0.16 | 250.0 | 164500 | 0.1488 | 0.7183 | 0.3383 | 0.2538 |
0.1592 | 300.0 | 197400 | 0.1482 | 0.7204 | 0.3423 | 0.2534 |
0.1589 | 350.0 | 230300 | 0.1476 | 0.7214 | 0.3450 | 0.2581 |
0.1584 | 400.0 | 263200 | 0.1474 | 0.7223 | 0.3459 | 0.2588 |
0.1584 | 450.0 | 296100 | 0.1471 | 0.7231 | 0.3487 | 0.2588 |
0.158 | 500.0 | 329000 | 0.1468 | 0.7232 | 0.3494 | 0.2612 |
0.1577 | 550.0 | 361900 | 0.1467 | 0.7239 | 0.3503 | 0.2600 |
... | ... | ... | ... | ... | ... | ... |
0.1574 | 591.0 | 388878 | 0.1466 | 0.7243 | 0.3510 | 0.2596 |
0.1576 | 592.0 | 389536 | 0.1465 | 0.7234 | 0.3496 | 0.2596 |
0.1582 | 593.0 | 390194 | 0.1465 | 0.7239 | 0.3504 | 0.2592 |
0.158 | 594.0 | 390852 | 0.1465 | 0.7240 | 0.3505 | 0.2588 |
Framework versions
- Transformers 4.52.4
- Pytorch 2.7.1+cu128
- Datasets 3.6.0
- Tokenizers 0.21.2
- Downloads last month
- 5
Model tree for spencercdz/xlm-roberta-sentiment-requests
Dataset used to train spencercdz/xlm-roberta-sentiment-requests
Evaluation results
- F1 Micro on community-datasets/disaster_response_messagesself-reported0.724
- F1 Macro on community-datasets/disaster_response_messagesself-reported0.350
- Subset Accuracy on community-datasets/disaster_response_messagesself-reported0.259