import gradio as gr import huggingface_hub as hfh from requests.exceptions import HTTPError from functools import lru_cache from general_suggestions import GENERAL_SUGGESTIONS from model_suggestions import MODEL_SUGGESTIONS from task_suggestions import TASK_SUGGESTIONS # ===================================================================================================================== # DATA # ===================================================================================================================== # Dict with the tasks considered in this spaces, {task: task tag} TASK_TYPES = { "✍️ Text Generation": "txtgen", "🤏 Summarization": "summ", "🫂 Translation": "trans", "💬 Conversational / Chatbot": "chat", "🤷 Text Question Answering": "txtqa", "🕵️ (Table/Document/Visual) Question Answering": "otherqa", "🎤 Automatic Speech Recognition": "asr", "🌇 Image to Text": "img2txt", } # Dict matching all task types with their possible hub tags, {task tag: (possible, hub, tags)} HUB_TAGS = { "txtgen": ("text-generation", "text2text-generation"), "summ": ("summarization", "text-generation", "text2text-generation"), "trans": ("translation", "text-generation", "text2text-generation"), "chat": ("conversational", "text-generation", "text2text-generation"), "txtqa": ("text-generation", "text2text-generation"), "otherqa": ("table-question-answering", "document-question-answering", "visual-question-answering"), "asr": ("automatic-speech-recognition",), "img2txt": ("image-to-text",), } assert len(TASK_TYPES) == len(TASK_TYPES) assert all(tag in HUB_TAGS for tag in TASK_TYPES.values()) # Dict with the problems considered in this spaces, {problem: problem tag} PROBLEMS = { "🤔 Baseline. I'm getting gibberish and I want a baseline": "baseline", "😵 Crashes. I want to prevent my model from crashing again": "crashes", "🤥 Hallucinations. I would like to reduce them": "hallucinations", "📏 Length. I want to control the length of the output": "length", "🌈 Prompting. I want better outputs without changing my generation options": "prompting", "😵‍💫 Repetitions. Make them stop make them stop": "repetitions", "📈 Quality. I want better outputs without changing my prompt": "quality", "🏎 Speed! Make it faster!": "speed", } INIT_MARKDOWN = """   👈 Fill in as much information as you can...             👈 ... then click here! """ DEMO_MARKDOWN = """ ⛔️ This is still a demo 🤗 Working sections include "Length" and "Quality" ⛔️ """ MODEL_PROBLEM = """ 😱 Could not retrieve model tags for the specified model, `{model_name}`. Ensure that the model name matches a Hub model repo, that it is a public model, and that it has Hub tags. """ SUGGETIONS_HEADER = """ #### ✨ Here is a list of suggestions for you -- click to expand ✨ """ PERFECT_MATCH_EMOJI = "✅" POSSIBLE_MATCH_EMOJI = "❓" MISSING_INPUTS = """ 💡 You can filter suggestions with {} if you add more inputs. Suggestions with {} are a perfect match. """.format(POSSIBLE_MATCH_EMOJI, PERFECT_MATCH_EMOJI) # The space below is reserved for suggestions that require advanced logic and/or formatting TASK_MODEL_MISMATCH = """
{count}. Select a model better suited for your task.   🤔 Why?   The selected model (`{model_name}`) doesn't have a tag compatible with the task you selected ("{task_type}"). Expected tags for this task are: {tags}   🤗 How?   Our recommendation is to go to our [tasks page](https://huggingface.co/tasks) and select one of the suggested models as a starting point.   😱 Caveats   1. The tags of a model are defined by the community and are not always accurate. If you think the model is incorrectly tagged or missing a tag, please open an issue on the [model card](https://huggingface.co/{model_name}/tree/main). _________________
""" # ===================================================================================================================== # ===================================================================================================================== # SUGGESTIONS LOGIC # ===================================================================================================================== def is_valid_task_for_model(model_tags, user_task): if len(model_tags) == 0 or user_task == "": return True # No model / no task tag = no problem :) possible_tags = HUB_TAGS[user_task] return any(tag in model_tags for tag in possible_tags) @lru_cache(maxsize=int(2e10)) def get_model_tags(model_name): if model_name == "": return [] try: model_tags = hfh.HfApi().model_info(model_name).tags except HTTPError: model_tags = [] return model_tags def get_suggestions(task_type, model_name, problem_type): # Check if the inputs were given if all([task_type == "", model_name == "", problem_type == ""]): return INIT_MARKDOWN suggestions = "" counter = 0 model_tags = get_model_tags(model_name) # If there is a model name but no model tags, something went wrong if model_name != "" and len(model_tags) == 0: return MODEL_PROBLEM.format(model_name=model_name) user_problem = PROBLEMS.get(problem_type, "") user_task = TASK_TYPES.get(task_type, "") # Check if the model is valid for the task. If not, return straight away if not is_valid_task_for_model(model_tags, user_task): counter += 1 possible_tags = " ".join("`" + tag + "`" for tag in HUB_TAGS[user_task]) suggestions += TASK_MODEL_MISMATCH.format( count=counter, model_name=model_name, task_type=user_task, tags=possible_tags ) return suggestions # Demo shortcut: only a few sections are working if user_problem not in ("", "length", "quality"): return DEMO_MARKDOWN # First: model-specific suggestions has_model_specific_suggestions = False match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or len(model_tags) == 0) else PERFECT_MATCH_EMOJI for model_tag, problem_tags, suggestion in MODEL_SUGGESTIONS: if user_problem == "" or user_problem in problem_tags: if len(model_tags) == 0 or model_tag in model_tags: counter += 1 suggestions += suggestion.format(count=counter, match_emoji=match_emoji) has_model_specific_suggestions = True # Second: task-specific suggestions has_task_specific_suggestions = False match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or user_task == "") else PERFECT_MATCH_EMOJI for task_tags, problem_tags, suggestion in TASK_SUGGESTIONS: if user_problem == "" or user_problem in problem_tags: if user_task == "" or user_task in task_tags: counter += 1 suggestions += suggestion.format(count=counter, match_emoji=match_emoji) has_task_specific_suggestions = True # Finally: general suggestions for the problem has_problem_specific_suggestions = False match_emoji = POSSIBLE_MATCH_EMOJI if user_problem == "" else PERFECT_MATCH_EMOJI for problem_tags, suggestion in GENERAL_SUGGESTIONS: if user_problem == "" or user_problem in problem_tags: counter += 1 suggestions += suggestion.format(count=counter, match_emoji=match_emoji) has_problem_specific_suggestions = True # Prepends needed bits if ( (task_type == "" and has_task_specific_suggestions) or (model_name == "" and has_model_specific_suggestions) or (problem_type == "" and has_problem_specific_suggestions) ): suggestions = MISSING_INPUTS + suggestions return SUGGETIONS_HEADER + suggestions # ===================================================================================================================== # ===================================================================================================================== # GRADIO # ===================================================================================================================== demo = gr.Blocks() with demo: gr.Markdown( """ # 🚀💬 Improving Generated Text 💬🚀 This is a ever-evolving guide on how to improve your text generation results. It is community-led and curated by Hugging Face 🤗 """ ) with gr.Row(): with gr.Column(): problem_type = gr.Dropdown( label="What would you like to improve?", choices=[""] + list(PROBLEMS.keys()), interactive=True, value="", ) task_type = gr.Dropdown( label="Which task are you working on?", choices=[""] + list(TASK_TYPES.keys()), interactive=True, value="", ) model_name = gr.Textbox( label="Which model are you using?", placeholder="e.g. google/flan-t5-xl", interactive=True, ) button = gr.Button(value="Get Suggestions!") with gr.Column(scale=2): suggestions = gr.Markdown(value=INIT_MARKDOWN) button.click(get_suggestions, inputs=[task_type, model_name, problem_type], outputs=suggestions) gr.Markdown( """   Is your problem not on the list? Need more suggestions? Have you spotted an error? Please open a [new discussion](https://huggingface.co/spaces/joaogante/generate_quality_improvement/discussions) 🙏 """ ) # ===================================================================================================================== if __name__ == "__main__": demo.launch()