Preparing code for final runs
Browse files- config.py +2 -2
 - run_mlm_flax_stream.py +69 -23
 - run_stream.sh +5 -6
 
    	
        config.py
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 1 | 
         
             
            #!/usr/bin/env python
         
     | 
| 2 | 
         
             
            from transformers import RobertaConfig
         
     | 
| 3 | 
         
             
            config = RobertaConfig.from_pretrained("roberta-large")
         
     | 
| 4 | 
         
            -
            config.save_pretrained("./")
         
     | 
| 5 | 
         | 
| 6 | 
         
             
            config = RobertaConfig.from_pretrained("roberta-base")
         
     | 
| 7 | 
         
            -
            config.save_pretrained("./ 
     | 
| 
         | 
|
| 1 | 
         
             
            #!/usr/bin/env python
         
     | 
| 2 | 
         
             
            from transformers import RobertaConfig
         
     | 
| 3 | 
         
             
            config = RobertaConfig.from_pretrained("roberta-large")
         
     | 
| 4 | 
         
            +
            config.save_pretrained("./configs/large")
         
     | 
| 5 | 
         | 
| 6 | 
         
             
            config = RobertaConfig.from_pretrained("roberta-base")
         
     | 
| 7 | 
         
            +
            config.save_pretrained("./configs/base")
         
     | 
    	
        run_mlm_flax_stream.py
    CHANGED
    
    | 
         @@ -21,13 +21,16 @@ Here is the full list of checkpoints on the hub that can be fine-tuned by this s 
     | 
|
| 21 | 
         
             
            https://huggingface.co/models?filter=masked-lm
         
     | 
| 22 | 
         
             
            """
         
     | 
| 23 | 
         
             
            import logging
         
     | 
| 
         | 
|
| 24 | 
         
             
            import os
         
     | 
| 
         | 
|
| 25 | 
         
             
            import sys
         
     | 
| 26 | 
         
             
            import time
         
     | 
| 27 | 
         
             
            from collections import defaultdict
         
     | 
| 28 | 
         
             
            from dataclasses import dataclass, field
         
     | 
| 29 | 
         | 
| 30 | 
         
             
            # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
         
     | 
| 
         | 
|
| 31 | 
         
             
            from pathlib import Path
         
     | 
| 32 | 
         
             
            from typing import Dict, List, Optional, Tuple
         
     | 
| 33 | 
         | 
| 
         @@ -39,9 +42,10 @@ from tqdm import tqdm 
     | 
|
| 39 | 
         
             
            import flax
         
     | 
| 40 | 
         
             
            import jax
         
     | 
| 41 | 
         
             
            import jax.numpy as jnp
         
     | 
| 42 | 
         
            -
            import kenlm
         
     | 
| 43 | 
         
             
            import optax
         
     | 
| 44 | 
         
             
            from flax import jax_utils, traverse_util
         
     | 
| 
         | 
|
| 45 | 
         
             
            from flax.training import train_state
         
     | 
| 46 | 
         
             
            from flax.training.common_utils import get_metrics, onehot, shard
         
     | 
| 47 | 
         
             
            from transformers import (
         
     | 
| 
         @@ -334,6 +338,26 @@ def write_eval_metric(summary_writer, eval_metrics, step): 
     | 
|
| 334 | 
         
             
                    summary_writer.scalar(f"eval_{metric_name}", value, step)
         
     | 
| 335 | 
         | 
| 336 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 337 | 
         
             
            if __name__ == "__main__":
         
     | 
| 338 | 
         
             
                # See all possible arguments in src/transformers/training_args.py
         
     | 
| 339 | 
         
             
                # or by passing the --help flag to this script.
         
     | 
| 
         @@ -391,19 +415,31 @@ if __name__ == "__main__": 
     | 
|
| 391 | 
         
             
                        filepaths["train"] = data_args.train_file
         
     | 
| 392 | 
         
             
                    if data_args.validation_file:
         
     | 
| 393 | 
         
             
                        filepaths["validation"] = data_args.validation_file
         
     | 
| 394 | 
         
            -
                     
     | 
| 395 | 
         
            -
                         
     | 
| 396 | 
         
            -
             
     | 
| 397 | 
         
            -
             
     | 
| 398 | 
         
            -
             
     | 
| 399 | 
         
            -
             
     | 
| 400 | 
         
            -
             
     | 
| 401 | 
         
            -
             
     | 
| 402 | 
         
            -
             
     | 
| 403 | 
         
            -
             
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
             
     | 
| 406 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 407 | 
         | 
| 408 | 
         
             
                if model_args.config_name:
         
     | 
| 409 | 
         
             
                    config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
         
     | 
| 
         @@ -662,15 +698,25 @@ if __name__ == "__main__": 
     | 
|
| 662 | 
         
             
                            write_eval_metric(summary_writer, eval_metrics, step)
         
     | 
| 663 | 
         
             
                        eval_metrics = []
         
     | 
| 664 | 
         | 
| 665 | 
         
            -
             
     | 
| 666 | 
         
            -
             
     | 
| 667 | 
         
            -
             
     | 
| 668 | 
         
            -
             
     | 
| 669 | 
         
            -
             
     | 
| 670 | 
         
            -
             
     | 
| 671 | 
         
            -
             
     | 
| 672 | 
         
            -
             
     | 
| 673 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 674 | 
         | 
| 675 | 
         
             
                    # update tqdm bar
         
     | 
| 676 | 
         
             
                    steps.update(1)
         
     | 
| 
         | 
|
| 21 | 
         
             
            https://huggingface.co/models?filter=masked-lm
         
     | 
| 22 | 
         
             
            """
         
     | 
| 23 | 
         
             
            import logging
         
     | 
| 24 | 
         
            +
            import json
         
     | 
| 25 | 
         
             
            import os
         
     | 
| 26 | 
         
            +
            import shutil
         
     | 
| 27 | 
         
             
            import sys
         
     | 
| 28 | 
         
             
            import time
         
     | 
| 29 | 
         
             
            from collections import defaultdict
         
     | 
| 30 | 
         
             
            from dataclasses import dataclass, field
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
         
     | 
| 33 | 
         
            +
            import joblib
         
     | 
| 34 | 
         
             
            from pathlib import Path
         
     | 
| 35 | 
         
             
            from typing import Dict, List, Optional, Tuple
         
     | 
| 36 | 
         | 
| 
         | 
|
| 42 | 
         
             
            import flax
         
     | 
| 43 | 
         
             
            import jax
         
     | 
| 44 | 
         
             
            import jax.numpy as jnp
         
     | 
| 45 | 
         
            +
            import kenlm  # pip install https://github.com/kpu/kenlm/archive/master.zip
         
     | 
| 46 | 
         
             
            import optax
         
     | 
| 47 | 
         
             
            from flax import jax_utils, traverse_util
         
     | 
| 48 | 
         
            +
            from flax.serialization import from_bytes, to_bytes
         
     | 
| 49 | 
         
             
            from flax.training import train_state
         
     | 
| 50 | 
         
             
            from flax.training.common_utils import get_metrics, onehot, shard
         
     | 
| 51 | 
         
             
            from transformers import (
         
     | 
| 
         | 
|
| 338 | 
         
             
                    summary_writer.scalar(f"eval_{metric_name}", value, step)
         
     | 
| 339 | 
         | 
| 340 | 
         | 
| 341 | 
         
            +
            def save_checkpoint_files(state, data_collator, training_args, save_dir):
         
     | 
| 342 | 
         
            +
                unreplicated_state = jax_utils.unreplicate(state)
         
     | 
| 343 | 
         
            +
                with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
         
     | 
| 344 | 
         
            +
                    f.write(to_bytes(unreplicated_state.opt_state))
         
     | 
| 345 | 
         
            +
                joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
         
     | 
| 346 | 
         
            +
                joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
         
     | 
| 347 | 
         
            +
                with open(os.path.join(save_dir, "training_state.json"), "w") as f:
         
     | 
| 348 | 
         
            +
                    json.dump({"step": unreplicated_state.step.item()}, f)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
            def rotate_checkpoints(path, max_checkpoints=5):
         
     | 
| 352 | 
         
            +
                paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
         
     | 
| 353 | 
         
            +
                if len(paths) > max_checkpoints:
         
     | 
| 354 | 
         
            +
                    for path_to_delete in paths[max_checkpoints:]:
         
     | 
| 355 | 
         
            +
                        try:
         
     | 
| 356 | 
         
            +
                            shutil.rmtree(path_to_delete)
         
     | 
| 357 | 
         
            +
                        except OSError:
         
     | 
| 358 | 
         
            +
                            os.remove(path_to_delete)
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
             
            if __name__ == "__main__":
         
     | 
| 362 | 
         
             
                # See all possible arguments in src/transformers/training_args.py
         
     | 
| 363 | 
         
             
                # or by passing the --help flag to this script.
         
     | 
| 
         | 
|
| 415 | 
         
             
                        filepaths["train"] = data_args.train_file
         
     | 
| 416 | 
         
             
                    if data_args.validation_file:
         
     | 
| 417 | 
         
             
                        filepaths["validation"] = data_args.validation_file
         
     | 
| 418 | 
         
            +
                    try:
         
     | 
| 419 | 
         
            +
                        dataset = load_dataset(
         
     | 
| 420 | 
         
            +
                            data_args.dataset_name,
         
     | 
| 421 | 
         
            +
                            data_args.dataset_config_name,
         
     | 
| 422 | 
         
            +
                            cache_dir=model_args.cache_dir,
         
     | 
| 423 | 
         
            +
                            streaming=True,
         
     | 
| 424 | 
         
            +
                            split="train",
         
     | 
| 425 | 
         
            +
                            sampling_method=sampling_args.sampling_method,
         
     | 
| 426 | 
         
            +
                            sampling_factor=sampling_args.sampling_factor,
         
     | 
| 427 | 
         
            +
                            boundaries=sampling_args.boundaries,
         
     | 
| 428 | 
         
            +
                            perplexity_model=sampling_args.perplexity_model,
         
     | 
| 429 | 
         
            +
                            seed=training_args.seed,
         
     | 
| 430 | 
         
            +
                            data_files=filepaths,
         
     | 
| 431 | 
         
            +
                        )
         
     | 
| 432 | 
         
            +
                    except Exception as exc:
         
     | 
| 433 | 
         
            +
                        logger.warning(
         
     | 
| 434 | 
         
            +
                            f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
         
     | 
| 435 | 
         
            +
                        )
         
     | 
| 436 | 
         
            +
                        dataset = load_dataset(
         
     | 
| 437 | 
         
            +
                            data_args.dataset_name,
         
     | 
| 438 | 
         
            +
                            data_args.dataset_config_name,
         
     | 
| 439 | 
         
            +
                            cache_dir=model_args.cache_dir,
         
     | 
| 440 | 
         
            +
                            streaming=True,
         
     | 
| 441 | 
         
            +
                            split="train",
         
     | 
| 442 | 
         
            +
                        )
         
     | 
| 443 | 
         | 
| 444 | 
         
             
                if model_args.config_name:
         
     | 
| 445 | 
         
             
                    config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
         
     | 
| 
         | 
|
| 698 | 
         
             
                            write_eval_metric(summary_writer, eval_metrics, step)
         
     | 
| 699 | 
         
             
                        eval_metrics = []
         
     | 
| 700 | 
         | 
| 701 | 
         
            +
                    # save checkpoint after eval_steps
         
     | 
| 702 | 
         
            +
                    if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
         
     | 
| 703 | 
         
            +
                        print(f"Saving checkpoint at {step + 1} steps")
         
     | 
| 704 | 
         
            +
                        params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
         
     | 
| 705 | 
         
            +
                        model.save_pretrained(
         
     | 
| 706 | 
         
            +
                            training_args.output_dir,
         
     | 
| 707 | 
         
            +
                            params=params,
         
     | 
| 708 | 
         
            +
                            push_to_hub=training_args.push_to_hub,
         
     | 
| 709 | 
         
            +
                            commit_message=f"Saving weights and logs of step {step + 1}",
         
     | 
| 710 | 
         
            +
                        )
         
     | 
| 711 | 
         
            +
                        save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
         
     | 
| 712 | 
         
            +
                        checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}"
         
     | 
| 713 | 
         
            +
                        checkpoints_dir.mkdir(parents=True, exist_ok=True)
         
     | 
| 714 | 
         
            +
                        model.save_pretrained(checkpoints_dir, params=params,)
         
     | 
| 715 | 
         
            +
                        save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
         
     | 
| 716 | 
         
            +
                        rotate_checkpoints(
         
     | 
| 717 | 
         
            +
                            Path(training_args.output_dir) / "checkpoints",
         
     | 
| 718 | 
         
            +
                            max_checkpoints=training_args.save_total_limit
         
     | 
| 719 | 
         
            +
                        )
         
     | 
| 720 | 
         | 
| 721 | 
         
             
                    # update tqdm bar
         
     | 
| 722 | 
         
             
                    steps.update(1)
         
     | 
    	
        run_stream.sh
    CHANGED
    
    | 
         @@ -4,9 +4,10 @@ python ./run_mlm_flax_stream.py \ 
     | 
|
| 4 | 
         
             
                --output_dir="./outputs" \
         
     | 
| 5 | 
         
             
                --model_type="roberta" \
         
     | 
| 6 | 
         
             
                --config_name="./configs/base" \
         
     | 
| 7 | 
         
            -
                --tokenizer_name="./" \
         
     | 
| 8 | 
         
             
                --dataset_name="./mc4" \
         
     | 
| 9 | 
         
             
                --dataset_config_name="es" \
         
     | 
| 
         | 
|
| 10 | 
         
             
                --max_seq_length="128" \
         
     | 
| 11 | 
         
             
                --pad_to_max_length  \
         
     | 
| 12 | 
         
             
                --per_device_train_batch_size="256" \
         
     | 
| 
         @@ -16,13 +17,11 @@ python ./run_mlm_flax_stream.py \ 
     | 
|
| 16 | 
         
             
                --adam_epsilon="1e-6" \
         
     | 
| 17 | 
         
             
                --learning_rate="6e-4" \
         
     | 
| 18 | 
         
             
                --weight_decay="0.01" \
         
     | 
| 19 | 
         
            -
                -- 
     | 
| 20 | 
         
            -
                --save_steps="1000" \
         
     | 
| 21 | 
         
             
                --save_total_limit="5" \
         
     | 
| 22 | 
         
             
                --warmup_steps="24000" \
         
     | 
| 23 | 
         
             
                --overwrite_output_dir \
         
     | 
| 24 | 
         
            -
                --num_train_steps=" 
     | 
| 25 | 
         
            -
                --eval_steps=" 
     | 
| 26 | 
         
             
                --dtype="bfloat16" \
         
     | 
| 27 | 
         
            -
                --sampling_method="steps" \
         
     | 
| 28 | 
         
             
                --logging_steps="500" 2>&1 | tee run_stream.log
         
     | 
| 
         | 
|
| 4 | 
         
             
                --output_dir="./outputs" \
         
     | 
| 5 | 
         
             
                --model_type="roberta" \
         
     | 
| 6 | 
         
             
                --config_name="./configs/base" \
         
     | 
| 7 | 
         
            +
                --tokenizer_name="./configs/base" \
         
     | 
| 8 | 
         
             
                --dataset_name="./mc4" \
         
     | 
| 9 | 
         
             
                --dataset_config_name="es" \
         
     | 
| 10 | 
         
            +
                --train_file="path/to/mc4-es-train-50M-XXX.jsonl" \
         
     | 
| 11 | 
         
             
                --max_seq_length="128" \
         
     | 
| 12 | 
         
             
                --pad_to_max_length  \
         
     | 
| 13 | 
         
             
                --per_device_train_batch_size="256" \
         
     | 
| 
         | 
|
| 17 | 
         
             
                --adam_epsilon="1e-6" \
         
     | 
| 18 | 
         
             
                --learning_rate="6e-4" \
         
     | 
| 19 | 
         
             
                --weight_decay="0.01" \
         
     | 
| 20 | 
         
            +
                --save_steps="10000" \
         
     | 
| 
         | 
|
| 21 | 
         
             
                --save_total_limit="5" \
         
     | 
| 22 | 
         
             
                --warmup_steps="24000" \
         
     | 
| 23 | 
         
             
                --overwrite_output_dir \
         
     | 
| 24 | 
         
            +
                --num_train_steps="250000" \
         
     | 
| 25 | 
         
            +
                --eval_steps="10000" \
         
     | 
| 26 | 
         
             
                --dtype="bfloat16" \
         
     | 
| 
         | 
|
| 27 | 
         
             
                --logging_steps="500" 2>&1 | tee run_stream.log
         
     |