Spaces:
Runtime error
Runtime error
allow model selection
Browse files
app.py
CHANGED
|
@@ -72,9 +72,9 @@ def update_labels(show_labels):
|
|
| 72 |
updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
|
| 73 |
return updated_gallery
|
| 74 |
|
| 75 |
-
def generate_images_wrapper(prompts, pw, show_labels):
|
| 76 |
global image_paths_global, image_labels_global
|
| 77 |
-
image_paths, image_labels = generate_images(prompts, pw)
|
| 78 |
image_paths_global = image_paths
|
| 79 |
|
| 80 |
# store this as a global so we can handle toggle state
|
|
@@ -118,7 +118,7 @@ def download_all_images():
|
|
| 118 |
|
| 119 |
return zip_path
|
| 120 |
|
| 121 |
-
def generate_images(prompts, pw):
|
| 122 |
# Check for a valid password
|
| 123 |
|
| 124 |
if pw != os.getenv("PW"):
|
|
@@ -151,7 +151,7 @@ def generate_images(prompts, pw):
|
|
| 151 |
|
| 152 |
try:
|
| 153 |
#what model to use?
|
| 154 |
-
model = ImageGenerationModel.from_pretrained(
|
| 155 |
response = model.generate_images(
|
| 156 |
prompt=prompt_w_challenge,
|
| 157 |
number_of_images=1,
|
|
@@ -168,9 +168,11 @@ def generate_images(prompts, pw):
|
|
| 168 |
response[0].save(filename)
|
| 169 |
image_label = f"{i+1}: {text}"
|
| 170 |
|
|
|
|
|
|
|
| 171 |
try:
|
| 172 |
# Save the prompt, model, image URL, generation time and creation timestamp to the database
|
| 173 |
-
mongo_collection.insert_one({"user": user_initials, "text": text, "model":
|
| 174 |
except Exception as e:
|
| 175 |
print(e)
|
| 176 |
raise gr.Error("An error occurred while saving the prompt to the database.")
|
|
@@ -195,7 +197,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 195 |
|
| 196 |
gr.Markdown("# <center>Prompt de Resistance Vertex Imagen</center>")
|
| 197 |
|
| 198 |
-
pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service"
|
| 199 |
|
| 200 |
#instructions
|
| 201 |
with gr.Accordion("Instructions & Tips",label="instructions",open=False):
|
|
@@ -212,9 +214,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 212 |
#prompts
|
| 213 |
with gr.Accordion("Prompts",label="Prompts",open=True):
|
| 214 |
text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
|
|
|
|
|
|
|
|
|
|
| 215 |
with gr.Row():
|
| 216 |
btn = gr.Button("Generate Images")
|
| 217 |
|
|
|
|
| 218 |
#output
|
| 219 |
with gr.Accordion("Image Outputs",label="Image Outputs",open=True):
|
| 220 |
output_images = gr.Gallery(label="Image Outputs", elem_id="gallery-images", show_label=True, columns=[3], rows=[1], object_fit="contain", height="auto", allow_preview=False)
|
|
@@ -230,8 +236,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 230 |
|
| 231 |
#submissions
|
| 232 |
#trigger generation either through hitting enter in the text field, or clicking the button.
|
| 233 |
-
btn.click(fn=generate_images_wrapper, inputs=[text, pw, show_labels ], outputs=output_images, api_name=False)
|
| 234 |
-
text.submit(fn=generate_images_wrapper, inputs=[text, pw, show_labels], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
|
| 235 |
show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
|
| 236 |
|
| 237 |
#downloads
|
|
|
|
| 72 |
updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
|
| 73 |
return updated_gallery
|
| 74 |
|
| 75 |
+
def generate_images_wrapper(prompts, pw, show_labels,model):
|
| 76 |
global image_paths_global, image_labels_global
|
| 77 |
+
image_paths, image_labels = generate_images(prompts, pw,model)
|
| 78 |
image_paths_global = image_paths
|
| 79 |
|
| 80 |
# store this as a global so we can handle toggle state
|
|
|
|
| 118 |
|
| 119 |
return zip_path
|
| 120 |
|
| 121 |
+
def generate_images(prompts, pw,model_name):
|
| 122 |
# Check for a valid password
|
| 123 |
|
| 124 |
if pw != os.getenv("PW"):
|
|
|
|
| 151 |
|
| 152 |
try:
|
| 153 |
#what model to use?
|
| 154 |
+
model = ImageGenerationModel.from_pretrained(model_name)
|
| 155 |
response = model.generate_images(
|
| 156 |
prompt=prompt_w_challenge,
|
| 157 |
number_of_images=1,
|
|
|
|
| 168 |
response[0].save(filename)
|
| 169 |
image_label = f"{i+1}: {text}"
|
| 170 |
|
| 171 |
+
model_for_db = f"imagen-{model_name}"
|
| 172 |
+
|
| 173 |
try:
|
| 174 |
# Save the prompt, model, image URL, generation time and creation timestamp to the database
|
| 175 |
+
mongo_collection.insert_one({"user": user_initials, "text": text, "model": model_for_db, "image_url": image_url, "gen_time": gen_time, "timestamp": time.time(), "challenge": challenge})
|
| 176 |
except Exception as e:
|
| 177 |
print(e)
|
| 178 |
raise gr.Error("An error occurred while saving the prompt to the database.")
|
|
|
|
| 197 |
|
| 198 |
gr.Markdown("# <center>Prompt de Resistance Vertex Imagen</center>")
|
| 199 |
|
| 200 |
+
pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service")
|
| 201 |
|
| 202 |
#instructions
|
| 203 |
with gr.Accordion("Instructions & Tips",label="instructions",open=False):
|
|
|
|
| 214 |
#prompts
|
| 215 |
with gr.Accordion("Prompts",label="Prompts",open=True):
|
| 216 |
text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
|
| 217 |
+
|
| 218 |
+
model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
|
| 219 |
+
|
| 220 |
with gr.Row():
|
| 221 |
btn = gr.Button("Generate Images")
|
| 222 |
|
| 223 |
+
|
| 224 |
#output
|
| 225 |
with gr.Accordion("Image Outputs",label="Image Outputs",open=True):
|
| 226 |
output_images = gr.Gallery(label="Image Outputs", elem_id="gallery-images", show_label=True, columns=[3], rows=[1], object_fit="contain", height="auto", allow_preview=False)
|
|
|
|
| 236 |
|
| 237 |
#submissions
|
| 238 |
#trigger generation either through hitting enter in the text field, or clicking the button.
|
| 239 |
+
btn.click(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model ], outputs=output_images, api_name=False)
|
| 240 |
+
text.submit(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
|
| 241 |
show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
|
| 242 |
|
| 243 |
#downloads
|