viswadarshan06's picture
Update app.py
fcb7d2b verified
from fastapi import FastAPI, Request
from huggingface_hub import hf_hub_download
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import os
from huggingface_hub import hf_hub_download, HfApi
hf_token = os.environ.get("HF_TOKEN")
app = FastAPI()
@app.on_event("startup")
async def load_model():
global tokenizer, encoder_sess, decoder_sess
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", cache_dir="/tmp")
encoder_path = hf_hub_download(
repo_id="viswadarshan06/mt5-tamil-paraphrase-onnx",
filename="encoder_model.onnx", cache_dir="/tmp",
token=hf_token
)
decoder_path = hf_hub_download(
repo_id="viswadarshan06/mt5-tamil-paraphrase-onnx",
filename="decoder_model.onnx", cache_dir="/tmp",
token=hf_token
)
encoder_sess = ort.InferenceSession(encoder_path)
decoder_sess = ort.InferenceSession(decoder_path)
@app.post("/generate")
async def generate(request: Request):
data = await request.json()
input_text = data.get("text", "")
# Encode input
enc_inputs = tokenizer(input_text, return_tensors="np")
input_ids = enc_inputs["input_ids"].astype(np.int64)
attention_mask = enc_inputs["attention_mask"].astype(np.int64)
encoder_outputs = encoder_sess.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
# Prepare decoder input: Start with <pad> or <bos>
decoder_input_ids = np.array([[tokenizer.pad_token_id]], dtype=np.int64)
# Run decoding step-by-step (greedy loop)
output_ids = []
for _ in range(64): # max length = 30
decoder_inputs = {
"input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_outputs[0],
"encoder_attention_mask": attention_mask
}
decoder_outputs = decoder_sess.run(None, decoder_inputs)
next_token_logits = decoder_outputs[0][:, -1, :] # shape (1, vocab_size)
next_token_id = np.argmax(next_token_logits, axis=-1)[0]
if next_token_id == tokenizer.eos_token_id:
break
output_ids.append(next_token_id)
decoder_input_ids = np.append(decoder_input_ids, [[next_token_id]], axis=1)
# Decode output tokens
output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
return {
"input_text": input_text,
"generated_paraphrase": output_text,
"num_tokens": len(output_ids)
}