from fastapi import FastAPI, Request from huggingface_hub import hf_hub_download import onnxruntime as ort from transformers import AutoTokenizer import numpy as np import os from huggingface_hub import hf_hub_download, HfApi hf_token = os.environ.get("HF_TOKEN") app = FastAPI() @app.on_event("startup") async def load_model(): global tokenizer, encoder_sess, decoder_sess tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", cache_dir="/tmp") encoder_path = hf_hub_download( repo_id="viswadarshan06/mt5-tamil-paraphrase-onnx", filename="encoder_model.onnx", cache_dir="/tmp", token=hf_token ) decoder_path = hf_hub_download( repo_id="viswadarshan06/mt5-tamil-paraphrase-onnx", filename="decoder_model.onnx", cache_dir="/tmp", token=hf_token ) encoder_sess = ort.InferenceSession(encoder_path) decoder_sess = ort.InferenceSession(decoder_path) @app.post("/generate") async def generate(request: Request): data = await request.json() input_text = data.get("text", "") # Encode input enc_inputs = tokenizer(input_text, return_tensors="np") input_ids = enc_inputs["input_ids"].astype(np.int64) attention_mask = enc_inputs["attention_mask"].astype(np.int64) encoder_outputs = encoder_sess.run(None, { "input_ids": input_ids, "attention_mask": attention_mask }) # Prepare decoder input: Start with or decoder_input_ids = np.array([[tokenizer.pad_token_id]], dtype=np.int64) # Run decoding step-by-step (greedy loop) output_ids = [] for _ in range(64): # max length = 30 decoder_inputs = { "input_ids": decoder_input_ids, "encoder_hidden_states": encoder_outputs[0], "encoder_attention_mask": attention_mask } decoder_outputs = decoder_sess.run(None, decoder_inputs) next_token_logits = decoder_outputs[0][:, -1, :] # shape (1, vocab_size) next_token_id = np.argmax(next_token_logits, axis=-1)[0] if next_token_id == tokenizer.eos_token_id: break output_ids.append(next_token_id) decoder_input_ids = np.append(decoder_input_ids, [[next_token_id]], axis=1) # Decode output tokens output_text = tokenizer.decode(output_ids, skip_special_tokens=True) return { "input_text": input_text, "generated_paraphrase": output_text, "num_tokens": len(output_ids) }