metadata
license: mit
base_model: roberta-base
tags:
- generated_from_trainer
metrics:
- accuracy
- recall
- precision
- f1
model-index:
- name: roberta-base-suicide-prediction-phr
results: []
datasets:
- vibhorag101/suicide_prediction_dataset_phr
language:
- en
library_name: transformers
roberta-base-suicide-prediction-phr
This model is a fine-tuned version of roberta-base on a Suicide Prediction dataset sourced from Reddit. It achieves the following results on the evaluation/validation set:
- Loss: 0.1543
- Accuracy: {'accuracy': 0.9652972367116438}
- Recall: {'recall': 0.966571403827834}
- Precision: {'precision': 0.9638169257340242}
- F1: {'f1': 0.9651921995935487}
Model description
More information needed
Training and evaluation data
The dataset is sourced from Reddit and is available on Kaggle.
The dataset contains text with binary labels for suicide or non-suicide.
The evaluation set had ~23000 samples, while the training set had ~186k samples, i.e. 80:10:10 (train:test:val) split.
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 2e-05
- train_batch_size: 16
- eval_batch_size: 16
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 3
Training results
Training Loss | Epoch | Step | Validation Loss | Accuracy | Recall | Precision | F1 |
---|---|---|---|---|---|---|---|
0.2023 | 0.09 | 1000 | 0.1868 | {'accuracy': 0.9415010561710566} | {'recall': 0.9389451805663809} | {'precision': 0.943274752044545} | {'f1': 0.9411049867627274} |
0.1792 | 0.17 | 2000 | 0.1465 | {'accuracy': 0.9528387291460103} | {'recall': 0.9615484541439335} | {'precision': 0.9446949714966392} | {'f1': 0.9530472103004292} |
0.1596 | 0.26 | 3000 | 0.1871 | {'accuracy': 0.9523645298961072} | {'recall': 0.9399844115354637} | {'precision': 0.9634297887448962} | {'f1': 0.9515627054749485} |
0.1534 | 0.34 | 4000 | 0.1563 | {'accuracy': 0.9518041126007674} | {'recall': 0.974971854161254} | {'precision': 0.9314139157772814} | {'f1': 0.9526952695269527} |
0.1553 | 0.43 | 5000 | 0.1691 | {'accuracy': 0.9513730223735828} | {'recall': 0.93141075604053} | {'precision': 0.9697051663510955} | {'f1': 0.950172276702889} |
0.1537 | 0.52 | 6000 | 0.1347 | {'accuracy': 0.9568478682588266} | {'recall': 0.9644063393089114} | {'precision': 0.9496844618795839} | {'f1': 0.9569887852876723} |
0.1515 | 0.6 | 7000 | 0.1276 | {'accuracy': 0.9565461050997974} | {'recall': 0.9426690915389279} | {'precision': 0.9691924138545098} | {'f1': 0.9557467732022126} |
0.1453 | 0.69 | 8000 | 0.1351 | {'accuracy': 0.960210372030866} | {'recall': 0.9589503767212263} | {'precision': 0.961031070994619} | {'f1': 0.959989596428107} |
0.1526 | 0.78 | 9000 | 0.1423 | {'accuracy': 0.9610725524852352} | {'recall': 0.9612020438209059} | {'precision': 0.9606196988056085} | {'f1': 0.9609107830829834} |
0.1437 | 0.86 | 10000 | 0.1365 | {'accuracy': 0.9599948269172738} | {'recall': 0.9625010825322594} | {'precision': 0.9573606684468946} | {'f1': 0.9599239937813093} |
0.1317 | 0.95 | 11000 | 0.1275 | {'accuracy': 0.9616760788032935} | {'recall': 0.9653589676972374} | {'precision': 0.9579752492265383} | {'f1': 0.9616529353405513} |
0.125 | 1.03 | 12000 | 0.1428 | {'accuracy': 0.9608138983489244} | {'recall': 0.9522819780029445} | {'precision': 0.9684692619341201} | {'f1': 0.9603074101567617} |
0.1135 | 1.12 | 13000 | 0.1627 | {'accuracy': 0.960770789326206} | {'recall': 0.9544470425218672} | {'precision': 0.966330556773345} | {'f1': 0.9603520390379923} |
0.1096 | 1.21 | 14000 | 0.1240 | {'accuracy': 0.9624520412122257} | {'recall': 0.9566987096215467} | {'precision': 0.9675074443860571} | {'f1': 0.962072719355541} |
0.1213 | 1.29 | 15000 | 0.1502 | {'accuracy': 0.9616760788032935} | {'recall': 0.9659651857625358} | {'precision': 0.9574248927038627} | {'f1': 0.9616760788032936} |
0.1166 | 1.38 | 16000 | 0.1574 | {'accuracy': 0.958873992326594} | {'recall': 0.9438815276695246} | {'precision': 0.9726907630522088} | {'f1': 0.9580696202531646} |
0.1214 | 1.47 | 17000 | 0.1626 | {'accuracy': 0.9562443419407682} | {'recall': 0.9773101238416905} | {'precision': 0.9374480810765908} | {'f1': 0.9569641721433114} |
0.1064 | 1.55 | 18000 | 0.1653 | {'accuracy': 0.9624089321895073} | {'recall': 0.9622412747899888} | {'precision': 0.9622412747899888} | {'f1': 0.9622412747899888} |
0.1046 | 1.64 | 19000 | 0.1608 | {'accuracy': 0.9640039660300901} | {'recall': 0.9697756993158396} | {'precision': 0.9584046559397467} | {'f1': 0.9640566484438896} |
0.1043 | 1.72 | 20000 | 0.1556 | {'accuracy': 0.960770789326206} | {'recall': 0.9493374902572097} | {'precision': 0.9712058119961017} | {'f1': 0.9601471489883507} |
0.0995 | 1.81 | 21000 | 0.1646 | {'accuracy': 0.9602534810535845} | {'recall': 0.9752316619035247} | {'precision': 0.9465411448264268} | {'f1': 0.9606722402320423} |
0.1065 | 1.9 | 22000 | 0.1721 | {'accuracy': 0.9627106953485365} | {'recall': 0.9710747380271932} | {'precision': 0.9547854223433242} | {'f1': 0.9628611910179897} |
0.1204 | 1.98 | 23000 | 0.1214 | {'accuracy': 0.9629693494848471} | {'recall': 0.961028838659392} | {'precision': 0.9644533286980705} | {'f1': 0.9627380384331756} |
0.0852 | 2.07 | 24000 | 0.1583 | {'accuracy': 0.9643919472345562} | {'recall': 0.9624144799515025} | {'precision': 0.9659278574532811} | {'f1': 0.9641679680721846} |
0.0812 | 2.16 | 25000 | 0.1594 | {'accuracy': 0.9635728758029055} | {'recall': 0.9572183251060882} | {'precision': 0.9692213258505787} | {'f1': 0.9631824321380331} |
0.0803 | 2.24 | 26000 | 0.1629 | {'accuracy': 0.9639177479846532} | {'recall': 0.9608556334978783} | {'precision': 0.9664634146341463} | {'f1': 0.963651365787988} |
0.0832 | 2.33 | 27000 | 0.1570 | {'accuracy': 0.9631417855757209} | {'recall': 0.9658785831817788} | {'precision': 0.9603065266058206} | {'f1': 0.9630844954881052} |
0.0887 | 2.41 | 28000 | 0.1551 | {'accuracy': 0.9623227141440703} | {'recall': 0.9669178141508616} | {'precision': 0.9577936004117698} | {'f1': 0.9623340803309774} |
0.084 | 2.5 | 29000 | 0.1585 | {'accuracy': 0.9644350562572747} | {'recall': 0.9613752489824197} | {'precision': 0.96698606271777} | {'f1': 0.9641724931602031} |
0.0807 | 2.59 | 30000 | 0.1601 | {'accuracy': 0.9639177479846532} | {'recall': 0.9699489044773534} | {'precision': 0.9580838323353293} | {'f1': 0.9639798597065025} |
0.079 | 2.67 | 31000 | 0.1645 | {'accuracy': 0.9628400224166919} | {'recall': 0.9558326838139777} | {'precision': 0.9690929844586882} | {'f1': 0.9624171607952564} |
0.0913 | 2.76 | 32000 | 0.1560 | {'accuracy': 0.9642626201664009} | {'recall': 0.964752749631939} | {'precision': 0.9635011243729459} | {'f1': 0.9641265307888701} |
0.0927 | 2.85 | 33000 | 0.1491 | {'accuracy': 0.9649523645298961} | {'recall': 0.9659651857625358} | {'precision': 0.9637117677553136} | {'f1': 0.9648371610224472} |
0.0882 | 2.93 | 34000 | 0.1543 | {'accuracy': 0.9652972367116438} | {'recall': 0.966571403827834} | {'precision': 0.9638169257340242} | {'f1': 0.9651921995935487} |
Framework versions
- Transformers 4.31.0
- Pytorch 2.1.0+cu121
- Datasets 2.14.5
- Tokenizers 0.13.3