Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from mario_gpt.dataset import MarioDataset | |
| from mario_gpt.prompter import Prompter | |
| from mario_gpt.lm import MarioLM | |
| from mario_gpt.utils import view_level, convert_level_to_png | |
| mario_lm = MarioLM() | |
| device = torch.device('cuda') | |
| mario_lm = mario_lm.to(device) | |
| TILE_DIR = "data/tiles" | |
| def update(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""): | |
| if prompt == "": | |
| prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" | |
| print(f"Using prompt: {prompt}") | |
| prompts = [prompt] | |
| generated_level = mario_lm.sample( | |
| prompts=prompts, | |
| num_steps=level_size, | |
| temperature=temperature, | |
| use_tqdm=True | |
| ) | |
| img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] | |
| return img | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Demo for ['MarioGPT: Open-Ended Text2Level Generation through Large Language Models'](https://github.com/shyamsn97/mario-gpt). Enter a text prompt or select parameters from below!") | |
| text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation', or compose your prompts below") | |
| with gr.Accordion(label="Compose your prompt", open=False): | |
| pipes = gr.Radio(["no", "little", "some", "many"], label="pipes") | |
| enemies = gr.Radio(["no", "little", "some", "many"], label="enemies") | |
| blocks = gr.Radio(["little", "some", "many"], label="blocks") | |
| elevation = gr.Radio(["low", "high"], label="elevation") | |
| temperature = gr.Number(value=2.0, label="temperature: Increase these for more stochastic, but lower quality, generations") | |
| level_size = gr.Number(value=1399, precision=0, label="level_size") | |
| btn = gr.Button("Generate level") | |
| level_image = gr.Image() | |
| btn.click(fn=update, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=level_image) | |
| gr.Examples( | |
| examples=[ | |
| ["many", "many", "some", "high"], | |
| ["no", "some", "many", "high", 2.0], | |
| ["many", "many", "little", "low", 2.0], | |
| ["no", "no", "many", "high", 2.4], | |
| ], | |
| inputs=[pipes, enemies, blocks, elevation], | |
| outputs=level_image, | |
| fn=update, | |
| cache_examples=True, | |
| ) | |
| demo.launch() | |