Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # TODO | |
| # Remove duplication in code used to generate markdown | |
| # periodically update models to check all still valid and public | |
| import os | |
| import re | |
| import sys | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Union | |
| import gradio as gr | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from apscheduler.triggers.cron import CronTrigger | |
| from cachetools import TTLCache, cached | |
| from diskcache import Cache | |
| from dotenv import load_dotenv | |
| from huggingface_hub import ( | |
| HfApi, | |
| comment_discussion, | |
| create_discussion, | |
| dataset_info, | |
| get_repo_discussions, | |
| ) | |
| from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError | |
| from sqlitedict import SqliteDict | |
| from toolz import concat, count, unique | |
| from tqdm.auto import tqdm | |
| from tqdm.contrib.concurrent import thread_map | |
| local = bool(sys.platform.startswith("darwin")) | |
| cache_location = "cache/" if local else "/data/cache" | |
| save_dir = "test_data" if local else "/data/" | |
| Path(save_dir).mkdir(parents=True, exist_ok=True) | |
| cache = Cache(cache_location) | |
| load_dotenv() | |
| user_agent = os.getenv("USER_AGENT") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| REPO = "librarian-bots/dataset-to-model-monitor" # where issues land | |
| AUTHOR = "librarian-bot" # who makes the issues | |
| hf_api = HfApi(user_agent=user_agent) | |
| ten_min_cache = TTLCache(maxsize=5_000, ttl=600) | |
| def get_datasets_for_user(username: str) -> List[str]: | |
| datasets = hf_api.list_datasets(author=username) | |
| datasets = (dataset.id for dataset in datasets) | |
| return datasets | |
| def get_models_for_dataset(dataset_id): | |
| results = list(iter(hf_api.list_models(filter=f"dataset:{dataset_id}"))) | |
| if results: | |
| results = list({result.id for result in results}) | |
| return {dataset_id: results} | |
| def generate_dataset_model_map( | |
| dataset_ids: List[str], | |
| ) -> dict[str, dict[str, List[str]]]: | |
| results = thread_map(get_models_for_dataset, dataset_ids) | |
| results = {key: value for d in results for key, value in d.items()} | |
| return results | |
| def maybe_update_datasets_to_model_map(dataset_id): | |
| with SqliteDict(f"{save_dir}/models_to_dataset.sqlite") as dataset_to_model_map_db: | |
| if dataset_id not in dataset_to_model_map_db: | |
| dataset_to_model_map_db[dataset_id] = list( | |
| get_models_for_dataset(dataset_id)[dataset_id] | |
| ) | |
| dataset_to_model_map_db.commit() | |
| return len(dataset_to_model_map_db) | |
| return False | |
| def datasets_tracked_by_user(username): | |
| with SqliteDict( | |
| f"{save_dir}/tracked_dataset_to_users.sqlite" | |
| ) as tracked_dataset_to_users_db: | |
| return [ | |
| dataset | |
| for dataset, users in tracked_dataset_to_users_db.items() | |
| if username in users | |
| ] | |
| def update_tracked_dataset_to_users(dataset_id: str, username: str): | |
| with SqliteDict( | |
| f"{save_dir}/tracked_dataset_to_users.sqlite", | |
| ) as tracked_dataset_to_users_db: | |
| if dataset_id in tracked_dataset_to_users_db: | |
| # check if user already tracking dataset | |
| if username not in tracked_dataset_to_users_db[dataset_id]: | |
| users_for_dataset = tracked_dataset_to_users_db[dataset_id] | |
| users_for_dataset.append(username) | |
| tracked_dataset_to_users_db[dataset_id] = list(set(users_for_dataset)) | |
| tracked_dataset_to_users_db.commit() | |
| else: | |
| tracked_dataset_to_users_db[dataset_id] = [username] | |
| tracked_dataset_to_users_db.commit() | |
| return datasets_tracked_by_user(username) | |
| HUB_ORG_OR_USERNAME_GLOB_PATTERN = re.compile(r"^([^/]+)(?=/)") | |
| def match_org_user_glob_pattern(hub_id): | |
| if match := re.match(HUB_ORG_OR_USERNAME_GLOB_PATTERN, hub_id): | |
| return match[1] | |
| else: | |
| return None | |
| def grab_dataset_ids_for_user_or_org(hub_id: str) -> List[str]: | |
| datasets_for_org = hf_api.list_datasets(author=hub_id) | |
| datasets_for_org = ( | |
| dataset for dataset in datasets_for_org if dataset.private is False | |
| ) | |
| return [dataset.id for dataset in datasets_for_org] | |
| def parse_hub_id_entry(hub_id: str) -> Union[str, List[str]]: | |
| if match := match_org_user_glob_pattern(hub_id): | |
| return grab_dataset_ids_for_user_or_org(match), match | |
| try: | |
| dataset_info(hub_id) | |
| return hub_id, match | |
| except HFValidationError as e: | |
| raise gr.Error(f"Invalid format for Hugging Face Hub dataset ID. {e}") from e | |
| except RepositoryNotFoundError as e: | |
| raise gr.Error("Invalid Hugging Face Hub dataset ID") from e | |
| def remove_user_from_tracking_datasets(dataset_id, profile: gr.OAuthProfile | None): | |
| if not profile and not local: | |
| return "You must be logged in to remove a dataset" | |
| username = profile.preferred_username | |
| dataset_id, match = parse_hub_id_entry(dataset_id) | |
| if isinstance(dataset_id, str): | |
| return _remove_user_from_tracking_datasets(dataset_id, username) | |
| if isinstance(dataset_id, list): | |
| [ | |
| _remove_user_from_tracking_datasets(dataset, username) | |
| for dataset in dataset_id | |
| ] | |
| return f"Stopped tracking datasets for username or org: {match}" | |
| def _remove_user_from_tracking_datasets(dataset_id: str, username): | |
| with SqliteDict( | |
| f"{save_dir}/tracked_dataset_to_users.sqlite" | |
| ) as tracked_dataset_to_users_db: | |
| users = tracked_dataset_to_users_db.get(dataset_id) | |
| if users is None: | |
| return "Dataset not being tracked" | |
| try: | |
| users.remove(username) | |
| except ValueError: | |
| return "No longer tracking dataset" | |
| tracked_dataset_to_users_db[dataset_id] = users | |
| if len(users) < 1: | |
| del tracked_dataset_to_users_db[dataset_id] | |
| with SqliteDict( | |
| f"{save_dir}/models_to_dataset.sqlite" | |
| ) as dataset_to_models_db: | |
| del dataset_to_models_db[dataset_id] | |
| dataset_to_models_db.commit() | |
| tracked_dataset_to_users_db.commit() | |
| return "Dataset no longer being tracked" | |
| def user_unsubscribe_all(username): | |
| datasets_tracked = datasets_tracked_by_user(username) | |
| for dataset_id in datasets_tracked: | |
| remove_user_from_tracking_datasets(username, dataset_id) | |
| assert len(datasets_tracked_by_user(username)) == 0 | |
| return f"Unsubscribed from {len(datasets_tracked)} datasets" | |
| def user_update(hub_id, profile: gr.OAuthProfile | None): | |
| if not profile and not local: | |
| return "Please login to track a dataset" | |
| username = profile.preferred_username | |
| hub_id, match = parse_hub_id_entry(hub_id) | |
| if isinstance(hub_id, str): | |
| return _user_update(hub_id, username) | |
| else: | |
| return glob_update_tracked_datasets(hub_id, username, match) | |
| def glob_update_tracked_datasets(hub_ids, username, match): | |
| for id_ in tqdm(hub_ids): | |
| _user_update(id_, username) | |
| response = "## Dataset tracking summary \n\n" | |
| response += ( | |
| f"All datasets under the user or organization: {match} are being tracked \n\n" | |
| ) | |
| tracked_datasets = datasets_tracked_by_user(username) | |
| response += ( | |
| "You are currently tracking whether new models have been trained on" | |
| f" {len(tracked_datasets)} datasets.\n\n" | |
| ) | |
| if tracked_datasets: | |
| response += "### Datasets being tracked \n\n" | |
| response += ( | |
| "You are currently monitoring whether new models have been trained on the" | |
| " following datasets:\n" | |
| ) | |
| for dataset in tracked_datasets: | |
| response += f"- [{dataset}](https://huggingface.co/datasets/{dataset})\n" | |
| return response | |
| def _user_update(hub_id: str, username: str) -> str: | |
| """Update the user's tracked datasets and return a response string.""" | |
| response = "" | |
| if number_datasets_being_tracked := maybe_update_datasets_to_model_map(hub_id): | |
| response += ( | |
| "New dataset being tracked! Now tracking" | |
| f" {number_datasets_being_tracked} datasets \n\n" | |
| ) | |
| if not number_datasets_being_tracked: | |
| response += f"Dataset {hub_id} is already being tracked. \n\n" | |
| datasets_tracked_by_user = update_tracked_dataset_to_users(hub_id, username) | |
| response += ( | |
| "You are currently whether new models have been trained on" | |
| f" {len(datasets_tracked_by_user)} datasets." | |
| ) | |
| if datasets_tracked_by_user: | |
| response += ( | |
| "\nYou are currently monitoring whether new models have been trained on the" | |
| " following datasets:\n" | |
| ) | |
| for dataset in datasets_tracked_by_user: | |
| response += f"- [{dataset}](https://huggingface.co/datasets/{dataset})\n" | |
| else: | |
| response += "You are not currently tracking any datasets." | |
| return response | |
| def check_for_new_models_for_dataset_and_update() -> Dict[str, Set[str]]: | |
| # if not Path(f"{save_dir}/models_to_dataset.json").is_file(): | |
| with SqliteDict(f"{save_dir}/models_to_dataset.sqlite") as old_results_db: | |
| dataset_ids = list(old_results_db.keys()) | |
| new_results = generate_dataset_model_map(dataset_ids) | |
| models_to_notify_about = { | |
| dataset_id: set(models).difference(set(old_results_db[dataset_id])) | |
| for dataset_id, models in new_results.items() | |
| if len(models) > len(old_results_db[dataset_id]) | |
| } | |
| for dataset_id, models in new_results.items(): | |
| old_results_db[dataset_id] = models | |
| old_results_db.commit() | |
| return models_to_notify_about | |
| def get_repo_discussion_by_author_and_type( | |
| repo, author, token, repo_type="space", include_prs=False | |
| ): | |
| discussions = get_repo_discussions(repo, repo_type=repo_type, token=token) | |
| for discussion in discussions: | |
| if discussion.author == author: | |
| if not include_prs and discussion.is_pull_request: | |
| continue | |
| yield discussion | |
| def create_discussion_text_body(dataset_id, new_models, users_to_notify): | |
| usernames = [f"@{username}" for username in users_to_notify] | |
| usernames_string = ", ".join(usernames) | |
| dataset_id_markdown_url = ( | |
| f"[{dataset_id}](https://huggingface.co/datasets/{dataset_id})" | |
| ) | |
| description = ( | |
| f"Hey {usernames_string}! Librarian bot found new models trained on the" | |
| f" {dataset_id_markdown_url} dataset!\n\n" | |
| ) | |
| description += f"New model trained on {dataset_id}:\n" | |
| markdown_items = [ | |
| f"- {hub_id_to_huggingface_hub_url_markdown(model)}" for model in new_models | |
| ] | |
| markdown_list = "\n".join(markdown_items) | |
| description += markdown_list | |
| return description | |
| def maybe_create_discussion( | |
| repo: str, | |
| dataset_id: str, | |
| new_models: Union[List, str], | |
| users_to_notify: List[str], | |
| author: str, | |
| token: str, | |
| ): | |
| title = f"Discussion tracking new models trained on {dataset_id}" | |
| discussions = get_repo_discussion_by_author_and_type(repo, author, HF_TOKEN) | |
| if discussions_for_dataset := next( | |
| (discussion for discussion in discussions if title == discussion.title), | |
| None, | |
| ): | |
| discussion_id = discussions_for_dataset.num | |
| description = create_discussion_text_body( | |
| dataset_id, new_models, users_to_notify | |
| ) | |
| comment_discussion( | |
| repo, discussion_id, description, token=token, repo_type="space" | |
| ) | |
| else: | |
| description = create_discussion_text_body( | |
| dataset_id, new_models, users_to_notify | |
| ) | |
| create_discussion( | |
| repo, | |
| title, | |
| token=token, | |
| description=description, | |
| repo_type="space", | |
| ) | |
| def hub_id_to_huggingface_hub_url_markdown(hub_id: str) -> str: | |
| return f"[{hub_id}](https://huggingface.co/{hub_id})" | |
| def notify_about_new_models(): | |
| print("running notifications") | |
| if models_to_notify_about := check_for_new_models_for_dataset_and_update(): | |
| for dataset_id, new_models in models_to_notify_about.items(): | |
| with SqliteDict( | |
| f"{save_dir}/tracked_dataset_to_users.sqlite" | |
| ) as tracked_dataset_to_users_db: | |
| users_to_notify = tracked_dataset_to_users_db.get(dataset_id) | |
| maybe_create_discussion( | |
| REPO, dataset_id, new_models, users_to_notify, AUTHOR, HF_TOKEN | |
| ) | |
| print("notified about new models") | |
| def number_of_users_tracking_datasets(): | |
| with SqliteDict( | |
| f"{save_dir}/tracked_dataset_to_users.sqlite" | |
| ) as tracked_dataset_to_users_db: | |
| return count(unique(concat(iter(tracked_dataset_to_users_db.values())))) | |
| def number_of_datasets_tracked(): | |
| with SqliteDict(f"{save_dir}/models_to_dataset.sqlite") as datasets_to_models_db: | |
| return len(datasets_to_models_db) | |
| def generate_summary_stats(): | |
| return ( | |
| f"Currently there are {number_of_users_tracking_datasets()} users tracking" | |
| f" datasets with a total of {number_of_datasets_tracked()} datasets being" | |
| " tracked" | |
| ) | |
| def _user_stats(username: str): | |
| if not (tracked_datasets := datasets_tracked_by_user(username)): | |
| return "You are not currently tracking any datasets" | |
| response = ( | |
| "You are currently tracking whether new models have been trained on" | |
| f" {len(tracked_datasets)} datasets.\n\n" | |
| ) | |
| response += "### Datasets being tracked \n\n" | |
| response += ( | |
| "You are currently monitoring whether new models have been trained on the" | |
| " following datasets:\n" | |
| ) | |
| for dataset in tracked_datasets: | |
| response += f"- [{dataset}](https://huggingface.co/datasets/{dataset})\n" | |
| return response | |
| def user_stats(profile: gr.OAuthProfile | None): | |
| if not profile and not local: | |
| return "You must be logged in to remove a dataset" | |
| username = profile.preferred_username | |
| return _user_stats(username) | |
| markdown_text = """ | |
| The Hugging Face Hub allows users to specify the dataset used to train a model in the model metadata. | |
| This metadata allows you to find models trained on a particular dataset. | |
| These links can be very powerful for finding models that might be suitable for a particular task.\n\n | |
| This Gradio app allows you to track datasets hosted on the Hugging Face Hub and get a notification when new models are trained on the dataset you are tracking. | |
| 1. Submit the Hugging Face Hub ID for the dataset you are interested in tracking. | |
| 2. If a new model is listed as being trained on this dataset Librarian Bot will ping you in a discussion on the Hugging Face Hub to let you know. | |
| 3. Librarian Bot will check for new models for a particular dataset once a day. | |
| **Tip** *You can use a wildcard `*` to track all datasets for a user or organization on the hub. For example `biglam/*` will create alerts for all the datasets under the biglam Hugging Face Organization* | |
| **You need to be logged in to your Hugging Face account to use this app.** If you don't have a Hugging Face Hub account you can get one <a href="https://huggingface.co/join">here</a>. | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| '<div style="text-align: center;"><h1> 🤖 Librarian Bot Dataset-to-Model' | |
| ' Monitor 🤖 </h1><i><p style="font-size: 20px;">✨ Get alerts when a new' | |
| " model is created from a dataset you are interested in! ✨</p></i></div>" | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(markdown_text) | |
| with gr.Row(): | |
| hub_id = gr.Textbox( | |
| "i.e. biglam/brill_iconclass", | |
| label="Hugging Face Hub ID for dataset to track", | |
| ) | |
| with gr.Column(): | |
| track_button = gr.Button("Track new models for dataset") | |
| with gr.Row(): | |
| remove_specific_datasets = gr.Button("Stop tracking dataset") | |
| remove_all = gr.Button("⛔️ Unsubscribe from all datasets ⛔️") | |
| with gr.Row(variant="compact"): | |
| gr.LoginButton(size="sm") | |
| gr.LogoutButton(size="sm") | |
| summary_stats_btn = gr.Button( | |
| "Summary stats for datasets being tracked by this app", size="sm" | |
| ) | |
| user_stats_btn = gr.Button("List my tracked datasets", size="sm") | |
| with gr.Row(): | |
| output = gr.Markdown() | |
| track_button.click(user_update, [hub_id], output) | |
| remove_specific_datasets.click( | |
| remove_user_from_tracking_datasets, [hub_id], output | |
| ) | |
| summary_stats_btn.click(generate_summary_stats, [], output) | |
| user_stats_btn.click(user_stats, [], output) | |
| scheduler = BackgroundScheduler() | |
| if local: | |
| scheduler.add_job(notify_about_new_models, "interval", minutes=30) | |
| else: | |
| scheduler.add_job( | |
| notify_about_new_models, | |
| CronTrigger.from_crontab("0 */12 * * *"), | |
| ) | |
| scheduler.start() | |
| demo.queue(max_size=5) | |
| demo.launch() | |