|
|
--- |
|
|
base_model: |
|
|
- fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes |
|
|
language: |
|
|
- multilingual |
|
|
- ca |
|
|
- cy |
|
|
- da |
|
|
- de |
|
|
- en |
|
|
- es |
|
|
- et |
|
|
- eu |
|
|
- fa |
|
|
- fr |
|
|
- ga |
|
|
- hr |
|
|
- hu |
|
|
- id |
|
|
- is |
|
|
- it |
|
|
- ja |
|
|
- ko |
|
|
- nl |
|
|
- 'no' |
|
|
- pl |
|
|
- pt |
|
|
- qu |
|
|
- ro |
|
|
- sr |
|
|
- sv |
|
|
- tr |
|
|
- zh |
|
|
- yue |
|
|
datasets: |
|
|
- fdemelo/ipa-childes-split |
|
|
license: apache-2.0 |
|
|
pipeline_tag: text-generation |
|
|
--- |
|
|
|
|
|
onnx version of [fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes](https://huggingface.co/fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes) |
|
|
|
|
|
inference example |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer |
|
|
import onnxruntime |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def infer_onnx(text: str, lang: str, onnx_model_path: str = "byt5_g2p_model.onnx"): |
|
|
""" |
|
|
Exports the ByT5 model to ONNX format and then performs inference using ONNX Runtime. |
|
|
|
|
|
Args: |
|
|
text (str): The input text to convert to phonemes. |
|
|
lang (str): The language tag (e.g., "en"). |
|
|
onnx_model_path (str): The path to save/load the ONNX model. |
|
|
""" |
|
|
model_name = 'fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes' |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
# --- Step 2: Perform Inference with ONNX Runtime --- |
|
|
print("\n--- Performing inference with ONNX Runtime ---") |
|
|
|
|
|
# Create an ONNX Runtime session |
|
|
try: |
|
|
session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) |
|
|
except Exception as e: |
|
|
print(f"Error loading ONNX model: {e}") |
|
|
return |
|
|
|
|
|
# Get input and output names from the ONNX model |
|
|
onnx_input_names = [inp.name for inp in session.get_inputs()] |
|
|
onnx_output_names = [out.name for out in session.get_outputs()] |
|
|
|
|
|
# Prepare actual input for ONNX inference |
|
|
input_text_for_onnx = f"<{lang}>: {text}" |
|
|
inputs_for_onnx = tokenizer([input_text_for_onnx], return_tensors="pt", add_special_tokens=False) |
|
|
|
|
|
input_ids_np = inputs_for_onnx["input_ids"].cpu().numpy() |
|
|
attention_mask_np = inputs_for_onnx["attention_mask"].cpu().numpy() |
|
|
|
|
|
# Manual greedy decoding loop for ONNX Runtime |
|
|
# This simulates the 'generate' method's greedy decoding. |
|
|
generated_ids = [] |
|
|
# T5 models typically use pad_token_id as the initial token for generation |
|
|
# or a specific decoder_start_token_id. |
|
|
# For T5, the decoder_start_token_id is usually the pad_token_id. |
|
|
current_decoder_input_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 |
|
|
|
|
|
# Ensure it's a batch of 1 |
|
|
decoder_input_ids_np = np.array([[current_decoder_input_id]]) |
|
|
|
|
|
max_length = 512 # Same as in the original predict_byt5 |
|
|
|
|
|
# Store encoder outputs if needed for cross-attention in decoder (T5 does this) |
|
|
# When exporting the full T5 model's forward pass, the encoder_hidden_states |
|
|
# are implicitly handled within the graph. We just need to feed the decoder_input_ids. |
|
|
|
|
|
for _ in range(max_length): |
|
|
# Prepare inputs for the current step |
|
|
onnx_inputs = { |
|
|
"input_ids": input_ids_np, |
|
|
"attention_mask": attention_mask_np, |
|
|
"decoder_input_ids": decoder_input_ids_np |
|
|
} |
|
|
|
|
|
# Run inference |
|
|
outputs = session.run(onnx_output_names, onnx_inputs) |
|
|
logits = outputs[0] # Get the logits |
|
|
|
|
|
# Get the logits for the last token in the sequence |
|
|
next_token_logits = logits[0, -1, :] # Batch 0, last token, all vocab logits |
|
|
|
|
|
# Greedy decoding: pick the token with the highest logit |
|
|
next_token_id = np.argmax(next_token_logits) |
|
|
generated_ids.append(next_token_id) |
|
|
|
|
|
# Check for end-of-sequence token |
|
|
if next_token_id == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
# Update decoder input for the next step |
|
|
# Append the new token to the decoder input sequence |
|
|
decoder_input_ids_np = np.concatenate((decoder_input_ids_np, np.array([[next_token_id]])), axis=1) |
|
|
|
|
|
# Decode the generated ONNX phoneme IDs |
|
|
onnx_phones = tokenizer.batch_decode([generated_ids], skip_special_tokens=True) |
|
|
print(f"ONNX Runtime Inference: {onnx_phones}") |
|
|
return onnx_phones |
|
|
`` |