Spaces:
Build error
Build error
| from __future__ import annotations | |
| from argparse import ArgumentParser | |
| import datasets | |
| import gradio as gr | |
| import numpy as np | |
| import openai | |
| from dataset_creation.generate_txt_dataset import generate | |
| def main(openai_model: str): | |
| dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train") | |
| captions = dataset[np.random.permutation(len(dataset))]["TEXT"] | |
| index = 0 | |
| def click_random(): | |
| nonlocal index | |
| output = captions[index] | |
| index = (index + 1) % len(captions) | |
| return output | |
| def click_generate(input: str): | |
| if input == "": | |
| raise gr.Error("Input caption is missing!") | |
| edit_output = generate(openai_model, input) | |
| if edit_output is None: | |
| return "Failed :(", "Failed :(" | |
| return edit_output | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") # fmt: skip | |
| txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False) | |
| txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| random_btn = gr.Button("Random Input") | |
| generate_btn = gr.Button("Generate Instruction + Edited Caption") | |
| clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output]) | |
| random_btn.click(fn=click_random, inputs=[], outputs=[txt_input]) | |
| generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output]) | |
| demo.launch(share=True) | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument("--openai-api-key", required=True, type=str) | |
| parser.add_argument("--openai-model", required=True, type=str) | |
| args = parser.parse_args() | |
| openai.api_key = args.openai_api_key | |
| main(args.openai_model) | |