|
from fastapi import FastAPI, Request |
|
from huggingface_hub import hf_hub_download |
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
import numpy as np |
|
import os |
|
from huggingface_hub import hf_hub_download, HfApi |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
app = FastAPI() |
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
global tokenizer, encoder_sess, decoder_sess |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", cache_dir="/tmp") |
|
|
|
encoder_path = hf_hub_download( |
|
repo_id="viswadarshan06/mt5-tamil-paraphrase-onnx", |
|
filename="encoder_model.onnx", cache_dir="/tmp", |
|
token=hf_token |
|
) |
|
decoder_path = hf_hub_download( |
|
repo_id="viswadarshan06/mt5-tamil-paraphrase-onnx", |
|
filename="decoder_model.onnx", cache_dir="/tmp", |
|
token=hf_token |
|
) |
|
|
|
encoder_sess = ort.InferenceSession(encoder_path) |
|
decoder_sess = ort.InferenceSession(decoder_path) |
|
|
|
@app.post("/generate") |
|
async def generate(request: Request): |
|
data = await request.json() |
|
input_text = data.get("text", "") |
|
|
|
|
|
enc_inputs = tokenizer(input_text, return_tensors="np") |
|
input_ids = enc_inputs["input_ids"].astype(np.int64) |
|
attention_mask = enc_inputs["attention_mask"].astype(np.int64) |
|
|
|
encoder_outputs = encoder_sess.run(None, { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask |
|
}) |
|
|
|
|
|
decoder_input_ids = np.array([[tokenizer.pad_token_id]], dtype=np.int64) |
|
|
|
|
|
output_ids = [] |
|
for _ in range(64): |
|
decoder_inputs = { |
|
"input_ids": decoder_input_ids, |
|
"encoder_hidden_states": encoder_outputs[0], |
|
"encoder_attention_mask": attention_mask |
|
} |
|
decoder_outputs = decoder_sess.run(None, decoder_inputs) |
|
next_token_logits = decoder_outputs[0][:, -1, :] |
|
next_token_id = np.argmax(next_token_logits, axis=-1)[0] |
|
if next_token_id == tokenizer.eos_token_id: |
|
break |
|
output_ids.append(next_token_id) |
|
decoder_input_ids = np.append(decoder_input_ids, [[next_token_id]], axis=1) |
|
|
|
|
|
output_text = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
|
return { |
|
"input_text": input_text, |
|
"generated_paraphrase": output_text, |
|
"num_tokens": len(output_ids) |
|
} |
|
|