niji-playground / app.py
nyanko7's picture
Update app.py
56611b9
raw
history blame
3.64 kB
import gradio as gr
import numpy as np
import time
import requests
import json
import os
import tempfile
import logging
from PIL import Image
from io import BytesIO
odnapi = os.getenv("odnapi_url")
fetapi = os.getenv("fetapi_url")
auth_token = os.getenv("auth_token")
# Setup a logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def fetch_image(url):
try:
response = requests.get(url)
return Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to fetch image")
raise
def split_image(img):
width, height = img.size
width_cut = width // 2
height_cut = height // 2
return [
img.crop((0, 0, width_cut, height_cut)),
img.crop((width_cut, 0, width, height_cut)),
img.crop((0, height_cut, width_cut, height)),
img.crop((width_cut, height_cut, width, height))
]
def save_image(img, suffix='.png'):
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
img.save(tmp, 'PNG')
return tmp.name
def download_and_split_image(url):
img = fetch_image(url)
images = split_image(img)
return [save_image(i) for i in images]
def niji_api(prompt):
try:
response = requests.post(fetapi, headers={'Content-Type': 'application/json'}, data=json.dumps({'msg': prompt}))
response.raise_for_status() # Check for HTTP errors.
except requests.exceptions.RequestException as e:
logger.error(f"Failed to make POST request")
raise ValueError("Invalid Response")
data = response.json()
message_id = data['messageId']
progress = 0
while progress < 100:
try:
response = requests.get(f'{odnapi}/message/{message_id}', headers={'Authorization': auth_token})
response.raise_for_status()
except requests.exceptions.RequestException as e:
logger.warning(f"Failure in getting message response")
continue
data = response.json()
progress = data.get('progress', 0)
if progress_image_url:= data.get('progressImageUrl'):
yield [(img, f"{progress}% done") for img in download_and_split_image(progress_image_url)]
time.sleep(1)
# Process the final image urls
image_urls = data['response']['imageUrls']
yield [(iurl, f"image {idx}/4") for idx, iurl in enumerate(image_urls)]
with gr.Blocks() as demo:
gr.HTML('''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div style="
display: inline-flex;
gap: 0.8rem;
font-size: 1.75rem;
justify-content: center;
margin-bottom: 10px;
">
<h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
MidJourney / NijiJourney Playground 🎨
</h1>
</div>
<div>
<p style="align-items: center; margin-bottom: 7px;">
Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, add a text prompt for what you want to draw
</div>
</div>
''')
with gr.Column(variant="panel"):
with gr.Row():
text = gr.Textbox(
label="Enter your prompt",
value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5",
max_lines=1,
container=False,
)
btn = gr.Button("Generate image", scale=0)
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096")
btn.click(niji_api, text, gallery)
demo.launch(debug=True, enable_queue=True)