linoyts HF Staff commited on
Commit
6612d96
·
verified ·
1 Parent(s): 5b2d557

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ import spaces
5
+ import torch
6
+ import spaces
7
+ import random
8
+
9
+ from diffusers import AutoPipelineForText2Image
10
+ from PIL import Image
11
+
12
+
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+ MAX_IMAGE_SIZE = 2048
15
+
16
+ pipe = AutoPipelineForText2Image.from_pretrained(
17
+ "ostris/Flex.2-preview",
18
+ custom_pipeline="pipeline.py",
19
+ torch_dtype=torch.bfloat16,
20
+ ).to("cuda")
21
+
22
+ # def calculate_optimal_dimensions(image: Image.Image):
23
+ # # Extract the original dimensions
24
+ # original_width, original_height = image.size
25
+
26
+ # # Set constants
27
+ # MIN_ASPECT_RATIO = 9 / 16
28
+ # MAX_ASPECT_RATIO = 16 / 9
29
+ # FIXED_DIMENSION = 1024
30
+
31
+ # # Calculate the aspect ratio of the original image
32
+ # original_aspect_ratio = original_width / original_height
33
+
34
+ # # Determine which dimension to fix
35
+ # if original_aspect_ratio > 1: # Wider than tall
36
+ # width = FIXED_DIMENSION
37
+ # height = round(FIXED_DIMENSION / original_aspect_ratio)
38
+ # else: # Taller than wide
39
+ # height = FIXED_DIMENSION
40
+ # width = round(FIXED_DIMENSION * original_aspect_ratio)
41
+
42
+ # # Ensure dimensions are multiples of 8
43
+ # width = (width // 8) * 8
44
+ # height = (height // 8) * 8
45
+
46
+ # # Enforce aspect ratio limits
47
+ # calculated_aspect_ratio = width / height
48
+ # if calculated_aspect_ratio > MAX_ASPECT_RATIO:
49
+ # width = (height * MAX_ASPECT_RATIO // 8) * 8
50
+ # elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
51
+ # height = (width / MIN_ASPECT_RATIO // 8) * 8
52
+
53
+ # # Ensure width and height remain above the minimum dimensions
54
+ # width = max(width, 576) if width == FIXED_DIMENSION else width
55
+ # height = max(height, 576) if height == FIXED_DIMENSION else height
56
+
57
+ # return width, height
58
+
59
+ @spaces.GPU
60
+ def infer(edit_images, prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
61
+ image = edit_images["background"]
62
+ width, height = calculate_optimal_dimensions(image)
63
+ mask = edit_images["layers"][0]
64
+ if randomize_seed:
65
+ seed = random.randint(0, MAX_SEED)
66
+ image = pipe(
67
+ prompt=prompt,
68
+ # image=image,
69
+ # mask_image=mask,
70
+ inpaint_image=image,
71
+ inpaint_mask=mask,
72
+ height=height,
73
+ width=width,
74
+ guidance_scale=guidance_scale,
75
+ num_inference_steps=num_inference_steps,
76
+ generator=torch.Generator("cpu").manual_seed(seed)
77
+ ).images[0]
78
+ return image, seed
79
+
80
+ examples = [
81
+ "a tiny astronaut hatching from an egg on the moon",
82
+ "a cat holding a sign that says hello world",
83
+ "an anime illustration of a wiener schnitzel",
84
+ ]
85
+
86
+ css="""
87
+ #col-container {
88
+ margin: 0 auto;
89
+ max-width: 1000px;
90
+ }
91
+ """
92
+
93
+ with gr.Blocks(css=css) as demo:
94
+
95
+ with gr.Column(elem_id="col-container"):
96
+ gr.Markdown(f"""# Flex.2 Preview - inpaint
97
+ """)
98
+ with gr.Row():
99
+ with gr.Column():
100
+ edit_image = gr.ImageEditor(
101
+ label='Upload and draw mask for inpainting',
102
+ type='pil',
103
+ sources=["upload", "webcam"],
104
+ image_mode='RGB',
105
+ layers=False,
106
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
107
+ height=600
108
+ )
109
+ prompt = gr.Text(
110
+ label="Prompt",
111
+ show_label=False,
112
+ max_lines=1,
113
+ placeholder="Enter your prompt",
114
+ container=False,
115
+ )
116
+ run_button = gr.Button("Run")
117
+
118
+ result = gr.Image(label="Result", show_label=False)
119
+
120
+ with gr.Accordion("Advanced Settings", open=False):
121
+
122
+ seed = gr.Slider(
123
+ label="Seed",
124
+ minimum=0,
125
+ maximum=MAX_SEED,
126
+ step=1,
127
+ value=0,
128
+ )
129
+
130
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
131
+
132
+ with gr.Row():
133
+
134
+ height = gr.Slider(64, 2048, value=512, step=64, label="Height")
135
+ width = gr.Slider(64, 2048, value=512, step=64, label="Width")
136
+
137
+ with gr.Row():
138
+
139
+ guidance_scale = gr.Slider(0.0, 20.0, value=3.5, step=0.1, label="Guidance Scale")
140
+
141
+ num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
142
+
143
+ gr.on(
144
+ triggers=[run_button.click, prompt.submit],
145
+ fn = infer,
146
+ inputs = [edit_image, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
147
+ outputs = [result, seed]
148
+ )
149
+
150
+ demo.launch()