joaogante's picture
joaogante HF staff
add cache
08f9035
raw
history blame
10 kB
import gradio as gr
import huggingface_hub as hfh
from requests.exceptions import HTTPError
from functools import lru_cache
from general_suggestions import GENERAL_SUGGESTIONS
from model_suggestions import MODEL_SUGGESTIONS
from task_suggestions import TASK_SUGGESTIONS
# =====================================================================================================================
# DATA
# =====================================================================================================================
# Dict with the tasks considered in this spaces, {task: task tag}
TASK_TYPES = {
"✍️ Text Generation": "txtgen",
"🤏 Summarization": "summ",
"🫂 Translation": "trans",
"💬 Conversational / Chatbot": "chat",
"🤷 Text Question Answering": "txtqa",
"🕵️ (Table/Document/Visual) Question Answering": "otherqa",
"🎤 Automatic Speech Recognition": "asr",
"🌇 Image to Text": "img2txt",
}
# Dict matching all task types with their possible hub tags, {task tag: (possible, hub, tags)}
HUB_TAGS = {
"txtgen": ("text-generation", "text2text-generation"),
"summ": ("summarization", "text-generation", "text2text-generation"),
"trans": ("translation", "text-generation", "text2text-generation"),
"chat": ("conversational", "text-generation", "text2text-generation"),
"txtqa": ("text-generation", "text2text-generation"),
"otherqa": ("table-question-answering", "document-question-answering", "visual-question-answering"),
"asr": ("automatic-speech-recognition",),
"img2txt": ("image-to-text",),
}
assert len(TASK_TYPES) == len(TASK_TYPES)
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values())
# Dict with the problems considered in this spaces, {problem: problem tag}
PROBLEMS = {
"🤔 Baseline. I'm getting gibberish and I want a baseline": "baseline",
"😵 Crashes. I want to prevent my model from crashing again": "crashes",
"🤥 Hallucinations. I would like to reduce them": "hallucinations",
"📏 Length. I want to control the length of the output": "length",
"🌈 Prompting. I want better outputs without changing my generation options": "prompting",
"😵‍💫 Repetitions. Make them stop make them stop": "repetitions",
"📈 Quality. I want better outputs without changing my prompt": "quality",
"🏎 Speed! Make it faster!": "speed",
}
INIT_MARKDOWN = """
 
👈 Fill in as much information as you can...
 
 
 
 
 
 
👈 ... then click here!
"""
DEMO_MARKDOWN = """
⛔️ This is still a demo 🤗 Working sections include "Length" and "Quality" ⛔️
"""
MODEL_PROBLEM = """
😱 Could not retrieve model tags for the specified model, `{model_name}`. Ensure that the model name matches a Hub
model repo, that it is a public model, and that it has Hub tags.
"""
SUGGETIONS_HEADER = """
#### ✨ Here is a list of suggestions for you -- click to expand ✨
"""
PERFECT_MATCH_EMOJI = "✅"
POSSIBLE_MATCH_EMOJI = "❓"
MISSING_INPUTS = """
💡 You can filter suggestions with {} if you add more inputs. Suggestions with {} are a perfect match.
""".format(POSSIBLE_MATCH_EMOJI, PERFECT_MATCH_EMOJI)
# The space below is reserved for suggestions that require advanced logic and/or formatting
TASK_MODEL_MISMATCH = """
<details><summary>{count}. Select a model better suited for your task.</summary>
&nbsp;
🤔 Why? &nbsp;
The selected model (`{model_name}`) doesn't have a tag compatible with the task you selected ("{task_type}").
Expected tags for this task are: {tags} &nbsp;
🤗 How? &nbsp;
Our recommendation is to go to our [tasks page](https://huggingface.co/tasks) and select one of the suggested
models as a starting point. &nbsp;
😱 Caveats &nbsp;
1. The tags of a model are defined by the community and are not always accurate. If you think the model is incorrectly
tagged or missing a tag, please open an issue on the [model card](https://huggingface.co/{model_name}/tree/main).
_________________
</details>
"""
# =====================================================================================================================
# =====================================================================================================================
# SUGGESTIONS LOGIC
# =====================================================================================================================
def is_valid_task_for_model(model_tags, user_task):
if len(model_tags) == 0 or user_task == "":
return True # No model / no task tag = no problem :)
possible_tags = HUB_TAGS[user_task]
return any(tag in model_tags for tag in possible_tags)
@lru_cache(maxsize=int(2e10))
def get_model_tags(model_name):
if model_name == "":
return []
try:
model_tags = hfh.HfApi().model_info(model_name).tags
except HTTPError:
model_tags = []
return model_tags
@lru_cache(maxsize=int(2e10))
def get_suggestions(task_type, model_name, problem_type):
# Check if the inputs were given
if all([task_type == "", model_name == "", problem_type == ""]):
return INIT_MARKDOWN
suggestions = ""
counter = 0
model_tags = get_model_tags(model_name)
# If there is a model name but no model tags, something went wrong
if model_name != "" and len(model_tags) == 0:
return MODEL_PROBLEM.format(model_name=model_name)
user_problem = PROBLEMS.get(problem_type, "")
user_task = TASK_TYPES.get(task_type, "")
# Check if the model is valid for the task. If not, return straight away
if not is_valid_task_for_model(model_tags, user_task):
counter += 1
possible_tags = " ".join("`" + tag + "`" for tag in HUB_TAGS[user_task])
suggestions += TASK_MODEL_MISMATCH.format(
count=counter, model_name=model_name, task_type=user_task, tags=possible_tags
)
return suggestions
# Demo shortcut: only a few sections are working
if user_problem not in ("", "length", "quality"):
return DEMO_MARKDOWN
# First: model-specific suggestions
has_model_specific_suggestions = False
match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or len(model_tags) == 0) else PERFECT_MATCH_EMOJI
for model_tag, problem_tags, suggestion in MODEL_SUGGESTIONS:
if user_problem == "" or user_problem in problem_tags:
if len(model_tags) == 0 or model_tag in model_tags:
counter += 1
suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
has_model_specific_suggestions = True
# Second: task-specific suggestions
has_task_specific_suggestions = False
match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or user_task == "") else PERFECT_MATCH_EMOJI
for task_tags, problem_tags, suggestion in TASK_SUGGESTIONS:
if user_problem == "" or user_problem in problem_tags:
if user_task == "" or user_task in task_tags:
counter += 1
suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
has_task_specific_suggestions = True
# Finally: general suggestions for the problem
has_problem_specific_suggestions = False
match_emoji = POSSIBLE_MATCH_EMOJI if user_problem == "" else PERFECT_MATCH_EMOJI
for problem_tags, suggestion in GENERAL_SUGGESTIONS:
if user_problem == "" or user_problem in problem_tags:
counter += 1
suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
has_problem_specific_suggestions = True
# Prepends needed bits
if (
(task_type == "" and has_task_specific_suggestions)
or (model_name == "" and has_model_specific_suggestions)
or (problem_type == "" and has_problem_specific_suggestions)
):
suggestions = MISSING_INPUTS + suggestions
return SUGGETIONS_HEADER + suggestions
# =====================================================================================================================
# =====================================================================================================================
# GRADIO
# =====================================================================================================================
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# 🚀💬 Improving Generated Text 💬🚀
This is a ever-evolving guide on how to improve your text generation results. It is community-led and
curated by Hugging Face 🤗
"""
)
with gr.Row():
with gr.Column():
problem_type = gr.Dropdown(
label="What would you like to improve?",
choices=[""] + list(PROBLEMS.keys()),
interactive=True,
value="",
)
task_type = gr.Dropdown(
label="Which task are you working on?",
choices=[""] + list(TASK_TYPES.keys()),
interactive=True,
value="",
)
model_name = gr.Textbox(
label="Which model are you using?",
placeholder="e.g. google/flan-t5-xl",
interactive=True,
)
button = gr.Button(value="Get Suggestions!")
with gr.Column(scale=2):
suggestions = gr.Markdown(value=INIT_MARKDOWN)
button.click(get_suggestions, inputs=[task_type, model_name, problem_type], outputs=suggestions)
gr.Markdown(
"""
&nbsp;
Is your problem not on the list? Need more suggestions? Have you spotted an error? Please open a
[new discussion](https://huggingface.co/spaces/joaogante/generate_quality_improvement/discussions) 🙏
"""
)
# =====================================================================================================================
if __name__ == "__main__":
demo.launch()