Sanchit Gandhi commited on
Commit
cb91e87
·
1 Parent(s): 9738fad

Update convert_scan_to_unrolled.py

Browse files
Files changed (1) hide show
  1. convert_scan_to_unrolled.py +7 -2
convert_scan_to_unrolled.py CHANGED
@@ -1,5 +1,5 @@
1
  from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
2
- from transformers import SpeechEncoderDecoderModel, AutoConfig
3
  from flax.traverse_util import flatten_dict, unflatten_dict
4
  import collections
5
 
@@ -44,4 +44,9 @@ def scanned_to_unrolled(params):
44
 
45
  unrolled_model.params = scanned_to_unrolled(model.params)
46
 
47
- unrolled_model.save_pretrained("./")
 
 
 
 
 
 
1
  from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
2
+ from transformers import SpeechEncoderDecoderModel, AutoConfig, AutoFeatureExtractor, AutoTokenizer
3
  from flax.traverse_util import flatten_dict, unflatten_dict
4
  import collections
5
 
 
44
 
45
  unrolled_model.params = scanned_to_unrolled(model.params)
46
 
47
+ unrolled_model.save_pretrained("./")
48
+
49
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
50
+ feature_extractor.save_pretrained("./")
51
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
52
+ tokenizer.save_pretrained("./")