Spaces:
Paused
Paused
import gradio as gr | |
import numpy as np | |
import time | |
import requests | |
import json | |
import os | |
import tempfile | |
import logging | |
from PIL import Image | |
from io import BytesIO | |
odnapi = os.getenv("odnapi_url") | |
fetapi = os.getenv("fetapi_url") | |
auth_token = os.getenv("auth_token") | |
# Setup a logger | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def fetch_image(url): | |
try: | |
response = requests.get(url) | |
return Image.open(BytesIO(response.content)) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to fetch image") | |
raise | |
def split_image(img): | |
width, height = img.size | |
width_cut = width // 2 | |
height_cut = height // 2 | |
return [ | |
img.crop((0, 0, width_cut, height_cut)), | |
img.crop((width_cut, 0, width, height_cut)), | |
img.crop((0, height_cut, width_cut, height)), | |
img.crop((width_cut, height_cut, width, height)) | |
] | |
def save_image(img, suffix='.png'): | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
img.save(tmp, 'PNG') | |
return tmp.name | |
def download_and_split_image(url): | |
img = fetch_image(url) | |
images = split_image(img) | |
return [save_image(i) for i in images] | |
def niji_api(prompt): | |
try: | |
response = requests.post(fetapi, headers={'Content-Type': 'application/json'}, data=json.dumps({'msg': prompt})) | |
response.raise_for_status() # Check for HTTP errors. | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to make POST request") | |
raise ValueError("Invalid Response") | |
data = response.json() | |
message_id = data['messageId'] | |
progress = 0 | |
while progress < 100: | |
try: | |
response = requests.get(f'{odnapi}/message/{message_id}', headers={'Authorization': auth_token}) | |
response.raise_for_status() | |
except requests.exceptions.RequestException as e: | |
logger.warning(f"Failure in getting message response") | |
continue | |
data = response.json() | |
progress = data.get('progress', 0) | |
if progress_image_url:= data.get('progressImageUrl'): | |
yield [(img, f"{progress}% done") for img in download_and_split_image(progress_image_url)] | |
time.sleep(1) | |
# Process the final image urls | |
image_urls = data['response']['imageUrls'] | |
yield [(iurl, f"image {idx}/4") for idx, iurl in enumerate(image_urls)] | |
with gr.Blocks() as demo: | |
gr.HTML(''' | |
<div style="text-align: center; max-width: 650px; margin: 0 auto;"> | |
<div style=" | |
display: inline-flex; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
justify-content: center; | |
margin-bottom: 10px; | |
"> | |
<h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;"> | |
MidJourney / NijiJourney Playground 🎨 | |
</h1> | |
</div> | |
<div> | |
<p style="align-items: center; margin-bottom: 7px;"> | |
Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, add a text prompt for what you want to draw | |
</div> | |
</div> | |
''') | |
with gr.Column(variant="panel"): | |
with gr.Row(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5", | |
max_lines=1, | |
container=False, | |
) | |
btn = gr.Button("Generate image", scale=0) | |
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096") | |
btn.click(niji_api, text, gallery) | |
demo.launch(debug=True, enable_queue=True) |