File size: 3,644 Bytes
7f467ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56611b9
7f467ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
import time
import requests
import json
import os
import tempfile
import logging
from PIL import Image
from io import BytesIO

odnapi = os.getenv("odnapi_url")
fetapi = os.getenv("fetapi_url")
auth_token = os.getenv("auth_token")

# Setup a logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def fetch_image(url):
    try:
        response = requests.get(url)
        return Image.open(BytesIO(response.content))
    except requests.exceptions.RequestException as e:
        logger.error(f"Failed to fetch image")
        raise

def split_image(img):
    width, height = img.size
    width_cut = width // 2
    height_cut = height // 2
    return [
        img.crop((0, 0, width_cut, height_cut)),
        img.crop((width_cut, 0, width, height_cut)),
        img.crop((0, height_cut, width_cut, height)),
        img.crop((width_cut, height_cut, width, height))
    ]

def save_image(img, suffix='.png'):
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        img.save(tmp, 'PNG')
        return tmp.name

def download_and_split_image(url):
    img = fetch_image(url)
    images = split_image(img)
    return [save_image(i) for i in images]

def niji_api(prompt):
    try:
        response = requests.post(fetapi, headers={'Content-Type': 'application/json'}, data=json.dumps({'msg': prompt}))
        response.raise_for_status()  # Check for HTTP errors.
    except requests.exceptions.RequestException as e:
        logger.error(f"Failed to make POST request")
        raise ValueError("Invalid Response") 
    data = response.json()
    message_id = data['messageId']
    progress = 0
    while progress < 100:
        try:
            response = requests.get(f'{odnapi}/message/{message_id}', headers={'Authorization': auth_token})
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            logger.warning(f"Failure in getting message response")
            continue
        data = response.json()
        progress = data.get('progress', 0)
        if progress_image_url:= data.get('progressImageUrl'):
            yield [(img, f"{progress}% done") for img in download_and_split_image(progress_image_url)]
        time.sleep(1)
    # Process the final image urls
    image_urls = data['response']['imageUrls']
    yield [(iurl, f"image {idx}/4") for idx, iurl in enumerate(image_urls)]

with gr.Blocks() as demo:
    gr.HTML('''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
  <div style="
        display: inline-flex;
        gap: 0.8rem;
        font-size: 1.75rem;
        justify-content: center;
        margin-bottom: 10px;
      ">
    <h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
      MidJourney / NijiJourney Playground 🎨
    </h1>
  </div>
  <div>
    <p style="align-items: center; margin-bottom: 7px;">
      Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, add a text prompt for what you want to draw
  </div>
</div>
    ''')  
    with gr.Column(variant="panel"):
        with gr.Row():
            text = gr.Textbox(
                label="Enter your prompt",
                value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5",
                max_lines=1,
                container=False,
            )
            btn = gr.Button("Generate image", scale=0)
        gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096")
    btn.click(niji_api, text, gallery)

demo.launch(debug=True, enable_queue=True)