sanchit-gandhi HF staff commited on
Commit
e41bb55
1 Parent(s): 31d7cf2

Saving weights and logs of epoch 0

Browse files
config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "architectures": [
3
  "SpeechEncoderDecoderModel"
4
  ],
 
1
  {
2
+ "_name_or_path": "./",
3
  "architectures": [
4
  "SpeechEncoderDecoderModel"
5
  ],
events.out.tfevents.1647623127.t1v-n-4eb331dd-w-0.109040.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8cc2df64b0ca4a2236b9256fd012c09849a8ebffb8bba82e1f30d87f3b832c
3
+ size 40
events.out.tfevents.1647624498.t1v-n-4eb331dd-w-0.110942.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d960465927b4baa73a2e2f7d162582318d2bc5d03535397ef6fad511e0a93ea0
3
+ size 40
events.out.tfevents.1647625887.t1v-n-4eb331dd-w-0.115000.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92e87288c18763cf7eae5684e5cfd830378ebd19fc3e30e79387aa93421828eb
3
+ size 40
events.out.tfevents.1647626125.t1v-n-4eb331dd-w-0.116613.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be739bafc4ee271f78a915990b4341088fcfdce8601b621e59a9475c1ccd9a81
3
+ size 14838
events.out.tfevents.1647626511.t1v-n-4eb331dd-w-0.118537.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eed09b08de9e164e23417088d103a92327435e856b346f8ca112028e7f7d74c
3
+ size 40
events.out.tfevents.1647626831.t1v-n-4eb331dd-w-0.120349.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a77d713c620597fd2f46f26536620b1f808b02fc61ff831d87c7663750336bb
3
+ size 8961
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76f60c95bce357dcd23f2d771192e67edceba6efa2097e874e50084f0859cd91
3
+ size 2353635949
run_flax_speech_recognition_seq2seq.py DELETED
@@ -1,897 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2022 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Fine-tuning the Flax library models for sequence to sequence speech recognition.
18
- """
19
- # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
-
21
- import logging
22
- import os
23
- import sys
24
- import time
25
- from dataclasses import field
26
- from functools import partial
27
- from pathlib import Path
28
- from typing import Any, Callable, Dict, List, Optional, Union
29
-
30
- import datasets
31
- import numpy as np
32
- from datasets import DatasetDict, load_dataset, load_metric
33
- from tqdm import tqdm
34
-
35
- import flax
36
- import jax
37
- import jax.numpy as jnp
38
- import optax
39
- import transformers
40
- from flax import jax_utils, traverse_util
41
- from flax.jax_utils import unreplicate
42
- from flax.training import train_state
43
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
- from huggingface_hub import Repository
45
- from transformers import (
46
- AutoConfig,
47
- AutoFeatureExtractor,
48
- AutoProcessor,
49
- AutoTokenizer,
50
- FlaxAutoModelForSpeechSeq2Seq,
51
- HfArgumentParser,
52
- Seq2SeqTrainingArguments,
53
- is_tensorboard_available,
54
- )
55
- from transformers.file_utils import get_full_repo_name
56
- from transformers.trainer_utils import get_last_checkpoint, is_main_process
57
- from transformers.utils import check_min_version
58
- from transformers.utils.versions import require_version
59
-
60
-
61
- # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
62
- check_min_version("4.17.0.dev0")
63
-
64
- require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
65
-
66
- logger = logging.getLogger(__name__)
67
-
68
-
69
- @flax.struct.dataclass
70
- class ModelArguments:
71
- """
72
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
73
- """
74
-
75
- model_name_or_path: str = field(
76
- metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
77
- )
78
- config_name: Optional[str] = field(
79
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
- )
81
- tokenizer_name: Optional[str] = field(
82
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
83
- )
84
- feature_extractor_name: Optional[str] = field(
85
- default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
86
- )
87
- cache_dir: Optional[str] = field(
88
- default=None,
89
- metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
90
- )
91
- use_fast_tokenizer: bool = field(
92
- default=True,
93
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
- )
95
- model_revision: str = field(
96
- default="main",
97
- metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
98
- )
99
- use_auth_token: bool = field(
100
- default=False,
101
- metadata={
102
- "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
103
- "with private models)."
104
- },
105
- )
106
- freeze_feature_encoder: bool = field(
107
- default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
108
- )
109
-
110
-
111
- @flax.struct.dataclass
112
- class DataTrainingArguments:
113
- """
114
- Arguments pertaining to what data we are going to input our model for training and eval.
115
- """
116
-
117
- dataset_name: str = field(
118
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
119
- )
120
- dataset_config_name: Optional[str] = field(
121
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
122
- )
123
- text_column: Optional[str] = field(
124
- default=None,
125
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
126
- )
127
- overwrite_cache: bool = field(
128
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
129
- )
130
- preprocessing_num_workers: Optional[int] = field(
131
- default=None,
132
- metadata={"help": "The number of processes to use for the preprocessing."},
133
- )
134
- max_train_samples: Optional[int] = field(
135
- default=None,
136
- metadata={
137
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
138
- "value if set."
139
- },
140
- )
141
- max_eval_samples: Optional[int] = field(
142
- default=None,
143
- metadata={
144
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
145
- "value if set."
146
- },
147
- )
148
- audio_column_name: str = field(
149
- default="audio",
150
- metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
151
- )
152
- text_column_name: str = field(
153
- default="text",
154
- metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
155
- )
156
- max_duration_in_seconds: float = field(
157
- default=20.0,
158
- metadata={
159
- "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
160
- },
161
- )
162
- min_duration_in_seconds: float = field(
163
- default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
164
- )
165
- max_target_length: Optional[int] = field(
166
- default=128,
167
- metadata={
168
- "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
169
- "than this will be truncated, sequences shorter will be padded."
170
- },
171
- )
172
- min_target_length: Optional[int] = field(
173
- default=0,
174
- metadata={
175
- "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
176
- "than this will be filtered."
177
- },
178
- )
179
- pad_input_to_multiple_of: Optional[int] = field(
180
- default=None,
181
- metadata={
182
- "help": "If set will pad the input sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
183
- },
184
- )
185
- pad_target_to_multiple_of: Optional[int] = field(
186
- default=None,
187
- metadata={
188
- "help": "If set will pad the target sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
189
- },
190
- )
191
- preprocessing_only: bool = field(
192
- default=False,
193
- metadata={
194
- "help": "Whether to only do data preprocessing and skip training. "
195
- "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
196
- "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
197
- "so that the cached datasets can consequently be loaded in distributed training"
198
- },
199
- )
200
- train_split_name: str = field(
201
- default="train",
202
- metadata={
203
- "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
204
- },
205
- )
206
- eval_split_name: str = field(
207
- default="test",
208
- metadata={
209
- "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
210
- },
211
- )
212
- do_lower_case: bool = field(
213
- default=True,
214
- metadata={"help": "Whether the target text should be lower cased."},
215
- )
216
-
217
-
218
- class TrainState(train_state.TrainState):
219
- dropout_rng: jnp.ndarray
220
-
221
- def replicate(self):
222
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
223
-
224
-
225
- def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
226
- """
227
- Shift label ids one token to the right.
228
- """
229
- shifted_label_ids = np.zeros_like(label_ids)
230
- shifted_label_ids[:, 1:] = label_ids[:, :-1]
231
- shifted_label_ids[:, 0] = decoder_start_token_id
232
-
233
- return shifted_label_ids
234
-
235
-
236
- @flax.struct.dataclass
237
- class FlaxDataCollatorSpeechSeq2SeqWithPadding:
238
- """
239
- Data collator that will dynamically pad the inputs received.
240
- Args:
241
- processor ([`Wav2Vec2Processor`])
242
- The processor used for proccessing the data.
243
- decoder_start_token_id (`int`)
244
- The begin-of-sentence of the decoder.
245
- input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
246
- Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
247
- among:
248
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
249
- sequence if provided).
250
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
251
- maximum acceptable input length for the model if that argument is not provided.
252
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
253
- different lengths).
254
- target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
255
- Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
256
- See above for details.
257
- max_input_length (:obj:`float`, `optional`):
258
- Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
259
- max_target_length (:obj:`int`, `optional`):
260
- Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
261
- pad_input_to_multiple_of (:obj:`int`, `optional`):
262
- If set will pad the input sequence to a multiple of the provided value.
263
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
264
- 7.5 (Volta).
265
- pad_target_to_multiple_of (:obj:`int`, `optional`):
266
- If set will pad the target sequence to a multiple of the provided value.
267
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
268
- 7.5 (Volta).
269
- """
270
-
271
- processor: Any
272
- decoder_start_token_id: int
273
- input_padding: Union[bool, str] = "max_length"
274
- target_padding: Union[bool, str] = "max_length"
275
- max_input_length: Optional[float] = None
276
- max_target_length: Optional[int] = None
277
- pad_input_to_multiple_of: Optional[int] = None
278
- pad_target_to_multiple_of: Optional[int] = None
279
-
280
- def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
281
- # split inputs and labels since they have to be of different lengths and need
282
- # different padding methods
283
- input_features = [{"input_values": feature["input_values"]} for feature in features]
284
- label_features = [{"input_ids": feature["labels"]} for feature in features]
285
-
286
- # reformat list to dict and set to pytorch format
287
- batch = self.processor.feature_extractor.pad(
288
- input_features,
289
- max_length=self.max_input_length,
290
- padding=self.input_padding,
291
- pad_to_multiple_of=self.pad_input_to_multiple_of,
292
- return_tensors="np",
293
- )
294
-
295
- labels_batch = self.processor.tokenizer.pad(
296
- label_features,
297
- max_length=self.max_target_length,
298
- padding=self.target_padding,
299
- pad_to_multiple_of=self.pad_target_to_multiple_of,
300
- return_tensors="np",
301
- )
302
-
303
- # if bos token is appended in previous tokenization step,
304
- # cut bos token here as it's append later anyways
305
- labels = labels_batch["input_ids"]
306
- if (labels[:, 0] == self.decoder_start_token_id).all().item():
307
- labels = labels[:, 1:]
308
- labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
309
-
310
- decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
311
-
312
- # replace padding with -100 to ignore loss correctly
313
- labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
314
- labels = labels.filled(fill_value=-100)
315
-
316
- batch["inputs"] = batch.pop("input_values")
317
- batch["labels"] = labels
318
- batch["decoder_input_ids"] = decoder_input_ids
319
- # decoder_attention_mask known to give issues with nan's
320
- # remove decoder_attention_mask as an arg for the time being - handled by the causal mask in XXXForCausalLM
321
- # batch["decoder_attention_mask"] = labels_batch.attention_mask
322
-
323
- return batch
324
-
325
-
326
- def write_train_metric(summary_writer, train_metrics, train_time, step):
327
- summary_writer.scalar("train_time", train_time, step)
328
-
329
- train_metrics = get_metrics(train_metrics)
330
- for key, vals in train_metrics.items():
331
- tag = f"train_{key}"
332
- for i, val in enumerate(vals):
333
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
334
-
335
-
336
- def write_eval_metric(summary_writer, eval_metrics, step):
337
- for metric_name, value in eval_metrics.items():
338
- summary_writer.scalar(f"eval_{metric_name}", value, step)
339
-
340
-
341
- def create_learning_rate_fn(
342
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
343
- ) -> Callable[[int], jnp.array]:
344
- """Returns a linear warmup, linear_decay learning rate function."""
345
- steps_per_epoch = train_ds_size // train_batch_size
346
- num_train_steps = steps_per_epoch * num_train_epochs
347
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
348
- decay_fn = optax.linear_schedule(
349
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
350
- )
351
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
352
- return schedule_fn
353
-
354
-
355
- def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
356
- num_samples = len(samples_idx)
357
- samples_to_remove = num_samples % batch_size
358
-
359
- if samples_to_remove != 0:
360
- samples_idx = samples_idx[:-samples_to_remove]
361
- sections_split = num_samples // batch_size
362
- batch_idx = np.split(samples_idx, sections_split)
363
- return batch_idx
364
-
365
-
366
- def main():
367
- # 1. Parse input arguments
368
- # See all possible arguments in src/transformers/training_args.py
369
- # or by passing the --help flag to this script.
370
- # We now keep distinct sets of args, for a cleaner separation of concerns.
371
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
372
-
373
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
374
- # If we pass only one argument to the script and it's the path to a json file,
375
- # let's parse it to get our arguments.
376
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
377
- else:
378
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
379
-
380
- # 2. Setup logging
381
- logging.basicConfig(
382
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
383
- datefmt="%m/%d/%Y %H:%M:%S",
384
- handlers=[logging.StreamHandler(sys.stdout)],
385
- )
386
- # We only want one process per machine to log things on the screen.
387
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
388
- if jax.process_index() == 0:
389
- datasets.utils.logging.set_verbosity_warning()
390
- transformers.utils.logging.set_verbosity_info()
391
- else:
392
- datasets.utils.logging.set_verbosity_error()
393
- transformers.utils.logging.set_verbosity_error()
394
-
395
- # Log on each process the small summary:
396
- logger.warning(
397
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
398
- f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
399
- )
400
-
401
- # Set the verbosity to info of the Transformers logger (on main process only):
402
- if is_main_process(training_args.local_rank):
403
- transformers.utils.logging.set_verbosity_info()
404
- logger.info("Training/evaluation parameters %s", training_args)
405
-
406
- logger.info(f"JAX devices: {jax.device_count()}")
407
-
408
- # 3. Detecting last checkpoint and eventually continue from last checkpoint
409
- last_checkpoint = None
410
- if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
411
- last_checkpoint = get_last_checkpoint(training_args.output_dir)
412
- if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
413
- raise ValueError(
414
- f"Output directory ({training_args.output_dir}) already exists and is not empty. "
415
- "Use --overwrite_output_dir to overcome."
416
- )
417
- elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
418
- logger.info(
419
- f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
420
- "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
421
- )
422
-
423
- # 4. Load dataset
424
- raw_datasets = DatasetDict()
425
-
426
- if training_args.do_train:
427
- raw_datasets["train"] = load_dataset(
428
- data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
429
- )
430
-
431
- if training_args.do_eval:
432
- raw_datasets["eval"] = load_dataset(
433
- data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name
434
- )
435
-
436
- if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
437
- raise ValueError(
438
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
439
- "Make sure to set `--audio_column_name` to the correct audio column - one of "
440
- f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
441
- )
442
-
443
- if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
444
- raise ValueError(
445
- f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
446
- "Make sure to set `--text_column_name` to the correct text column - one of "
447
- f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
448
- )
449
-
450
- # 5. Load pretrained model, tokenizer, and feature extractor
451
- #
452
- # Distributed training:
453
- # The .from_pretrained methods guarantee that only one local process can concurrently
454
- config = AutoConfig.from_pretrained(
455
- model_args.config_name if model_args.config_name else model_args.model_name_or_path,
456
- cache_dir=model_args.cache_dir,
457
- revision=model_args.model_revision,
458
- use_auth_token=True if model_args.use_auth_token else None,
459
- )
460
-
461
- feature_extractor = AutoFeatureExtractor.from_pretrained(
462
- model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
463
- cache_dir=model_args.cache_dir,
464
- revision=model_args.model_revision,
465
- use_auth_token=True if model_args.use_auth_token else None,
466
- )
467
- tokenizer = AutoTokenizer.from_pretrained(
468
- model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
469
- cache_dir=model_args.cache_dir,
470
- use_fast=model_args.use_fast_tokenizer,
471
- revision=model_args.model_revision,
472
- use_auth_token=True if model_args.use_auth_token else None,
473
- )
474
- model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
475
- model_args.model_name_or_path,
476
- config=config,
477
- cache_dir=model_args.cache_dir,
478
- revision=model_args.model_revision,
479
- use_auth_token=True if model_args.use_auth_token else None,
480
- )
481
-
482
- if model.config.decoder_start_token_id is None:
483
- raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
484
-
485
- # 6. Resample speech dataset if necessary
486
- dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
487
- if dataset_sampling_rate != feature_extractor.sampling_rate:
488
- raw_datasets = raw_datasets.cast_column(
489
- data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
490
- )
491
-
492
- # 7. Preprocessing the datasets.
493
- # We need to read the audio files as arrays and tokenize the targets.
494
- max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
495
- min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
496
- max_target_length = data_args.max_target_length
497
- min_target_length = data_args.min_target_length
498
- pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
499
- pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
500
- audio_column_name = data_args.audio_column_name
501
- num_workers = data_args.preprocessing_num_workers
502
- text_column_name = data_args.text_column_name
503
- model_input_name = feature_extractor.model_input_names[0]
504
- do_lower_case = data_args.do_lower_case
505
-
506
- if data_args.max_train_samples is not None:
507
- raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
508
-
509
- if data_args.max_eval_samples is not None:
510
- raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
511
-
512
- def prepare_dataset(batch):
513
- # process audio
514
- sample = batch[audio_column_name]
515
- inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
516
- # process audio length
517
- batch[model_input_name] = inputs.input_values[0]
518
- batch["input_length"] = len(batch["input_values"])
519
-
520
- # process targets
521
- input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
522
- batch["labels"] = tokenizer(input_str).input_ids
523
- batch["labels_length"] = len(batch["labels"])
524
- return batch
525
-
526
- with training_args.main_process_first(desc="dataset map pre-processing"):
527
- vectorized_datasets = raw_datasets.map(
528
- prepare_dataset,
529
- remove_columns=next(iter(raw_datasets.values())).column_names,
530
- num_proc=data_args.preprocessing_num_workers,
531
- desc="preprocess train dataset",
532
- )
533
-
534
- # filter data with inputs shorter than min_input_length or longer than
535
- # max_input_length
536
- def is_audio_in_length_range(length):
537
- return length > min_input_length and length < max_input_length
538
-
539
- vectorized_datasets = vectorized_datasets.filter(
540
- is_audio_in_length_range,
541
- num_proc=num_workers,
542
- input_columns=["input_length"],
543
- )
544
-
545
- # filter data with targets shorter than min_target_length or longer than
546
- # max_target_length
547
- def is_labels_in_length_range(length):
548
- return length > min_target_length and length < max_target_length
549
-
550
- vectorized_datasets = vectorized_datasets.filter(
551
- is_labels_in_length_range,
552
- num_proc=num_workers,
553
- input_columns=["labels_length"],
554
- )
555
-
556
- # for large datasets it is advised to run the preprocessing on a
557
- # single machine first with `args.preprocessing_only` since there will mostly likely
558
- # be a timeout when running the script in distributed mode.
559
- # In a second step `args.preprocessing_only` can then be set to `False` to load the
560
- # cached dataset
561
- if data_args.preprocessing_only:
562
- cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
563
- logger.info(f"Data preprocessing finished. Files cached at {cache}.")
564
- return
565
-
566
- # 8. Load Metric
567
- metric = load_metric("wer")
568
-
569
- def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
570
- padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
571
-
572
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
573
- # we do not want to group tokens when computing the metrics
574
- label_str = tokenizer.batch_decode(padded_ids, skip_special_tokens=True)
575
-
576
- wer = metric.compute(predictions=pred_str, references=label_str)
577
-
578
- return {"wer": wer}
579
-
580
- # 9. Create a single speech processor
581
- if is_main_process(training_args.local_rank):
582
- # save feature extractor, tokenizer and config
583
- feature_extractor.save_pretrained(training_args.output_dir)
584
- tokenizer.save_pretrained(training_args.output_dir)
585
- config.save_pretrained(training_args.output_dir)
586
-
587
- processor = AutoProcessor.from_pretrained(training_args.output_dir)
588
-
589
- data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
590
- processor=processor,
591
- decoder_start_token_id=model.config.decoder_start_token_id,
592
- input_padding="max_length",
593
- target_padding="max_length",
594
- max_input_length=max_input_length,
595
- max_target_length=max_target_length,
596
- pad_input_to_multiple_of=pad_input_to_multiple_of,
597
- pad_target_to_multiple_of=pad_target_to_multiple_of,
598
- )
599
-
600
- # Enable tensorboard only on the master node
601
- has_tensorboard = is_tensorboard_available()
602
- if has_tensorboard and jax.process_index() == 0:
603
- try:
604
- from flax.metrics.tensorboard import SummaryWriter
605
-
606
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
607
- except ImportError as ie:
608
- has_tensorboard = False
609
- logger.warning(
610
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
611
- )
612
- else:
613
- logger.warning(
614
- "Unable to display metrics through TensorBoard because the package is not installed: "
615
- "Please run `pip install tensorboard` to enable."
616
- )
617
-
618
- # 10. Handle the repository creation
619
- if training_args.push_to_hub:
620
- if training_args.hub_model_id is None:
621
- repo_name = get_full_repo_name(
622
- Path(training_args.output_dir).absolute().name, token=training_args.hub_token
623
- )
624
- else:
625
- repo_name = training_args.hub_model_id
626
- repo = Repository(training_args.output_dir, clone_from=repo_name)
627
-
628
- # 11. Initialize our training
629
- rng = jax.random.PRNGKey(training_args.seed)
630
- rng, dropout_rng = jax.random.split(rng)
631
-
632
- # Store some constant
633
- num_epochs = int(training_args.num_train_epochs)
634
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
635
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
636
- steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
637
- total_train_steps = steps_per_epoch * num_epochs
638
- gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
639
-
640
- # Create learning rate schedule
641
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
642
- len(vectorized_datasets["train"]),
643
- train_batch_size,
644
- training_args.num_train_epochs,
645
- training_args.warmup_steps,
646
- training_args.learning_rate,
647
- )
648
-
649
- # We use Optax's "masking" functionality to not apply weight decay
650
- # to bias and LayerNorm scale parameters. decay_mask_fn returns a
651
- # mask boolean with the same structure as the parameters.
652
- # The mask is True for parameters that should be decayed.
653
- # Note that this mask is specifically adapted for FlaxBart.
654
- # For FlaxT5, one should correct the layer norm parameter naming
655
- # accordingly - see `run_t5_mlm_flax.py` e.g.
656
- # TODO: check param dictionary of encoder and decoder match the layer_norm_params list
657
- def decay_mask_fn(params):
658
- flat_params = traverse_util.flatten_dict(params)
659
- layer_norm_params = [
660
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
661
- ]
662
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
663
- return traverse_util.unflatten_dict(flat_mask)
664
-
665
- # create adam optimizer
666
- adamw = optax.adamw(
667
- learning_rate=linear_decay_lr_schedule_fn,
668
- b1=training_args.adam_beta1,
669
- b2=training_args.adam_beta2,
670
- eps=training_args.adam_epsilon,
671
- weight_decay=training_args.weight_decay,
672
- mask=decay_mask_fn,
673
- )
674
-
675
- # augment adam optimizer to facilitate gradient accumulation (ignore for now)
676
- # optim = optax.chain(adamw, optax.apply_every(gradient_accumulation_steps))
677
-
678
- # Setup train state
679
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
680
-
681
- # label smoothed cross entropy
682
- def loss_fn(logits, labels, label_smoothing_factor=0.0):
683
- """
684
- The label smoothing implementation is adapted from Flax's official example:
685
- https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
686
- """
687
- vocab_size = logits.shape[-1]
688
- confidence = 1.0 - label_smoothing_factor
689
- low_confidence = (1.0 - confidence) / (vocab_size - 1)
690
- normalizing_constant = -(
691
- confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
692
- )
693
- soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
694
-
695
- loss = optax.softmax_cross_entropy(logits, soft_labels)
696
- loss = loss - normalizing_constant
697
-
698
- # ignore padded tokens from loss, i.e. where labels are not set to -100
699
- padding = labels > 0
700
- loss = loss * padding
701
- loss = loss.sum() / padding.sum()
702
- return loss
703
-
704
- # Define gradient update step fn
705
- def train_step(state, batch, label_smoothing_factor=0.0):
706
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
707
-
708
- def compute_loss(params):
709
- labels = batch.pop("labels")
710
- outputs = state.apply_fn(
711
- **batch,
712
- params=params,
713
- dropout_rng=dropout_rng,
714
- freeze_feature_encoder=model_args.freeze_feature_encoder,
715
- return_dict=True,
716
- output_attentions=True,
717
- output_hidden_states=True,
718
- train=True,
719
- )
720
- encoder_hidden_states = jnp.asarray(outputs.encoder_hidden_states)
721
- encoder_outputs = outputs.encoder_last_hidden_state
722
- decoder_hidden_states = jnp.asarray(outputs.decoder_hidden_states)
723
- logits = outputs.logits
724
-
725
- # check for nan in inputs by taking l2-norm over inputs
726
- # a single nan in the inputs will return a nan when normed
727
- logs = {"inputs": jnp.linalg.norm(batch["inputs"])}
728
-
729
- # check for nan in encoder_hidden_states, encoder_outputs
730
- logs["encoder_hidden_states"] = jnp.linalg.norm(
731
- encoder_hidden_states.reshape(-1, encoder_hidden_states.shape[0]), axis=0
732
- )
733
- logs["encoder_outputs"] = jnp.linalg.norm(encoder_outputs)
734
-
735
- # check for nan in decoder_hidden_states, decoder_outputs (logits)
736
- logs["decoder_hidden_states"] = jnp.linalg.norm(
737
- decoder_hidden_states.reshape(-1, decoder_hidden_states.shape[0]), axis=0
738
- )
739
- logs["logits"] = jnp.linalg.norm(logits)
740
-
741
- loss = loss_fn(logits, labels, label_smoothing_factor)
742
- # normalize loss over gradient accumulation steps (ignore for now)
743
- # loss = loss / gradient_accumulation_steps
744
- return loss, logs
745
-
746
- grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
747
- (loss, logs), grad = grad_fn(state.params)
748
- # TODO: compute loss correctly over pmapped axis
749
- grad = jax.lax.pmean(grad, "batch")
750
-
751
- # compute gradient norm for monitoring
752
- # (re-introduce when no nan's on forward pass, currently meaningless)
753
- # grad_norm = jnp.linalg.norm(jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad)))
754
-
755
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
756
-
757
- # don't log learning-rate and grad-norm until forward pass returns real-valued numbers
758
- metrics = {"loss": loss}
759
- metrics.update(logs)
760
- metrics = jax.lax.pmean(metrics, axis_name="batch")
761
-
762
- return new_state, metrics
763
-
764
- # Define eval fn
765
- def eval_step(params, batch, label_smoothing_factor=0.0):
766
- labels = batch.pop("labels")
767
- logits = model(**batch, params=params, train=False)[0]
768
- loss = loss_fn(logits, labels, label_smoothing_factor)
769
-
770
- # summarize metrics
771
- metrics = {"loss": loss}
772
- metrics = jax.lax.pmean(metrics, axis_name="batch")
773
- return metrics
774
-
775
- # Define generation function
776
- gen_kwargs = {"max_length": training_args.generation_max_length, "num_beams": training_args.generation_num_beams}
777
-
778
- def generate_step(params, batch):
779
- model.params = params
780
- output_ids = model.generate(batch["inputs"], **gen_kwargs)
781
- return output_ids.sequences
782
-
783
- # Create parallel version of the train and eval step
784
- p_train_step = jax.pmap(
785
- partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
786
- )
787
- p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
788
- p_generate_step = jax.pmap(generate_step, "batch")
789
-
790
- # Replicate the train state on each device
791
- state = state.replicate()
792
-
793
- logger.info("***** Running training *****")
794
- logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
795
- logger.info(f" Num Epochs = {num_epochs}")
796
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
797
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
798
- logger.info(f" Total optimization steps = {total_train_steps}")
799
-
800
- train_time = 0
801
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
802
- for epoch in epochs:
803
- # ======================== Training ================================
804
- train_start = time.time()
805
-
806
- # Create sampling rng
807
- rng, input_rng = jax.random.split(rng)
808
- train_metrics = []
809
-
810
- # Generate an epoch by shuffling sampling indices from the train dataset
811
- num_train_samples = len(vectorized_datasets["train"])
812
- train_samples_idx = np.random.permutation(np.arange(num_train_samples))
813
- train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
814
-
815
- # Gather the indexes for creating the batch and do a training step
816
- for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
817
- samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
818
- batch = data_collator(samples)
819
- batch = shard(batch.data)
820
- state, train_metric = p_train_step(state, batch)
821
- train_metrics.append(train_metric)
822
-
823
- cur_step = epoch * (num_train_samples // train_batch_size) + step
824
-
825
- if cur_step % training_args.logging_steps == 0 and cur_step > 0:
826
- # Save metrics
827
- train_metric = jax_utils.unreplicate(train_metric)
828
- train_time += time.time() - train_start
829
- # if has_tensorboard and jax.process_index() == 0:
830
- # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
831
-
832
- # Log everything
833
- metric_desc = " ".join([f"{key}: {value} |" for key, value in train_metric.items()])
834
- epochs.write(f"Step... ({cur_step}) | {metric_desc}")
835
-
836
- train_metrics = []
837
-
838
- # epochs.write(
839
- # f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
840
- # )
841
-
842
- continue
843
- # ======================== Evaluating ==============================
844
- eval_metrics = []
845
- eval_preds = []
846
- eval_labels = []
847
-
848
- num_eval_samples = len(vectorized_datasets["eval"])
849
- eval_samples_idx = jnp.arange(num_eval_samples)
850
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
851
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
852
- samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
853
- batch = data_collator(samples)
854
- batch = shard(batch.data)
855
- labels = batch["labels"]
856
-
857
- metrics = p_eval_step(state.params, batch)
858
- eval_metrics.append(metrics)
859
-
860
- # generation
861
- if training_args.predict_with_generate:
862
- generated_ids = p_generate_step(state.params, batch)
863
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
864
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
865
-
866
- # normalize eval metrics
867
- eval_metrics = get_metrics(eval_metrics)
868
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
869
-
870
- # compute WER metric
871
- wer_desc = ""
872
- if training_args.predict_with_generate:
873
- wer_metric = compute_metrics(eval_preds, eval_labels)
874
- eval_metrics.update(wer_metric)
875
- wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
876
-
877
- # Print metrics and update progress bar
878
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
879
- epochs.write(desc)
880
- epochs.desc = desc
881
-
882
- # Save metrics
883
- if has_tensorboard and jax.process_index() == 0:
884
- cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
885
- write_eval_metric(summary_writer, eval_metrics, cur_step)
886
-
887
- # save checkpoint after each epoch and push checkpoint to the hub
888
- if jax.process_index() == 0:
889
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
890
- model.save_pretrained(training_args.output_dir, params=params)
891
- tokenizer.save_pretrained(training_args.output_dir)
892
- if training_args.push_to_hub:
893
- repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
894
-
895
-
896
- if __name__ == "__main__":
897
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_flax_speech_recognition_seq2seq.py ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/sanchitgandhi/transformers/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py
run_librispeech.sh CHANGED
@@ -1,5 +1,5 @@
1
  #!/usr/bin/env bash
2
- JAX_DEFAULT_MATMUL_PRECISION=float32 python run_flax_speech_recognition_seq2seq.py \
3
  --dataset_name="librispeech_asr" \
4
  --model_name_or_path="./" \
5
  --dataset_config_name="clean" \
@@ -10,11 +10,11 @@ JAX_DEFAULT_MATMUL_PRECISION=float32 python run_flax_speech_recognition_seq2seq.
10
  --length_column_name="input_length" \
11
  --overwrite_output_dir \
12
  --num_train_epochs="1" \
13
- --per_device_train_batch_size="2" \
14
- --per_device_eval_batch_size="2" \
15
  --logging_steps="1" \
16
- --max_duration_in_seconds="10" \
17
- --max_target_length="32" \
18
  --generation_max_length="40" \
19
  --generation_num_beams="1" \
20
  --learning_rate="3e-4" \
@@ -25,5 +25,7 @@ JAX_DEFAULT_MATMUL_PRECISION=float32 python run_flax_speech_recognition_seq2seq.
25
  --predict_with_generate \
26
  --do_lower_case \
27
  --do_eval \
28
- --do_train
 
 
29
 
 
1
  #!/usr/bin/env bash
2
+ python run_flax_speech_recognition_seq2seq.py \
3
  --dataset_name="librispeech_asr" \
4
  --model_name_or_path="./" \
5
  --dataset_config_name="clean" \
 
10
  --length_column_name="input_length" \
11
  --overwrite_output_dir \
12
  --num_train_epochs="1" \
13
+ --per_device_train_batch_size="4" \
14
+ --per_device_eval_batch_size="4" \
15
  --logging_steps="1" \
16
+ --max_duration_in_seconds="15" \
17
+ --max_target_length="64" \
18
  --generation_max_length="40" \
19
  --generation_num_beams="1" \
20
  --learning_rate="3e-4" \
 
25
  --predict_with_generate \
26
  --do_lower_case \
27
  --do_eval \
28
+ --do_train \
29
+ --push_to_hub \
30
+ --use_auth_token
31
 
tokenizer_config.json CHANGED
@@ -1 +1 @@
1
- {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "facebook/bart-large-cnn", "tokenizer_class": "BartTokenizer"}
 
1
+ {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "BartTokenizer"}