heaversm commited on
Commit
de6d38a
·
1 Parent(s): 1e53245

allow model selection

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -72,9 +72,9 @@ def update_labels(show_labels):
72
  updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
73
  return updated_gallery
74
 
75
- def generate_images_wrapper(prompts, pw, show_labels):
76
  global image_paths_global, image_labels_global
77
- image_paths, image_labels = generate_images(prompts, pw)
78
  image_paths_global = image_paths
79
 
80
  # store this as a global so we can handle toggle state
@@ -118,7 +118,7 @@ def download_all_images():
118
 
119
  return zip_path
120
 
121
- def generate_images(prompts, pw):
122
  # Check for a valid password
123
 
124
  if pw != os.getenv("PW"):
@@ -151,7 +151,7 @@ def generate_images(prompts, pw):
151
 
152
  try:
153
  #what model to use?
154
- model = ImageGenerationModel.from_pretrained("imagegeneration@002")
155
  response = model.generate_images(
156
  prompt=prompt_w_challenge,
157
  number_of_images=1,
@@ -168,9 +168,11 @@ def generate_images(prompts, pw):
168
  response[0].save(filename)
169
  image_label = f"{i+1}: {text}"
170
 
 
 
171
  try:
172
  # Save the prompt, model, image URL, generation time and creation timestamp to the database
173
- mongo_collection.insert_one({"user": user_initials, "text": text, "model": "imagen", "image_url": image_url, "gen_time": gen_time, "timestamp": time.time(), "challenge": challenge})
174
  except Exception as e:
175
  print(e)
176
  raise gr.Error("An error occurred while saving the prompt to the database.")
@@ -195,7 +197,7 @@ with gr.Blocks(css=css) as demo:
195
 
196
  gr.Markdown("# <center>Prompt de Resistance Vertex Imagen</center>")
197
 
198
- pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service", value="REBEL.pier6moment")
199
 
200
  #instructions
201
  with gr.Accordion("Instructions & Tips",label="instructions",open=False):
@@ -212,9 +214,13 @@ with gr.Blocks(css=css) as demo:
212
  #prompts
213
  with gr.Accordion("Prompts",label="Prompts",open=True):
214
  text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
 
 
 
215
  with gr.Row():
216
  btn = gr.Button("Generate Images")
217
 
 
218
  #output
219
  with gr.Accordion("Image Outputs",label="Image Outputs",open=True):
220
  output_images = gr.Gallery(label="Image Outputs", elem_id="gallery-images", show_label=True, columns=[3], rows=[1], object_fit="contain", height="auto", allow_preview=False)
@@ -230,8 +236,8 @@ with gr.Blocks(css=css) as demo:
230
 
231
  #submissions
232
  #trigger generation either through hitting enter in the text field, or clicking the button.
233
- btn.click(fn=generate_images_wrapper, inputs=[text, pw, show_labels ], outputs=output_images, api_name=False)
234
- text.submit(fn=generate_images_wrapper, inputs=[text, pw, show_labels], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
235
  show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
236
 
237
  #downloads
 
72
  updated_gallery = [(path, label if show_labels else "") for path, label in zip(image_paths_global, image_labels_global)]
73
  return updated_gallery
74
 
75
+ def generate_images_wrapper(prompts, pw, show_labels,model):
76
  global image_paths_global, image_labels_global
77
+ image_paths, image_labels = generate_images(prompts, pw,model)
78
  image_paths_global = image_paths
79
 
80
  # store this as a global so we can handle toggle state
 
118
 
119
  return zip_path
120
 
121
+ def generate_images(prompts, pw,model_name):
122
  # Check for a valid password
123
 
124
  if pw != os.getenv("PW"):
 
151
 
152
  try:
153
  #what model to use?
154
+ model = ImageGenerationModel.from_pretrained(model_name)
155
  response = model.generate_images(
156
  prompt=prompt_w_challenge,
157
  number_of_images=1,
 
168
  response[0].save(filename)
169
  image_label = f"{i+1}: {text}"
170
 
171
+ model_for_db = f"imagen-{model_name}"
172
+
173
  try:
174
  # Save the prompt, model, image URL, generation time and creation timestamp to the database
175
+ mongo_collection.insert_one({"user": user_initials, "text": text, "model": model_for_db, "image_url": image_url, "gen_time": gen_time, "timestamp": time.time(), "challenge": challenge})
176
  except Exception as e:
177
  print(e)
178
  raise gr.Error("An error occurred while saving the prompt to the database.")
 
197
 
198
  gr.Markdown("# <center>Prompt de Resistance Vertex Imagen</center>")
199
 
200
+ pw = gr.Textbox(label="Password", type="password", placeholder="Enter the password to unlock the service")
201
 
202
  #instructions
203
  with gr.Accordion("Instructions & Tips",label="instructions",open=False):
 
214
  #prompts
215
  with gr.Accordion("Prompts",label="Prompts",open=True):
216
  text = gr.Textbox(label="What do you want to create?", placeholder="Enter your text and then click on the \"Image Generate\" button")
217
+
218
+ model = gr.Dropdown(choices=["imagegeneration@002", "imagegeneration@005"], label="Model", value="imagegeneration@005")
219
+
220
  with gr.Row():
221
  btn = gr.Button("Generate Images")
222
 
223
+
224
  #output
225
  with gr.Accordion("Image Outputs",label="Image Outputs",open=True):
226
  output_images = gr.Gallery(label="Image Outputs", elem_id="gallery-images", show_label=True, columns=[3], rows=[1], object_fit="contain", height="auto", allow_preview=False)
 
236
 
237
  #submissions
238
  #trigger generation either through hitting enter in the text field, or clicking the button.
239
+ btn.click(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model ], outputs=output_images, api_name=False)
240
+ text.submit(fn=generate_images_wrapper, inputs=[text, pw, show_labels,model], outputs=output_images, api_name="generate_image") # Generate an api endpoint in Gradio / HF
241
  show_labels.change(fn=update_labels, inputs=[show_labels], outputs=[output_images])
242
 
243
  #downloads