Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from tokenizers import Tokenizer | |
| from transformer.config import load_config | |
| from transformer.components.decoding import beam_search | |
| from transformer.transformer import Transformer | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CONFIG_PATH = "configs/config.yaml" | |
| MODEL_PATH = "model_checkpoint.pt" | |
| TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json" | |
| MAX_LEN = 128 | |
| config = load_config(CONFIG_PATH) | |
| tokenizer = Tokenizer.from_file(TOKENIZER_PATH) | |
| padding_idx = tokenizer.token_to_id("[PAD]") | |
| model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE) | |
| def translate(text: str, beam_size: int = 4) -> str: | |
| src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE) | |
| src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2) | |
| with torch.no_grad(): | |
| result_ids = beam_search( | |
| model, | |
| src_ids, | |
| src_mask, | |
| tokenizer, | |
| max_len=MAX_LEN, | |
| beam_size=beam_size, | |
| )[0] | |
| return tokenizer.decode(result_ids, skip_special_tokens=True) | |
| with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo: | |
| gr.Markdown( | |
| "# Transformer From Scratch Translation Demo\n" | |
| "Translate English to German using a custom Transformer model trained from scratch.\n\n" | |
| "**Note:** This model was trained on the WMT14 English-German news dataset. It works best on formal, news-style sentences and may not perform well on everyday informal or conversational text." | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="English Text", | |
| placeholder="Enter text to translate...", | |
| lines=3 | |
| ) | |
| beam_size = gr.Slider( | |
| minimum=1, maximum=8, step=1, value=4, label="Beam Size" | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="German Translation", | |
| lines=3, | |
| interactive=False, | |
| show_copy_button=True, | |
| show_label=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pass | |
| with gr.Column(scale=2, min_width=300, elem_id="centered-controls"): | |
| translate_btn = gr.Button("Translate") | |
| gr.Examples( | |
| examples=[ | |
| ["Hello, how are you?"], | |
| ["The weather is nice today."], | |
| ["I love machine learning."], | |
| ], | |
| inputs=[input_text] | |
| ) | |
| with gr.Column(scale=1): | |
| pass | |
| translate_btn.click( | |
| translate, | |
| inputs=[input_text, beam_size], | |
| outputs=[output_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |