matrix-game-2 / app.py
dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
f6ff21d verified
raw
history blame
2.55 kB
import torch
import spaces
import gradio as gr
from diffusers import ModularPipelineBlocks
from diffusers.utils import export_to_video, load_image
from diffusers.modular_pipelines import WanModularPipeline
class MatrixGameWanModularPipeline(WanModularPipeline):
"""
A ModularPipeline for MatrixGameWan.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
"""
@property
def default_sample_height(self):
return 44
@property
def default_sample_width(self):
return 80
blocks = ModularPipelineBlocks.from_pretrained("diffusers/matrix-game-2-modular", trust_remote_code=True)
image_to_action_block = ModularPipelineBlocks.from_pretrained("dn6/matrix-game-image-to-action", trust_remote_code=True)
blocks.sub_blocks.insert("image_to_action", image_to_action_block, 0)
pipe = MatrixGameWanModularPipeline(blocks, "diffusers-internal-dev/matrix-game-2-modular")
pipe.load_components(trust_remote_code=True, device_map="cuda", torch_dtype={"default": torch.bfloat16, "vae": torch.float32})
@spaces.GPU(300)
def predict(image, prompt):
output = pipe(image=image, prompt=prompt, num_frames=141)
return export_to_video(output.values['videos'][0], "output.mp4")
examples = []
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
#logo-title {
text-align: center;
}
#logo-title img {
width: 400px;
}
#edit_text{margin-top: -62px !important}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
input_images = gr.Gallery(label="Input Images",
show_label=False,
type="pil",
interactive=True)
with gr.Column():
result = gr.Gallery(label="Result", show_label=False, type="pil")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="describe the edit instruction",
container=False,
)
run_button = gr.Button("Run!", variant="primary")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=predict,
inputs=[
input_images,
prompt,
],
outputs=[result], # Added use_output_btn to outputs
)
if __name__ == "__main__":
demo.launch()