Jarbas commited on
Commit
4ed478b
·
verified ·
1 Parent(s): 8f28f86

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -1
README.md CHANGED
@@ -36,4 +36,97 @@ datasets:
36
  - fdemelo/ipa-childes-split
37
  license: apache-2.0
38
  pipeline_tag: text-generation
39
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  - fdemelo/ipa-childes-split
37
  license: apache-2.0
38
  pipeline_tag: text-generation
39
+ ---
40
+
41
+ onnx version of [fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes](https://huggingface.co/fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes)
42
+
43
+ inference example
44
+
45
+ ```python
46
+ from transformers import AutoTokenizer
47
+ import onnxruntime
48
+ import numpy as np
49
+
50
+
51
+ def infer_onnx(text: str, lang: str, onnx_model_path: str = "byt5_g2p_model.onnx"):
52
+ """
53
+ Exports the ByT5 model to ONNX format and then performs inference using ONNX Runtime.
54
+
55
+ Args:
56
+ text (str): The input text to convert to phonemes.
57
+ lang (str): The language tag (e.g., "en").
58
+ onnx_model_path (str): The path to save/load the ONNX model.
59
+ """
60
+ model_name = 'fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes'
61
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+
63
+ # --- Step 2: Perform Inference with ONNX Runtime ---
64
+ print("\n--- Performing inference with ONNX Runtime ---")
65
+
66
+ # Create an ONNX Runtime session
67
+ try:
68
+ session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
69
+ except Exception as e:
70
+ print(f"Error loading ONNX model: {e}")
71
+ return
72
+
73
+ # Get input and output names from the ONNX model
74
+ onnx_input_names = [inp.name for inp in session.get_inputs()]
75
+ onnx_output_names = [out.name for out in session.get_outputs()]
76
+
77
+ # Prepare actual input for ONNX inference
78
+ input_text_for_onnx = f"<{lang}>: {text}"
79
+ inputs_for_onnx = tokenizer([input_text_for_onnx], return_tensors="pt", add_special_tokens=False)
80
+
81
+ input_ids_np = inputs_for_onnx["input_ids"].cpu().numpy()
82
+ attention_mask_np = inputs_for_onnx["attention_mask"].cpu().numpy()
83
+
84
+ # Manual greedy decoding loop for ONNX Runtime
85
+ # This simulates the 'generate' method's greedy decoding.
86
+ generated_ids = []
87
+ # T5 models typically use pad_token_id as the initial token for generation
88
+ # or a specific decoder_start_token_id.
89
+ # For T5, the decoder_start_token_id is usually the pad_token_id.
90
+ current_decoder_input_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
91
+
92
+ # Ensure it's a batch of 1
93
+ decoder_input_ids_np = np.array([[current_decoder_input_id]])
94
+
95
+ max_length = 512 # Same as in the original predict_byt5
96
+
97
+ # Store encoder outputs if needed for cross-attention in decoder (T5 does this)
98
+ # When exporting the full T5 model's forward pass, the encoder_hidden_states
99
+ # are implicitly handled within the graph. We just need to feed the decoder_input_ids.
100
+
101
+ for _ in range(max_length):
102
+ # Prepare inputs for the current step
103
+ onnx_inputs = {
104
+ "input_ids": input_ids_np,
105
+ "attention_mask": attention_mask_np,
106
+ "decoder_input_ids": decoder_input_ids_np
107
+ }
108
+
109
+ # Run inference
110
+ outputs = session.run(onnx_output_names, onnx_inputs)
111
+ logits = outputs[0] # Get the logits
112
+
113
+ # Get the logits for the last token in the sequence
114
+ next_token_logits = logits[0, -1, :] # Batch 0, last token, all vocab logits
115
+
116
+ # Greedy decoding: pick the token with the highest logit
117
+ next_token_id = np.argmax(next_token_logits)
118
+ generated_ids.append(next_token_id)
119
+
120
+ # Check for end-of-sequence token
121
+ if next_token_id == tokenizer.eos_token_id:
122
+ break
123
+
124
+ # Update decoder input for the next step
125
+ # Append the new token to the decoder input sequence
126
+ decoder_input_ids_np = np.concatenate((decoder_input_ids_np, np.array([[next_token_id]])), axis=1)
127
+
128
+ # Decode the generated ONNX phoneme IDs
129
+ onnx_phones = tokenizer.batch_decode([generated_ids], skip_special_tokens=True)
130
+ print(f"ONNX Runtime Inference: {onnx_phones}")
131
+ return onnx_phones
132
+ ``