Spaces:
Paused
Paused
from __future__ import annotations | |
from argparse import ArgumentParser | |
import datasets | |
import gradio as gr | |
import numpy as np | |
import openai | |
from dataset_creation.generate_txt_dataset import generate | |
def main(openai_model: str): | |
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train") | |
captions = dataset[np.random.permutation(len(dataset))]["TEXT"] | |
index = 0 | |
def click_random(): | |
nonlocal index | |
output = captions[index] | |
index = (index + 1) % len(captions) | |
return output | |
def click_generate(input: str): | |
if input == "": | |
raise gr.Error("Input caption is missing!") | |
edit_output = generate(openai_model, input) | |
if edit_output is None: | |
return "Failed :(", "Failed :(" | |
return edit_output | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") # fmt: skip | |
txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False) | |
txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
random_btn = gr.Button("Random Input") | |
generate_btn = gr.Button("Generate Instruction + Edited Caption") | |
clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output]) | |
random_btn.click(fn=click_random, inputs=[], outputs=[txt_input]) | |
generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output]) | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("openai-api-key", type=str) | |
parser.add_argument("openai-model", type=str) | |
args = parser.parse_args() | |
openai.api_key = args.openai_api_key | |
main(args.openai_model) | |