Akjava's picture
resize back 32
1686e90
raw
history blame
5.85 kB
import spaces
import gradio as gr
import re
from PIL import Image
import os
import numpy as np
import torch
from diffusers import FluxImg2ImgPipeline
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
def sanitize_prompt(prompt):
# Allow only alphanumeric characters, spaces, and basic punctuation
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
sanitized_prompt = allowed_chars.sub("", prompt)
return sanitized_prompt
def convert_to_fit_size(original_width_and_height, maximum_size = 2048):
width, height =original_width_and_height
if width <= maximum_size and height <= maximum_size:
return width,height
if width > height:
scaling_factor = maximum_size / width
else:
scaling_factor = maximum_size / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return new_width, new_height
def adjust_to_multiple_of_32(width: int, height: int):
width = width - (width % 32)
height = height - (height % 32)
return width, height
@spaces.GPU(duration=120)
def process_images(image,prompt="a girl",strength=0.75,seed=0,inference_step=4,progress=gr.Progress(track_tqdm=True)):
#print("start process_images")
progress(0, desc="Starting")
def process_img2img(image,prompt="a person",strength=0.75,seed=0,num_inference_steps=4):
#print("start process_img2img")
if image == None:
print("empty input image returned")
return None
generators = []
generator = torch.Generator(device).manual_seed(seed)
generators.append(generator)
fit_width,fit_height = convert_to_fit_size(image.size)
#print(f"fit {width}x{height}")
width,height = adjust_to_multiple_of_32(fit_width,fit_height)
#print(f"multiple {width}x{height}")
image = image.resize((width, height), Image.LANCZOS)
#mask_image = mask_image.resize((width, height), Image.NEAREST)
# more parameter see https://huggingface.co/docs/diffusers/api/pipelines/flux#diffusers.FluxInpaintPipeline
#print(prompt)
output = pipe(prompt=prompt, image=image,generator=generator,strength=strength,width=width,height=height
,guidance_scale=0,num_inference_steps=num_inference_steps,max_sequence_length=256)
pil_image = Image.fromarray(output.images[0])
new_width,new_height = pil_image.size
# resize back multiple of 32
if (new_width!=fit_width) or (new_height!=fit_height):
resized_image= pil_image.resize(fit_width,fit_height,Image.LANCZOS)
return np.array(resized_image)
return np.array(pil_image)
output = process_img2img(image,prompt,strength,seed,inference_step)
#print("end process_images")
return output
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css="""
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap:10px
}
.image {
width: 128px;
height: 128px;
object-fit: cover;
}
.text {
font-size: 16px;
}
"""
with gr.Blocks(css=css, elem_id="demo-container") as demo:
with gr.Column():
gr.HTML(read_file("demo_header.html"))
gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
with gr.Row(elem_id="prompt-container", equal_height=False):
with gr.Row():
prompt = gr.Textbox(label="Prompt",value="a women",placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
btn = gr.Button("Img2Img", elem_id="run_button",variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row( equal_height=True):
strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
seed = gr.Number(value=100, minimum=0, step=1, label="seed")
inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
id_input=gr.Text(label="Name", visible=False)
with gr.Column():
image_out = gr.Image(height=800,sources=[],label="Output", elem_id="output-img",format="jpg")
gr.Examples(
examples=[
["examples/draw_input.jpg", "examples/draw_output.jpg","a women ,eyes closed,mouth opened"],
["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg","a women ,eyes closed,mouth opened"],
["examples/gimp_input.jpg", "examples/gimp_output.jpg","a women ,hand on neck"],
["examples/inpaint_input.jpg", "examples/inpaint_output.jpg","a women ,hand on neck"]
]
,
inputs=[image,image_out,prompt],
)
gr.HTML(
gr.HTML(read_file("demo_footer.html"))
)
gr.on(
triggers=[btn.click, prompt.submit],
fn = process_images,
inputs = [image,prompt,strength,seed,inference_step],
outputs = [image_out]
)
if __name__ == "__main__":
demo.launch()