Realtime-FLUX / app.py
ginipick's picture
Update app.py
1497411 verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
import os
from diffusers import DiffusionPipeline
from custom_pipeline import FLUXPipelineWithIntermediateOutputs
from transformers import pipeline
# ๋ฒˆ์—ญ ๋ชจ๋ธ ์„ค์ • (CPU ์‚ฌ์šฉ)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
# ์ƒ์ˆ˜ ์ •์˜
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1
GPU_DURATION = 15 # GPU ํ• ๋‹น ์‹œ๊ฐ„ ์ถ•์†Œ
# ๋ชจ๋ธ ์„ค์ •
def setup_model():
dtype = torch.float16
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=dtype
).to("cuda")
return pipe
pipe = setup_model()
# ๋ฉ”๋‰ด ๋ ˆ์ด๋ธ”
labels = {
"Generated Image": "์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€",
"Prompt": "ํ”„๋กฌํ”„ํŠธ",
"Enhance Image": "์ด๋ฏธ์ง€ ํ–ฅ์ƒ",
"Advanced Options": "๊ณ ๊ธ‰ ์„ค์ •",
"Seed": "์‹œ๋“œ",
"Randomize Seed": "๋žœ๋ค ์‹œ๋“œ",
"Width": "๋„ˆ๋น„",
"Height": "๋†’์ด",
"Inference Steps": "์ถ”๋ก  ๋‹จ๊ณ„",
"Inspiration Gallery": "์˜๊ฐ ๊ฐค๋Ÿฌ๋ฆฌ"
}
def translate_if_korean(text):
"""ํ•œ๊ธ€ ํ…์ŠคํŠธ๋ฅผ ์˜์–ด๋กœ ์•ˆ์ „ํ•˜๊ฒŒ ๋ฒˆ์—ญ"""
try:
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
return translator(text)[0]['translation_text']
return text
except Exception as e:
print(f"๋ฒˆ์—ญ ์˜ค๋ฅ˜: {e}")
return text
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜
@spaces.GPU(duration=GPU_DURATION)
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT,
randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
try:
# ์ž…๋ ฅ๊ฐ’ ๊ฒ€์ฆ
if not isinstance(seed, (int, type(None))):
seed = None
randomize_seed = True
prompt = translate_if_korean(prompt)
if seed is None or randomize_seed:
seed = random.randint(0, MAX_SEED)
# ํฌ๊ธฐ ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ
width = min(max(256, width), MAX_IMAGE_SIZE)
height = min(max(256, height), MAX_IMAGE_SIZE)
generator = torch.Generator().manual_seed(seed)
start_time = time.time()
with torch.cuda.amp.autocast():
for img in pipe.generate_images(
prompt=prompt,
guidance_scale=0,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
):
latency = f"์ฒ˜๋ฆฌ ์‹œ๊ฐ„: {(time.time()-start_time):.2f} ์ดˆ"
# CUDA ์บ์‹œ ์ •๋ฆฌ
if torch.cuda.is_available():
torch.cuda.empty_cache()
yield img, seed, latency
except Exception as e:
print(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
yield None, seed, f"์˜ค๋ฅ˜: {str(e)}"
# ์˜ˆ์ œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ
def generate_example_image(prompt):
try:
return next(generate_image(prompt, randomize_seed=True))
except Exception as e:
print(f"์˜ˆ์ œ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
return None, None, f"์˜ค๋ฅ˜: {str(e)}"
# Example prompts
examples = [
"๋น„๋„ˆ ์Šˆ๋‹ˆ์ฒผ์˜ ์• ๋‹ˆ๋ฉ”์ด์…˜ ์ผ๋Ÿฌ์ŠคํŠธ๋ ˆ์ด์…˜",
"A steampunk owl wearing Victorian-era clothing and reading a mechanical book",
"A floating island made of books with waterfalls of knowledge cascading down",
"A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city",
"An ancient temple being reclaimed by nature, with robots performing archaeology",
"A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
]
css = """
footer {
visibility: hidden;
}
"""
def create_snow_effect():
# CSS ์Šคํƒ€์ผ ์ •์˜
snow_css = """
@keyframes snowfall {
0% {
transform: translateY(-10vh) translateX(0);
opacity: 1;
}
100% {
transform: translateY(100vh) translateX(100px);
opacity: 0.3;
}
}
.snowflake {
position: fixed;
color: white;
font-size: 1.5em;
user-select: none;
z-index: 1000;
pointer-events: none;
animation: snowfall linear infinite;
}
"""
# JavaScript ์ฝ”๋“œ ์ •์˜
snow_js = """
function createSnowflake() {
const snowflake = document.createElement('div');
snowflake.innerHTML = 'โ„';
snowflake.className = 'snowflake';
snowflake.style.left = Math.random() * 100 + 'vw';
snowflake.style.animationDuration = Math.random() * 3 + 2 + 's';
snowflake.style.opacity = Math.random();
document.body.appendChild(snowflake);
setTimeout(() => {
snowflake.remove();
}, 5000);
}
setInterval(createSnowflake, 200);
"""
# CSS์™€ JavaScript๋ฅผ ๊ฒฐํ•ฉํ•œ HTML
snow_html = f"""
<style>
{snow_css}
</style>
<script>
{snow_js}
</script>
"""
return gr.HTML(snow_html)
# Gradio ์•ฑ์—์„œ ์‚ฌ์šฉํ•  ๋•Œ:
# with app: ์•„๋ž˜์—
# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
create_snow_effect()
with gr.Column(elem_id="app-container"):
with gr.Row():
with gr.Column(scale=3):
result = gr.Image(label=labels["Generated Image"],
show_label=False,
interactive=False)
with gr.Column(scale=1):
prompt = gr.Text(
label=labels["Prompt"],
placeholder="์ƒ์„ฑํ•˜๊ณ  ์‹ถ์€ ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”...",
lines=3,
show_label=False,
container=False,
)
enhanceBtn = gr.Button(f"๐Ÿš€ {labels['Enhance Image']}")
with gr.Column(labels["Advanced Options"]):
with gr.Row():
latency = gr.Text(show_label=False)
with gr.Row():
seed = gr.Number(
label=labels["Seed"],
value=42,
precision=0,
minimum=0,
maximum=MAX_SEED
)
randomize_seed = gr.Checkbox(
label=labels["Randomize Seed"],
value=True
)
with gr.Row():
width = gr.Slider(
label=labels["Width"],
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_WIDTH
)
height = gr.Slider(
label=labels["Height"],
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=DEFAULT_HEIGHT
)
num_inference_steps = gr.Slider(
label=labels["Inference Steps"],
minimum=1,
maximum=4,
step=1,
value=DEFAULT_INFERENCE_STEPS
)
with gr.Row():
gr.Markdown(f"### ๐ŸŒŸ {labels['Inspiration Gallery']}")
with gr.Row():
gr.Examples(
examples=examples,
fn=generate_example_image,
inputs=[prompt],
outputs=[result, seed],
cache_examples=False
)
# ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
def validated_generate(*args):
try:
return next(generate_image(*args))
except Exception as e:
print(f"๊ฒ€์ฆ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
return None, args[1], f"์˜ค๋ฅ˜: {str(e)}"
enhanceBtn.click(
fn=generate_image,
inputs=[prompt, seed, width, height],
outputs=[result, seed, latency],
show_progress="hidden",
show_api=False,
queue=False
)
gr.on(
triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
fn=validated_generate,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="hidden",
show_api=False,
trigger_mode="always_last",
queue=False
)
if __name__ == "__main__":
demo.launch()