ms180's picture
initial commit
068a50e
raw
history blame
3.13 kB
# source: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard/blob/main/src/utils_display.py
from dataclasses import dataclass
from transformers import AutoConfig
# These classes are for user facing column names, to avoid having to change them
# all around the code when a modif is needed
@dataclass
class ColumnContent:
name: str
type: str
displayed_by_default: bool
hidden: bool = False
def fields(raw_class):
return [
v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"
]
@dataclass(frozen=True)
class AutoEvalColumn: # Auto evals column
# you can use the following metrics:
# str, markdown, number
# ColumnContent(column name, type, flag if the value should be included in csv)
model = ColumnContent("Model", "markdown", True)
model_size = ColumnContent("Size (M)", "number", True)
train_config = ColumnContent("Training Config", "str", True)
model_config = ColumnContent("Model Config", "str", True)
espnet_version = ColumnContent("espnet version", "str", True)
pytorch_version = ColumnContent("pytorch version", "str", True)
wer_test_clean = ColumnContent("WER (test-clean)", "number", True)
wer_test_other = ColumnContent("WER (test-other)", "number", True)
wer_dev_clean = ColumnContent("WER (dev-clean)", "number", True)
wer_dev_other = ColumnContent("WER (dev-other)", "number", True)
cer_test_clean = ColumnContent("CER (test-clean)", "number", True)
cer_test_other = ColumnContent("CER (test-other)", "number", True)
cer_dev_clean = ColumnContent("CER (dev-clean)", "number", True)
cer_dev_other = ColumnContent("CER (dev-other)", "number", True)
def model_hyperlink(link, model_name):
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
def make_clickable_names(df):
df["Model"] = df.apply(
lambda row: model_hyperlink(row["Links"], row["Model"]), axis=1
)
return df
def styled_error(error):
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
def styled_warning(warn):
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
def styled_message(message):
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
def has_no_nan_values(df, columns):
return df[columns].notna().all(axis=1)
def has_nan_values(df, columns):
return df[columns].isna().any(axis=1)
def is_model_on_hub(model_name: str, revision: str) -> bool:
try:
AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=False)
return True, None
except ValueError:
return (
False,
"needs to be launched with `trust_remote_code=True`. For safety reason, we do not allow these models to be automatically submitted to the leaderboard.",
)
except Exception as e:
print(f"Could not get the model config from the hub.: {e}")
return False, "was not found on hub!"