Adding checkpointing, wandb, and new mlm script
Browse files- README.md +29 -1
- perplexity.py +22 -0
- run_mlm_flax.py +60 -31
- tokens.py +3 -1
README.md
CHANGED
|
@@ -1,19 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# BERTIN
|
|
|
|
| 2 |
BERTIN is a series of BERT-based models for Spanish. This one is a RoBERTa-large model trained from scratch on the Spanish portion of mC4 using [Flax](https://github.com/google/flax), including training scripts.
|
| 3 |
|
| 4 |
This is part of the
|
| 5 |
[Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
## Team members
|
|
|
|
| 8 |
- Javier de la Rosa (versae)
|
| 9 |
- Manu Romero (mrm8488)
|
| 10 |
- María Grandury (mariagrandury)
|
| 11 |
- Ari Polakov (aripo99)
|
| 12 |
- Pablogps
|
| 13 |
- daveni
|
| 14 |
-
- Sri Lakshmi
|
| 15 |
|
| 16 |
## Useful links
|
|
|
|
| 17 |
- [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
|
| 18 |
- [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
|
| 19 |
- [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: no
|
| 3 |
+
license: CC-BY 4.0
|
| 4 |
+
tags:
|
| 5 |
+
- spanish
|
| 6 |
+
- roberta
|
| 7 |
+
pipeline_tag: fill-mask
|
| 8 |
+
widget:
|
| 9 |
+
- text: "Lo hizo en un abrir y cerar de <mask>."
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
# BERTIN
|
| 13 |
+
|
| 14 |
BERTIN is a series of BERT-based models for Spanish. This one is a RoBERTa-large model trained from scratch on the Spanish portion of mC4 using [Flax](https://github.com/google/flax), including training scripts.
|
| 15 |
|
| 16 |
This is part of the
|
| 17 |
[Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
| 18 |
|
| 19 |
+
## Spanish mC4
|
| 20 |
+
|
| 21 |
+
The Spanish portion of mC4 containes about 416 million records and 235 billion words.
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
$ zcat c4/multilingual/c4-es*.tfrecord*.json.gz | wc -l
|
| 25 |
+
416057992
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
$ zcat c4/multilingual/c4-es*.tfrecord-*.json.gz | jq -r '.text | split(" ") | length' | paste -s -d+ - | bc
|
| 30 |
+
235303687795
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
## Team members
|
| 34 |
+
|
| 35 |
- Javier de la Rosa (versae)
|
| 36 |
- Manu Romero (mrm8488)
|
| 37 |
- María Grandury (mariagrandury)
|
| 38 |
- Ari Polakov (aripo99)
|
| 39 |
- Pablogps
|
| 40 |
- daveni
|
| 41 |
+
- Sri Lakshmi
|
| 42 |
|
| 43 |
## Useful links
|
| 44 |
+
|
| 45 |
- [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
|
| 46 |
- [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
|
| 47 |
- [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
|
perplexity.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import kenlm
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def pp(log_score, length):
|
| 8 |
+
return 10.0 ** (-log_score / length)
|
| 9 |
+
|
| 10 |
+
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
| 11 |
+
model = kenlm.Model("es.arpa.bin")
|
| 12 |
+
mc4 = load_dataset("mc4", "es", streaming=True)
|
| 13 |
+
with open("mc4-es-perplexity.txt", "w") as f:
|
| 14 |
+
for sample in tqdm(mc4["train"].shuffle(buffer_size=100_000), total=416057992):
|
| 15 |
+
lines = sample["text"].split("\n")
|
| 16 |
+
doc_log_score, doc_length = 0, 0
|
| 17 |
+
for line in lines:
|
| 18 |
+
log_score = model.score(line)
|
| 19 |
+
length = len(line.split()) + 1
|
| 20 |
+
doc_log_score += log_score
|
| 21 |
+
doc_length += length
|
| 22 |
+
f.write(f"{pp(doc_log_score, doc_length)}\n")
|
run_mlm_flax.py
CHANGED
|
@@ -56,22 +56,6 @@ from transformers import (
|
|
| 56 |
)
|
| 57 |
|
| 58 |
|
| 59 |
-
# Cache the result
|
| 60 |
-
has_tensorboard = is_tensorboard_available()
|
| 61 |
-
if has_tensorboard:
|
| 62 |
-
try:
|
| 63 |
-
from flax.metrics.tensorboard import SummaryWriter
|
| 64 |
-
except ImportError as ie:
|
| 65 |
-
has_tensorboard = False
|
| 66 |
-
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
| 67 |
-
|
| 68 |
-
else:
|
| 69 |
-
print(
|
| 70 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
| 71 |
-
"Please run pip install tensorboard to enable."
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 76 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 77 |
|
|
@@ -126,6 +110,9 @@ class DataTrainingArguments:
|
|
| 126 |
dataset_config_name: Optional[str] = field(
|
| 127 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 128 |
)
|
|
|
|
|
|
|
|
|
|
| 129 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 130 |
validation_file: Optional[str] = field(
|
| 131 |
default=None,
|
|
@@ -269,7 +256,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
|
|
| 269 |
return batch_idx
|
| 270 |
|
| 271 |
|
| 272 |
-
def
|
| 273 |
summary_writer.scalar("train_time", train_time, step)
|
| 274 |
|
| 275 |
train_metrics = get_metrics(train_metrics)
|
|
@@ -278,6 +265,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
| 278 |
for i, val in enumerate(vals):
|
| 279 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 280 |
|
|
|
|
|
|
|
| 281 |
for metric_name, value in eval_metrics.items():
|
| 282 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
| 283 |
|
|
@@ -315,10 +304,6 @@ if __name__ == "__main__":
|
|
| 315 |
|
| 316 |
# Log on each process the small summary:
|
| 317 |
logger = logging.getLogger(__name__)
|
| 318 |
-
#logger.warning(
|
| 319 |
-
# f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 320 |
-
# + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 321 |
-
#)
|
| 322 |
|
| 323 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 324 |
logger.info(f"Training/evaluation parameters {training_args}")
|
|
@@ -337,7 +322,7 @@ if __name__ == "__main__":
|
|
| 337 |
# download the dataset.
|
| 338 |
if data_args.dataset_name is not None:
|
| 339 |
# Downloading and loading a dataset from the hub.
|
| 340 |
-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
| 341 |
|
| 342 |
if "validation" not in datasets.keys():
|
| 343 |
datasets["validation"] = load_dataset(
|
|
@@ -345,12 +330,14 @@ if __name__ == "__main__":
|
|
| 345 |
data_args.dataset_config_name,
|
| 346 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 347 |
cache_dir=model_args.cache_dir,
|
|
|
|
| 348 |
)
|
| 349 |
datasets["train"] = load_dataset(
|
| 350 |
data_args.dataset_name,
|
| 351 |
data_args.dataset_config_name,
|
| 352 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 353 |
cache_dir=model_args.cache_dir,
|
|
|
|
| 354 |
)
|
| 355 |
else:
|
| 356 |
data_files = {}
|
|
@@ -469,10 +456,32 @@ if __name__ == "__main__":
|
|
| 469 |
num_proc=data_args.preprocessing_num_workers,
|
| 470 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 471 |
)
|
| 472 |
-
|
| 473 |
# Enable tensorboard only on the master node
|
|
|
|
| 474 |
if has_tensorboard and jax.process_index() == 0:
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
# Data collator
|
| 478 |
# This one will take care of randomly masking the tokens.
|
|
@@ -521,7 +530,7 @@ if __name__ == "__main__":
|
|
| 521 |
learning_rate=linear_decay_lr_schedule_fn,
|
| 522 |
b1=training_args.adam_beta1,
|
| 523 |
b2=training_args.adam_beta2,
|
| 524 |
-
eps=
|
| 525 |
weight_decay=training_args.weight_decay,
|
| 526 |
mask=decay_mask_fn,
|
| 527 |
)
|
|
@@ -601,7 +610,7 @@ if __name__ == "__main__":
|
|
| 601 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
| 602 |
|
| 603 |
# Gather the indexes for creating the batch and do a training step
|
| 604 |
-
for
|
| 605 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
| 606 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 607 |
|
|
@@ -610,11 +619,31 @@ if __name__ == "__main__":
|
|
| 610 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 611 |
train_metrics.append(train_metric)
|
| 612 |
|
| 613 |
-
|
| 614 |
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
|
| 619 |
# ======================== Evaluating ==============================
|
| 620 |
num_eval_samples = len(tokenized_datasets["validation"])
|
|
@@ -645,7 +674,7 @@ if __name__ == "__main__":
|
|
| 645 |
# Save metrics
|
| 646 |
if has_tensorboard and jax.process_index() == 0:
|
| 647 |
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
| 648 |
-
|
| 649 |
|
| 650 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 651 |
if jax.process_index() == 0:
|
|
|
|
| 56 |
)
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 60 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 61 |
|
|
|
|
| 110 |
dataset_config_name: Optional[str] = field(
|
| 111 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 112 |
)
|
| 113 |
+
dataset_streaming: bool = field(
|
| 114 |
+
default=False, metadata={"help": "Whether dataset_name should be retrieved using streaming if available."}
|
| 115 |
+
)
|
| 116 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 117 |
validation_file: Optional[str] = field(
|
| 118 |
default=None,
|
|
|
|
| 256 |
return batch_idx
|
| 257 |
|
| 258 |
|
| 259 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
| 260 |
summary_writer.scalar("train_time", train_time, step)
|
| 261 |
|
| 262 |
train_metrics = get_metrics(train_metrics)
|
|
|
|
| 265 |
for i, val in enumerate(vals):
|
| 266 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 267 |
|
| 268 |
+
|
| 269 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
| 270 |
for metric_name, value in eval_metrics.items():
|
| 271 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
| 272 |
|
|
|
|
| 304 |
|
| 305 |
# Log on each process the small summary:
|
| 306 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 309 |
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
|
| 322 |
# download the dataset.
|
| 323 |
if data_args.dataset_name is not None:
|
| 324 |
# Downloading and loading a dataset from the hub.
|
| 325 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.dataset_streaming)
|
| 326 |
|
| 327 |
if "validation" not in datasets.keys():
|
| 328 |
datasets["validation"] = load_dataset(
|
|
|
|
| 330 |
data_args.dataset_config_name,
|
| 331 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 332 |
cache_dir=model_args.cache_dir,
|
| 333 |
+
streaming=data_args.dataset_streaming,
|
| 334 |
)
|
| 335 |
datasets["train"] = load_dataset(
|
| 336 |
data_args.dataset_name,
|
| 337 |
data_args.dataset_config_name,
|
| 338 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 339 |
cache_dir=model_args.cache_dir,
|
| 340 |
+
streaming=data_args.dataset_streaming,
|
| 341 |
)
|
| 342 |
else:
|
| 343 |
data_files = {}
|
|
|
|
| 456 |
num_proc=data_args.preprocessing_num_workers,
|
| 457 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 458 |
)
|
|
|
|
| 459 |
# Enable tensorboard only on the master node
|
| 460 |
+
has_tensorboard = is_tensorboard_available()
|
| 461 |
if has_tensorboard and jax.process_index() == 0:
|
| 462 |
+
try:
|
| 463 |
+
from flax.metrics.tensorboard import SummaryWriter
|
| 464 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
| 465 |
+
# Enable Weight&Biases
|
| 466 |
+
import wandb
|
| 467 |
+
wandb.init(
|
| 468 |
+
entity='wandb',
|
| 469 |
+
project='hf-flax-bertin-roberta-es',
|
| 470 |
+
sync_tensorboard=True,
|
| 471 |
+
)
|
| 472 |
+
wandb.config.update(training_args)
|
| 473 |
+
wandb.config.update(model_args)
|
| 474 |
+
wandb.config.update(data_args)
|
| 475 |
+
except ImportError as ie:
|
| 476 |
+
has_tensorboard = False
|
| 477 |
+
logger.warning(
|
| 478 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
| 479 |
+
)
|
| 480 |
+
else:
|
| 481 |
+
logger.warning(
|
| 482 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
| 483 |
+
"Please run pip install tensorboard to enable."
|
| 484 |
+
)
|
| 485 |
|
| 486 |
# Data collator
|
| 487 |
# This one will take care of randomly masking the tokens.
|
|
|
|
| 530 |
learning_rate=linear_decay_lr_schedule_fn,
|
| 531 |
b1=training_args.adam_beta1,
|
| 532 |
b2=training_args.adam_beta2,
|
| 533 |
+
eps=training_args.adam_epsilon,
|
| 534 |
weight_decay=training_args.weight_decay,
|
| 535 |
mask=decay_mask_fn,
|
| 536 |
)
|
|
|
|
| 610 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
| 611 |
|
| 612 |
# Gather the indexes for creating the batch and do a training step
|
| 613 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
| 614 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
| 615 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 616 |
|
|
|
|
| 619 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 620 |
train_metrics.append(train_metric)
|
| 621 |
|
| 622 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
| 623 |
|
| 624 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
| 625 |
+
# Save metrics
|
| 626 |
+
train_metric = jax_utils.unreplicate(train_metric)
|
| 627 |
+
train_time += time.time() - train_start
|
| 628 |
+
if has_tensorboard and jax.process_index() == 0:
|
| 629 |
+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
| 630 |
+
|
| 631 |
+
epochs.write(
|
| 632 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
train_metrics = []
|
| 636 |
+
|
| 637 |
+
if training_args.save_strategy == "steps" and cur_step and cur_step % training_args.save_steps == 0:
|
| 638 |
+
if jax.process_index() == 0:
|
| 639 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 640 |
+
model.save_pretrained(
|
| 641 |
+
Path(str(training_args.output_dir)) / "checkpoints" / f"checkpoint-{cur_step}",
|
| 642 |
+
params=params,
|
| 643 |
+
push_to_hub=training_args.push_to_hub,
|
| 644 |
+
temp_dir=True,
|
| 645 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 646 |
+
)
|
| 647 |
|
| 648 |
# ======================== Evaluating ==============================
|
| 649 |
num_eval_samples = len(tokenized_datasets["validation"])
|
|
|
|
| 674 |
# Save metrics
|
| 675 |
if has_tensorboard and jax.process_index() == 0:
|
| 676 |
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
| 677 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 678 |
|
| 679 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 680 |
if jax.process_index() == 0:
|
tokens.py
CHANGED
|
@@ -3,12 +3,14 @@ from datasets import load_dataset
|
|
| 3 |
from tokenizers import ByteLevelBPETokenizer
|
| 4 |
|
| 5 |
# Load dataset
|
| 6 |
-
dataset = load_dataset("
|
|
|
|
| 7 |
# Instantiate tokenizer
|
| 8 |
tokenizer = ByteLevelBPETokenizer()
|
| 9 |
def batch_iterator(batch_size=100_000_000):
|
| 10 |
for i in range(0, len(dataset), batch_size):
|
| 11 |
yield dataset["text"][i: i + batch_size]
|
|
|
|
| 12 |
# Customized training
|
| 13 |
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
| 14 |
"<s>",
|
|
|
|
| 3 |
from tokenizers import ByteLevelBPETokenizer
|
| 4 |
|
| 5 |
# Load dataset
|
| 6 |
+
dataset = load_dataset("oscar", "unshuffled_deduplicated_es")
|
| 7 |
+
|
| 8 |
# Instantiate tokenizer
|
| 9 |
tokenizer = ByteLevelBPETokenizer()
|
| 10 |
def batch_iterator(batch_size=100_000_000):
|
| 11 |
for i in range(0, len(dataset), batch_size):
|
| 12 |
yield dataset["text"][i: i + batch_size]
|
| 13 |
+
|
| 14 |
# Customized training
|
| 15 |
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
| 16 |
"<s>",
|