G2P models
Collection
Phonemizer models for usage with TTS systems
•
2 items
•
Updated
onnx version of fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes
inference example
from transformers import AutoTokenizer
import onnxruntime
import numpy as np
def infer_onnx(text: str, lang: str, onnx_model_path: str = "byt5_g2p_model.onnx"):
"""
Exports the ByT5 model to ONNX format and then performs inference using ONNX Runtime.
Args:
text (str): The input text to convert to phonemes.
lang (str): The language tag (e.g., "en").
onnx_model_path (str): The path to save/load the ONNX model.
"""
model_name = 'fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes'
tokenizer = AutoTokenizer.from_pretrained(model_name)
# --- Step 2: Perform Inference with ONNX Runtime ---
print("\n--- Performing inference with ONNX Runtime ---")
# Create an ONNX Runtime session
try:
session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
except Exception as e:
print(f"Error loading ONNX model: {e}")
return
# Get input and output names from the ONNX model
onnx_input_names = [inp.name for inp in session.get_inputs()]
onnx_output_names = [out.name for out in session.get_outputs()]
# Prepare actual input for ONNX inference
input_text_for_onnx = f"<{lang}>: {text}"
inputs_for_onnx = tokenizer([input_text_for_onnx], return_tensors="pt", add_special_tokens=False)
input_ids_np = inputs_for_onnx["input_ids"].cpu().numpy()
attention_mask_np = inputs_for_onnx["attention_mask"].cpu().numpy()
# Manual greedy decoding loop for ONNX Runtime
# This simulates the 'generate' method's greedy decoding.
generated_ids = []
# T5 models typically use pad_token_id as the initial token for generation
# or a specific decoder_start_token_id.
# For T5, the decoder_start_token_id is usually the pad_token_id.
current_decoder_input_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# Ensure it's a batch of 1
decoder_input_ids_np = np.array([[current_decoder_input_id]])
max_length = 512 # Same as in the original predict_byt5
# Store encoder outputs if needed for cross-attention in decoder (T5 does this)
# When exporting the full T5 model's forward pass, the encoder_hidden_states
# are implicitly handled within the graph. We just need to feed the decoder_input_ids.
for _ in range(max_length):
# Prepare inputs for the current step
onnx_inputs = {
"input_ids": input_ids_np,
"attention_mask": attention_mask_np,
"decoder_input_ids": decoder_input_ids_np
}
# Run inference
outputs = session.run(onnx_output_names, onnx_inputs)
logits = outputs[0] # Get the logits
# Get the logits for the last token in the sequence
next_token_logits = logits[0, -1, :] # Batch 0, last token, all vocab logits
# Greedy decoding: pick the token with the highest logit
next_token_id = np.argmax(next_token_logits)
generated_ids.append(next_token_id)
# Check for end-of-sequence token
if next_token_id == tokenizer.eos_token_id:
break
# Update decoder input for the next step
# Append the new token to the decoder input sequence
decoder_input_ids_np = np.concatenate((decoder_input_ids_np, np.array([[next_token_id]])), axis=1)
# Decode the generated ONNX phoneme IDs
onnx_phones = tokenizer.batch_decode([generated_ids], skip_special_tokens=True)
print(f"ONNX Runtime Inference: {onnx_phones}")
return onnx_phones
``