Getting "!!!!" at output only

#3
by Ranjit - opened
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Constants
HF_TOKEN = "xxxxx"

# torch._dynamo.config.suppress_errors = True
# torch._dynamo.disable()

# Load model and tokenizer
model_id = "ai4bharat/IndicTrans3-beta"
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

# List of target languages
LANGUAGES = [
    "Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati", "Kannada",
    "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", "Santali", "Kashmiri",
    "Nepali", "Sindhi", "Konkani", "Dogri", "Manipuri", "Bodo",
]

# Simple formatting function
def format_message_for_translation(message, target_lang):
    return f"Translate the following text to {target_lang}: {message}"

def translate(
    message: str,
    chat_history: list[dict],
    target_language: str = "Hindi",
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> str:
    # Format input message
    conversation = []
    translation_request = format_message_for_translation(message, target_language)
    conversation.append({"role": "user", "content": translation_request})

    # Tokenize using chat template
    input_ids = tokenizer.apply_chat_template(
        conversation, return_tensors="pt", add_generation_prompt=True
    )

    MAX_INPUT_TOKEN_LENGTH = 4096
    # Trim input if too long
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]

    input_ids = input_ids.to(model.device)
    print(input_ids)
    
    # Generate output (non-streaming, blocking call)
    output_ids = model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,  # greedy decoding
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )

    # Decode output (skip input tokens)
    print(output_ids)
    generated_text = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
    print(generated_text)
    return generated_text

# Gradio UI
demo = gr.Interface(
    fn=translate,
    inputs=[
        gr.Textbox(label="Enter text to translate"),
        gr.Dropdown(choices=LANGUAGES, label="Target Language", value="Hindi"),
    ],
    outputs=gr.Textbox(label="Translated Output"),
    title="IndicTrans3-beta Translator",
)

if __name__ == "__main__":
    demo.launch(debug=True)

Input IDs:
tensor([[128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696,
25, 6790, 220, 2366, 18, 198, 15724, 2696, 25,
220, 1721, 12044, 220, 2366, 20, 271, 128009, 128006,
882, 128007, 271, 28573, 279, 2768, 1495, 311, 45080,
25, 22691, 11, 358, 1097, 1618, 13, 128009, 128006,
78191, 128007, 271]], device='cuda:0')

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set TRANSFORMERS_VERBOSITY=info for more details.

Output IDs:
tensor([[128000, 128006, 9125, ..., 0, 0, 0]],
device='cuda:0')

At the gradio output, i am just getting: "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"

is there some issue with tokens mapping or model mismatch? can anyone suggest?

Sign up or log in to comment