Sanchit Gandhi
commited on
Commit
·
cb91e87
1
Parent(s):
9738fad
Update convert_scan_to_unrolled.py
Browse files
convert_scan_to_unrolled.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
2 |
-
from transformers import SpeechEncoderDecoderModel, AutoConfig
|
3 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
4 |
import collections
|
5 |
|
@@ -44,4 +44,9 @@ def scanned_to_unrolled(params):
|
|
44 |
|
45 |
unrolled_model.params = scanned_to_unrolled(model.params)
|
46 |
|
47 |
-
unrolled_model.save_pretrained("./")
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
|
2 |
+
from transformers import SpeechEncoderDecoderModel, AutoConfig, AutoFeatureExtractor, AutoTokenizer
|
3 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
4 |
import collections
|
5 |
|
|
|
44 |
|
45 |
unrolled_model.params = scanned_to_unrolled(model.params)
|
46 |
|
47 |
+
unrolled_model.save_pretrained("./")
|
48 |
+
|
49 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
50 |
+
feature_extractor.save_pretrained("./")
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
52 |
+
tokenizer.save_pretrained("./")
|