Update README.md
Browse files
README.md
CHANGED
|
@@ -36,4 +36,97 @@ datasets:
|
|
| 36 |
- fdemelo/ipa-childes-split
|
| 37 |
license: apache-2.0
|
| 38 |
pipeline_tag: text-generation
|
| 39 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
- fdemelo/ipa-childes-split
|
| 37 |
license: apache-2.0
|
| 38 |
pipeline_tag: text-generation
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
onnx version of [fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes](https://huggingface.co/fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes)
|
| 42 |
+
|
| 43 |
+
inference example
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
from transformers import AutoTokenizer
|
| 47 |
+
import onnxruntime
|
| 48 |
+
import numpy as np
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def infer_onnx(text: str, lang: str, onnx_model_path: str = "byt5_g2p_model.onnx"):
|
| 52 |
+
"""
|
| 53 |
+
Exports the ByT5 model to ONNX format and then performs inference using ONNX Runtime.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
text (str): The input text to convert to phonemes.
|
| 57 |
+
lang (str): The language tag (e.g., "en").
|
| 58 |
+
onnx_model_path (str): The path to save/load the ONNX model.
|
| 59 |
+
"""
|
| 60 |
+
model_name = 'fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes'
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 62 |
+
|
| 63 |
+
# --- Step 2: Perform Inference with ONNX Runtime ---
|
| 64 |
+
print("\n--- Performing inference with ONNX Runtime ---")
|
| 65 |
+
|
| 66 |
+
# Create an ONNX Runtime session
|
| 67 |
+
try:
|
| 68 |
+
session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error loading ONNX model: {e}")
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
# Get input and output names from the ONNX model
|
| 74 |
+
onnx_input_names = [inp.name for inp in session.get_inputs()]
|
| 75 |
+
onnx_output_names = [out.name for out in session.get_outputs()]
|
| 76 |
+
|
| 77 |
+
# Prepare actual input for ONNX inference
|
| 78 |
+
input_text_for_onnx = f"<{lang}>: {text}"
|
| 79 |
+
inputs_for_onnx = tokenizer([input_text_for_onnx], return_tensors="pt", add_special_tokens=False)
|
| 80 |
+
|
| 81 |
+
input_ids_np = inputs_for_onnx["input_ids"].cpu().numpy()
|
| 82 |
+
attention_mask_np = inputs_for_onnx["attention_mask"].cpu().numpy()
|
| 83 |
+
|
| 84 |
+
# Manual greedy decoding loop for ONNX Runtime
|
| 85 |
+
# This simulates the 'generate' method's greedy decoding.
|
| 86 |
+
generated_ids = []
|
| 87 |
+
# T5 models typically use pad_token_id as the initial token for generation
|
| 88 |
+
# or a specific decoder_start_token_id.
|
| 89 |
+
# For T5, the decoder_start_token_id is usually the pad_token_id.
|
| 90 |
+
current_decoder_input_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| 91 |
+
|
| 92 |
+
# Ensure it's a batch of 1
|
| 93 |
+
decoder_input_ids_np = np.array([[current_decoder_input_id]])
|
| 94 |
+
|
| 95 |
+
max_length = 512 # Same as in the original predict_byt5
|
| 96 |
+
|
| 97 |
+
# Store encoder outputs if needed for cross-attention in decoder (T5 does this)
|
| 98 |
+
# When exporting the full T5 model's forward pass, the encoder_hidden_states
|
| 99 |
+
# are implicitly handled within the graph. We just need to feed the decoder_input_ids.
|
| 100 |
+
|
| 101 |
+
for _ in range(max_length):
|
| 102 |
+
# Prepare inputs for the current step
|
| 103 |
+
onnx_inputs = {
|
| 104 |
+
"input_ids": input_ids_np,
|
| 105 |
+
"attention_mask": attention_mask_np,
|
| 106 |
+
"decoder_input_ids": decoder_input_ids_np
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Run inference
|
| 110 |
+
outputs = session.run(onnx_output_names, onnx_inputs)
|
| 111 |
+
logits = outputs[0] # Get the logits
|
| 112 |
+
|
| 113 |
+
# Get the logits for the last token in the sequence
|
| 114 |
+
next_token_logits = logits[0, -1, :] # Batch 0, last token, all vocab logits
|
| 115 |
+
|
| 116 |
+
# Greedy decoding: pick the token with the highest logit
|
| 117 |
+
next_token_id = np.argmax(next_token_logits)
|
| 118 |
+
generated_ids.append(next_token_id)
|
| 119 |
+
|
| 120 |
+
# Check for end-of-sequence token
|
| 121 |
+
if next_token_id == tokenizer.eos_token_id:
|
| 122 |
+
break
|
| 123 |
+
|
| 124 |
+
# Update decoder input for the next step
|
| 125 |
+
# Append the new token to the decoder input sequence
|
| 126 |
+
decoder_input_ids_np = np.concatenate((decoder_input_ids_np, np.array([[next_token_id]])), axis=1)
|
| 127 |
+
|
| 128 |
+
# Decode the generated ONNX phoneme IDs
|
| 129 |
+
onnx_phones = tokenizer.batch_decode([generated_ids], skip_special_tokens=True)
|
| 130 |
+
print(f"ONNX Runtime Inference: {onnx_phones}")
|
| 131 |
+
return onnx_phones
|
| 132 |
+
``
|