VAREdit-8B-512 / app.py
cai-qi's picture
Update app.py
1e20cf4 verified
"""
Gradio app for VAREdit image editing model.
Provides web interface for editing images with text instructions.
"""
import spaces
import gradio as gr
import os
import tempfile
from PIL import Image
import logging
from infer import load_model, generate_image
import os
from huggingface_hub import snapshot_download
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@spaces.GPU
def edit_image(
input_image: Image.Image,
instruction: str,
cfg: float = 4.0,
tau: float = 0.5,
seed: int = -1
) -> Image.Image:
"""Edit image based on text instruction."""
if input_image is None:
raise gr.Error("Please upload an image")
if not instruction.strip():
raise gr.Error("Please provide an editing instruction")
try:
# Load model if needed
# Save input image to temporary file
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
input_image.save(tmp_file.name, 'JPEG')
temp_path = tmp_file.name
try:
# Generate edited image
result_image = generate_image(
model_components,
temp_path,
instruction,
cfg=cfg,
tau=tau,
seed=seed if seed != -1 else None
)
return result_image
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"Image editing failed: {e}")
raise gr.Error(f"Failed to edit image: {str(e)}")
# Create Gradio interface
def create_interface():
with gr.Blocks(title="VAREdit Image Editor") as demo:
gr.Markdown("# VAREdit Image Editor")
gr.Markdown("Edit images using natural language instructions with the VAREdit model.")
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Input Image",
)
instruction = gr.Textbox(
label="Editing Instruction",
placeholder="e.g., 'Remove glasses from this person', 'Change the sky to sunset', 'Add a hat'",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
cfg = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Scale (Guidance Strength)"
)
tau = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.01,
label="Temperature (Tau)"
)
seed = gr.Number(
value=-1,
label="Seed (-1 for random)",
precision=0
)
edit_btn = gr.Button("Edit Image", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(
label="Edited Image",
)
# Example images and instructions
gr.Markdown("## Examples")
gr.Examples(
examples=[
["assets/test_3.jpg", "change shirt to a black-and-white striped Breton top, add a red beret, set the background to an artist's loft with a window view of the Eiffel Tower"],
["assets/test.jpg", "Add glasses to this girl and change hair color to red"],
["assets/test_1.jpg", "replace all the bullets with shimmering, multi-colored butterflies."],
["assets/test_4.jpg", "Set the scene against a dark, blurred-out server room, make all text and arrows glow with a vibrant cyan light"],
],
inputs=[input_image, instruction],
outputs=output_image,
fn=lambda img, inst: edit_image(img, inst),
cache_examples=False
)
# Set up event handler
edit_btn.click(
fn=edit_image,
inputs=[input_image, instruction, cfg, tau, seed],
outputs=output_image
)
return demo
model_path = "HiDream-ai/VAREdit"
snapshot_download(repo_id=model_path, max_workers=16,repo_type="model",
local_dir=model_path)
model_components = load_model("HiDream-ai/VAREdit", "HiDream-ai/VAREdit/8B-512.pth", "8B", 512)
if __name__ == "__main__":
demo = create_interface()
demo.queue(max_size=50, default_concurrency_limit=16).launch(show_api=False)