saba9 HF Staff commited on
Commit
3bba4f1
·
verified ·
1 Parent(s): 6b9572f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __init__.py +259 -0
  3. __pycache__/__init__.cpython-312.pyc +0 -0
  4. __pycache__/__init__.cpython-313.pyc +0 -0
  5. __pycache__/cli.cpython-312.pyc +0 -0
  6. __pycache__/commit_scheduler.cpython-312.pyc +0 -0
  7. __pycache__/commit_scheduler.cpython-313.pyc +0 -0
  8. __pycache__/context_vars.cpython-312.pyc +0 -0
  9. __pycache__/context_vars.cpython-313.pyc +0 -0
  10. __pycache__/deploy.cpython-312.pyc +0 -0
  11. __pycache__/deploy.cpython-313.pyc +0 -0
  12. __pycache__/dummy_commit_scheduler.cpython-312.pyc +0 -0
  13. __pycache__/dummy_commit_scheduler.cpython-313.pyc +0 -0
  14. __pycache__/file_storage.cpython-312.pyc +0 -0
  15. __pycache__/imports.cpython-312.pyc +0 -0
  16. __pycache__/imports.cpython-313.pyc +0 -0
  17. __pycache__/media.cpython-312.pyc +0 -0
  18. __pycache__/media_commit_scheduler.cpython-312.pyc +0 -0
  19. __pycache__/run.cpython-312.pyc +0 -0
  20. __pycache__/run.cpython-313.pyc +0 -0
  21. __pycache__/sqlite_storage.cpython-312.pyc +0 -0
  22. __pycache__/sqlite_storage.cpython-313.pyc +0 -0
  23. __pycache__/table.cpython-312.pyc +0 -0
  24. __pycache__/typehints.cpython-312.pyc +0 -0
  25. __pycache__/ui.cpython-312.pyc +0 -0
  26. __pycache__/ui.cpython-313.pyc +0 -0
  27. __pycache__/utils.cpython-312.pyc +0 -0
  28. __pycache__/utils.cpython-313.pyc +0 -0
  29. assets/trackio_logo_dark.png +0 -0
  30. assets/trackio_logo_light.png +0 -0
  31. assets/trackio_logo_old.png +3 -0
  32. assets/trackio_logo_type_dark.png +0 -0
  33. assets/trackio_logo_type_dark_transparent.png +0 -0
  34. assets/trackio_logo_type_light.png +0 -0
  35. assets/trackio_logo_type_light_transparent.png +0 -0
  36. cli.py +32 -0
  37. commit_scheduler.py +391 -0
  38. context_vars.py +15 -0
  39. deploy.py +172 -0
  40. dummy_commit_scheduler.py +12 -0
  41. file_storage.py +37 -0
  42. imports.py +288 -0
  43. media.py +241 -0
  44. py.typed +0 -0
  45. run.py +156 -0
  46. sqlite_storage.py +398 -0
  47. table.py +55 -0
  48. typehints.py +17 -0
  49. ui.py +771 -0
  50. utils.py +637 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import warnings
4
+ import webbrowser
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from gradio.blocks import BUILT_IN_THEMES
9
+ from gradio.themes import Default as DefaultTheme
10
+ from gradio.themes import ThemeClass
11
+ from gradio_client import Client
12
+
13
+ from trackio import context_vars, deploy, utils
14
+ from trackio.imports import import_csv, import_tf_events
15
+ from trackio.media import TrackioImage, TrackioVideo
16
+ from trackio.run import Run
17
+ from trackio.sqlite_storage import SQLiteStorage
18
+ from trackio.table import Table
19
+ from trackio.ui import demo
20
+ from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
21
+
22
+ __version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
23
+
24
+ __all__ = [
25
+ "init",
26
+ "log",
27
+ "finish",
28
+ "show",
29
+ "import_csv",
30
+ "import_tf_events",
31
+ "Image",
32
+ "Table",
33
+ ]
34
+
35
+ Image = TrackioImage
36
+ Video = TrackioVideo
37
+
38
+
39
+ config = {}
40
+
41
+ DEFAULT_THEME = "citrus"
42
+
43
+
44
+ def init(
45
+ project: str,
46
+ name: str | None = None,
47
+ space_id: str | None = None,
48
+ dataset_id: str | None = None,
49
+ config: dict | None = None,
50
+ resume: str = "never",
51
+ settings: Any = None,
52
+ ) -> Run:
53
+ """
54
+ Creates a new Trackio project and returns a [`Run`] object.
55
+
56
+ Args:
57
+ project (`str`):
58
+ The name of the project (can be an existing project to continue tracking or
59
+ a new project to start tracking from scratch).
60
+ name (`str` or `None`, *optional*, defaults to `None`):
61
+ The name of the run (if not provided, a default name will be generated).
62
+ space_id (`str` or `None`, *optional*, defaults to `None`):
63
+ If provided, the project will be logged to a Hugging Face Space instead of
64
+ a local directory. Should be a complete Space name like
65
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
66
+ case the Space will be created in the currently-logged-in Hugging Face
67
+ user's namespace. If the Space does not exist, it will be created. If the
68
+ Space already exists, the project will be logged to it.
69
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
70
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
71
+ created and the metrics will be synced to it every 5 minutes. Specify a
72
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
73
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
74
+ or `None` (uses the same name as the Space but with the `"_dataset"`
75
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
76
+ already exists, the project will be appended to it.
77
+ config (`dict` or `None`, *optional*, defaults to `None`):
78
+ A dictionary of configuration options. Provided for compatibility with
79
+ `wandb.init()`.
80
+ resume (`str`, *optional*, defaults to `"never"`):
81
+ Controls how to handle resuming a run. Can be one of:
82
+
83
+ - `"must"`: Must resume the run with the given name, raises error if run
84
+ doesn't exist
85
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
86
+ - `"never"`: Never resume a run, always create a new one
87
+ settings (`Any`, *optional*, defaults to `None`):
88
+ Not used. Provided for compatibility with `wandb.init()`.
89
+
90
+ Returns:
91
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
92
+ """
93
+ if settings is not None:
94
+ warnings.warn(
95
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
96
+ )
97
+
98
+ if space_id is None and dataset_id is not None:
99
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
100
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
101
+ url = context_vars.current_server.get()
102
+
103
+ if url is None:
104
+ if space_id is None:
105
+ _, url, _ = demo.launch(
106
+ show_api=False,
107
+ inline=False,
108
+ quiet=True,
109
+ prevent_thread_lock=True,
110
+ show_error=True,
111
+ )
112
+ else:
113
+ url = space_id
114
+ context_vars.current_server.set(url)
115
+
116
+ if (
117
+ context_vars.current_project.get() is None
118
+ or context_vars.current_project.get() != project
119
+ ):
120
+ print(f"* Trackio project initialized: {project}")
121
+
122
+ if dataset_id is not None:
123
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
124
+ print(
125
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
126
+ )
127
+ if space_id is None:
128
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
129
+ utils.print_dashboard_instructions(project)
130
+ else:
131
+ deploy.create_space_if_not_exists(space_id, dataset_id)
132
+ print(
133
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
134
+ )
135
+ context_vars.current_project.set(project)
136
+
137
+ client = None
138
+ if not space_id:
139
+ client = Client(url, verbose=False)
140
+
141
+ if resume == "must":
142
+ if name is None:
143
+ raise ValueError("Must provide a run name when resume='must'")
144
+ if name not in SQLiteStorage.get_runs(project):
145
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
146
+ resumed = True
147
+ elif resume == "allow":
148
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
149
+ elif resume == "never":
150
+ if name is not None and name in SQLiteStorage.get_runs(project):
151
+ warnings.warn(
152
+ f"* Warning: resume='never' but a run '{name}' already exists in "
153
+ f"project '{project}'. Generating a new name and instead. If you want "
154
+ "to resume this run, call init() with resume='must' or resume='allow'."
155
+ )
156
+ name = None
157
+ resumed = False
158
+ else:
159
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
160
+
161
+ run = Run(
162
+ url=url,
163
+ project=project,
164
+ client=client,
165
+ name=name,
166
+ config=config,
167
+ space_id=space_id,
168
+ )
169
+
170
+ if resumed:
171
+ print(f"* Resumed existing run: {run.name}")
172
+ else:
173
+ print(f"* Created new run: {run.name}")
174
+
175
+ context_vars.current_run.set(run)
176
+ globals()["config"] = run.config
177
+ return run
178
+
179
+
180
+ def log(metrics: dict, step: int | None = None) -> None:
181
+ """
182
+ Logs metrics to the current run.
183
+
184
+ Args:
185
+ metrics (`dict`):
186
+ A dictionary of metrics to log.
187
+ step (`int` or `None`, *optional*, defaults to `None`):
188
+ The step number. If not provided, the step will be incremented
189
+ automatically.
190
+ """
191
+ run = context_vars.current_run.get()
192
+ if run is None:
193
+ raise RuntimeError("Call trackio.init() before trackio.log().")
194
+ run.log(
195
+ metrics=metrics,
196
+ step=step,
197
+ )
198
+
199
+
200
+ def finish():
201
+ """
202
+ Finishes the current run.
203
+ """
204
+ run = context_vars.current_run.get()
205
+ if run is None:
206
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
207
+ run.finish()
208
+
209
+
210
+ def show(project: str | None = None, theme: str | ThemeClass = DEFAULT_THEME):
211
+ """
212
+ Launches the Trackio dashboard.
213
+
214
+ Args:
215
+ project (`str` or `None`, *optional*, defaults to `None`):
216
+ The name of the project whose runs to show. If not provided, all projects
217
+ will be shown and the user can select one.
218
+ theme (`str` or `ThemeClass`, *optional*, defaults to `"citrus"`):
219
+ A Gradio Theme to use for the dashboard instead of the default `"citrus"`,
220
+ can be a built-in theme (e.g. `'soft'`, `'default'`), a theme from the Hub
221
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class.
222
+ """
223
+ if theme != DEFAULT_THEME:
224
+ # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks,
225
+ # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we
226
+ # will be able to remove this code.
227
+ if isinstance(theme, str):
228
+ if theme.lower() in BUILT_IN_THEMES:
229
+ theme = BUILT_IN_THEMES[theme.lower()]
230
+ else:
231
+ try:
232
+ theme = ThemeClass.from_hub(theme)
233
+ except Exception as e:
234
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
235
+ theme = DefaultTheme()
236
+ if not isinstance(theme, ThemeClass):
237
+ warnings.warn("Theme should be a class loaded from gradio.themes")
238
+ theme = DefaultTheme()
239
+ demo.theme: ThemeClass = theme
240
+ demo.theme_css = theme._get_theme_css()
241
+ demo.stylesheets = theme._stylesheets
242
+ theme_hasher = hashlib.sha256()
243
+ theme_hasher.update(demo.theme_css.encode("utf-8"))
244
+ demo.theme_hash = theme_hasher.hexdigest()
245
+
246
+ _, url, share_url = demo.launch(
247
+ show_api=False,
248
+ quiet=True,
249
+ inline=False,
250
+ prevent_thread_lock=True,
251
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
252
+ allowed_paths=[TRACKIO_LOGO_DIR],
253
+ )
254
+
255
+ base_url = share_url + "/" if share_url else url
256
+ dashboard_url = base_url + f"?project={project}" if project else base_url
257
+ print(f"* Trackio UI launched at: {dashboard_url}")
258
+ webbrowser.open(dashboard_url)
259
+ utils.block_except_in_notebook()
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
__pycache__/__init__.cpython-313.pyc ADDED
Binary file (7.35 kB). View file
 
__pycache__/cli.cpython-312.pyc ADDED
Binary file (1.43 kB). View file
 
__pycache__/commit_scheduler.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
__pycache__/commit_scheduler.cpython-313.pyc ADDED
Binary file (18.3 kB). View file
 
__pycache__/context_vars.cpython-312.pyc ADDED
Binary file (759 Bytes). View file
 
__pycache__/context_vars.cpython-313.pyc ADDED
Binary file (745 Bytes). View file
 
__pycache__/deploy.cpython-312.pyc ADDED
Binary file (6.75 kB). View file
 
__pycache__/deploy.cpython-313.pyc ADDED
Binary file (6.27 kB). View file
 
__pycache__/dummy_commit_scheduler.cpython-312.pyc ADDED
Binary file (1.01 kB). View file
 
__pycache__/dummy_commit_scheduler.cpython-313.pyc ADDED
Binary file (1.1 kB). View file
 
__pycache__/file_storage.cpython-312.pyc ADDED
Binary file (1.63 kB). View file
 
__pycache__/imports.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
__pycache__/imports.cpython-313.pyc ADDED
Binary file (11.6 kB). View file
 
__pycache__/media.cpython-312.pyc ADDED
Binary file (12.6 kB). View file
 
__pycache__/media_commit_scheduler.cpython-312.pyc ADDED
Binary file (3.66 kB). View file
 
__pycache__/run.cpython-312.pyc ADDED
Binary file (7.42 kB). View file
 
__pycache__/run.cpython-313.pyc ADDED
Binary file (1.37 kB). View file
 
__pycache__/sqlite_storage.cpython-312.pyc ADDED
Binary file (19.1 kB). View file
 
__pycache__/sqlite_storage.cpython-313.pyc ADDED
Binary file (13.8 kB). View file
 
__pycache__/table.cpython-312.pyc ADDED
Binary file (2.46 kB). View file
 
__pycache__/typehints.cpython-312.pyc ADDED
Binary file (851 Bytes). View file
 
__pycache__/ui.cpython-312.pyc ADDED
Binary file (30.9 kB). View file
 
__pycache__/ui.cpython-313.pyc ADDED
Binary file (5.37 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (17.5 kB). View file
 
__pycache__/utils.cpython-313.pyc ADDED
Binary file (9.8 kB). View file
 
assets/trackio_logo_dark.png ADDED
assets/trackio_logo_light.png ADDED
assets/trackio_logo_old.png ADDED

Git LFS Details

  • SHA256: 3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
assets/trackio_logo_type_dark.png ADDED
assets/trackio_logo_type_dark_transparent.png ADDED
assets/trackio_logo_type_light.png ADDED
assets/trackio_logo_type_light_transparent.png ADDED
cli.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from trackio import show
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Trackio CLI")
8
+ subparsers = parser.add_subparsers(dest="command")
9
+
10
+ ui_parser = subparsers.add_parser(
11
+ "show", help="Show the Trackio dashboard UI for a project"
12
+ )
13
+ ui_parser.add_argument(
14
+ "--project", required=False, help="Project name to show in the dashboard"
15
+ )
16
+ ui_parser.add_argument(
17
+ "--theme",
18
+ required=False,
19
+ default="citrus",
20
+ help="A Gradio Theme to use for the dashboard instead of the default 'citrus', can be a built-in theme (e.g. 'soft', 'default'), a theme from the Hub (e.g. 'gstaff/xkcd').",
21
+ )
22
+
23
+ args = parser.parse_args()
24
+
25
+ if args.command == "show":
26
+ show(args.project, args.theme)
27
+ else:
28
+ parser.print_help()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
commit_scheduler.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
2
+
3
+ import atexit
4
+ import logging
5
+ import os
6
+ import time
7
+ from concurrent.futures import Future
8
+ from dataclasses import dataclass
9
+ from io import SEEK_END, SEEK_SET, BytesIO
10
+ from pathlib import Path
11
+ from threading import Lock, Thread
12
+ from typing import Callable, Dict, List, Optional, Union
13
+
14
+ from huggingface_hub.hf_api import (
15
+ DEFAULT_IGNORE_PATTERNS,
16
+ CommitInfo,
17
+ CommitOperationAdd,
18
+ HfApi,
19
+ )
20
+ from huggingface_hub.utils import filter_repo_objects
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class _FileToUpload:
27
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
28
+
29
+ local_path: Path
30
+ path_in_repo: str
31
+ size_limit: int
32
+ last_modified: float
33
+
34
+
35
+ class CommitScheduler:
36
+ """
37
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
38
+
39
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
40
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
41
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
42
+ to learn more about how to use it.
43
+
44
+ Args:
45
+ repo_id (`str`):
46
+ The id of the repo to commit to.
47
+ folder_path (`str` or `Path`):
48
+ Path to the local folder to upload regularly.
49
+ every (`int` or `float`, *optional*):
50
+ The number of minutes between each commit. Defaults to 5 minutes.
51
+ path_in_repo (`str`, *optional*):
52
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
53
+ of the repository.
54
+ repo_type (`str`, *optional*):
55
+ The type of the repo to commit to. Defaults to `model`.
56
+ revision (`str`, *optional*):
57
+ The revision of the repo to commit to. Defaults to `main`.
58
+ private (`bool`, *optional*):
59
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
60
+ token (`str`, *optional*):
61
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
62
+ allow_patterns (`List[str]` or `str`, *optional*):
63
+ If provided, only files matching at least one pattern are uploaded.
64
+ ignore_patterns (`List[str]` or `str`, *optional*):
65
+ If provided, files matching any of the patterns are not uploaded.
66
+ squash_history (`bool`, *optional*):
67
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
68
+ useful to avoid degraded performances on the repo when it grows too large.
69
+ hf_api (`HfApi`, *optional*):
70
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
71
+ on_before_commit (`Callable[[], None]`, *optional*):
72
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
73
+
74
+ Example:
75
+ ```py
76
+ >>> from pathlib import Path
77
+ >>> from huggingface_hub import CommitScheduler
78
+
79
+ # Scheduler uploads every 10 minutes
80
+ >>> csv_path = Path("watched_folder/data.csv")
81
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
82
+
83
+ >>> with csv_path.open("a") as f:
84
+ ... f.write("first line")
85
+
86
+ # Some time later (...)
87
+ >>> with csv_path.open("a") as f:
88
+ ... f.write("second line")
89
+ ```
90
+
91
+ Example using a context manager:
92
+ ```py
93
+ >>> from pathlib import Path
94
+ >>> from huggingface_hub import CommitScheduler
95
+
96
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
97
+ ... csv_path = Path("watched_folder/data.csv")
98
+ ... with csv_path.open("a") as f:
99
+ ... f.write("first line")
100
+ ... (...)
101
+ ... with csv_path.open("a") as f:
102
+ ... f.write("second line")
103
+
104
+ # Scheduler is now stopped and last commit have been triggered
105
+ ```
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ *,
111
+ repo_id: str,
112
+ folder_path: Union[str, Path],
113
+ every: Union[int, float] = 5,
114
+ path_in_repo: Optional[str] = None,
115
+ repo_type: Optional[str] = None,
116
+ revision: Optional[str] = None,
117
+ private: Optional[bool] = None,
118
+ token: Optional[str] = None,
119
+ allow_patterns: Optional[Union[List[str], str]] = None,
120
+ ignore_patterns: Optional[Union[List[str], str]] = None,
121
+ squash_history: bool = False,
122
+ hf_api: Optional["HfApi"] = None,
123
+ on_before_commit: Optional[Callable[[], None]] = None,
124
+ ) -> None:
125
+ self.api = hf_api or HfApi(token=token)
126
+ self.on_before_commit = on_before_commit
127
+
128
+ # Folder
129
+ self.folder_path = Path(folder_path).expanduser().resolve()
130
+ self.path_in_repo = path_in_repo or ""
131
+ self.allow_patterns = allow_patterns
132
+
133
+ if ignore_patterns is None:
134
+ ignore_patterns = []
135
+ elif isinstance(ignore_patterns, str):
136
+ ignore_patterns = [ignore_patterns]
137
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
138
+
139
+ if self.folder_path.is_file():
140
+ raise ValueError(
141
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
142
+ )
143
+ self.folder_path.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Repository
146
+ repo_url = self.api.create_repo(
147
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
148
+ )
149
+ self.repo_id = repo_url.repo_id
150
+ self.repo_type = repo_type
151
+ self.revision = revision
152
+ self.token = token
153
+
154
+ self.last_uploaded: Dict[Path, float] = {}
155
+ self.last_push_time: float | None = None
156
+
157
+ if not every > 0:
158
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
159
+ self.lock = Lock()
160
+ self.every = every
161
+ self.squash_history = squash_history
162
+
163
+ logger.info(
164
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
165
+ )
166
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
167
+ self._scheduler_thread.start()
168
+ atexit.register(self._push_to_hub)
169
+
170
+ self.__stopped = False
171
+
172
+ def stop(self) -> None:
173
+ """Stop the scheduler.
174
+
175
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
176
+ """
177
+ self.__stopped = True
178
+
179
+ def __enter__(self) -> "CommitScheduler":
180
+ return self
181
+
182
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
183
+ # Upload last changes before exiting
184
+ self.trigger().result()
185
+ self.stop()
186
+ return
187
+
188
+ def _run_scheduler(self) -> None:
189
+ """Dumb thread waiting between each scheduled push to Hub."""
190
+ while True:
191
+ self.last_future = self.trigger()
192
+ time.sleep(self.every * 60)
193
+ if self.__stopped:
194
+ break
195
+
196
+ def trigger(self) -> Future:
197
+ """Trigger a `push_to_hub` and return a future.
198
+
199
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
200
+ immediately, without waiting for the next scheduled commit.
201
+ """
202
+ return self.api.run_as_future(self._push_to_hub)
203
+
204
+ def _push_to_hub(self) -> Optional[CommitInfo]:
205
+ if self.__stopped: # If stopped, already scheduled commits are ignored
206
+ return None
207
+
208
+ logger.info("(Background) scheduled commit triggered.")
209
+ try:
210
+ value = self.push_to_hub()
211
+ if self.squash_history:
212
+ logger.info("(Background) squashing repo history.")
213
+ self.api.super_squash_history(
214
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
215
+ )
216
+ return value
217
+ except Exception as e:
218
+ logger.error(
219
+ f"Error while pushing to Hub: {e}"
220
+ ) # Depending on the setup, error might be silenced
221
+ raise
222
+
223
+ def push_to_hub(self) -> Optional[CommitInfo]:
224
+ """
225
+ Push folder to the Hub and return the commit info.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
230
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
231
+ issues.
232
+
233
+ </Tip>
234
+
235
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
236
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
237
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
238
+ for example to compress data together in a single file before committing. For more details and examples, check
239
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
240
+ """
241
+ # Check files to upload (with lock)
242
+ with self.lock:
243
+ if self.on_before_commit is not None:
244
+ self.on_before_commit()
245
+
246
+ logger.debug("Listing files to upload for scheduled commit.")
247
+
248
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
249
+ relpath_to_abspath = {
250
+ path.relative_to(self.folder_path).as_posix(): path
251
+ for path in sorted(
252
+ self.folder_path.glob("**/*")
253
+ ) # sorted to be deterministic
254
+ if path.is_file()
255
+ }
256
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
257
+
258
+ # Filter with pattern + filter out unchanged files + retrieve current file size
259
+ files_to_upload: List[_FileToUpload] = []
260
+ for relpath in filter_repo_objects(
261
+ relpath_to_abspath.keys(),
262
+ allow_patterns=self.allow_patterns,
263
+ ignore_patterns=self.ignore_patterns,
264
+ ):
265
+ local_path = relpath_to_abspath[relpath]
266
+ stat = local_path.stat()
267
+ if (
268
+ self.last_uploaded.get(local_path) is None
269
+ or self.last_uploaded[local_path] != stat.st_mtime
270
+ ):
271
+ files_to_upload.append(
272
+ _FileToUpload(
273
+ local_path=local_path,
274
+ path_in_repo=prefix + relpath,
275
+ size_limit=stat.st_size,
276
+ last_modified=stat.st_mtime,
277
+ )
278
+ )
279
+
280
+ # Return if nothing to upload
281
+ if len(files_to_upload) == 0:
282
+ logger.debug("Dropping schedule commit: no changed file to upload.")
283
+ return None
284
+
285
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
286
+ logger.debug("Removing unchanged files since previous scheduled commit.")
287
+ add_operations = [
288
+ CommitOperationAdd(
289
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
290
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
291
+ path_or_fileobj=file_to_upload.local_path,
292
+ path_in_repo=file_to_upload.path_in_repo,
293
+ )
294
+ for file_to_upload in files_to_upload
295
+ ]
296
+
297
+ # Upload files (append mode expected - no need for lock)
298
+ logger.debug("Uploading files for scheduled commit.")
299
+ commit_info = self.api.create_commit(
300
+ repo_id=self.repo_id,
301
+ repo_type=self.repo_type,
302
+ operations=add_operations,
303
+ commit_message="Scheduled Commit",
304
+ revision=self.revision,
305
+ )
306
+
307
+ for file in files_to_upload:
308
+ self.last_uploaded[file.local_path] = file.last_modified
309
+
310
+ self.last_push_time = time.time()
311
+
312
+ return commit_info
313
+
314
+
315
+ class PartialFileIO(BytesIO):
316
+ """A file-like object that reads only the first part of a file.
317
+
318
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
319
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
320
+
321
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
322
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
323
+
324
+ Only supports `read`, `tell` and `seek` methods.
325
+
326
+ Args:
327
+ file_path (`str` or `Path`):
328
+ Path to the file to read.
329
+ size_limit (`int`):
330
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
331
+ will be read (and uploaded).
332
+ """
333
+
334
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
335
+ self._file_path = Path(file_path)
336
+ self._file = self._file_path.open("rb")
337
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
338
+
339
+ def __del__(self) -> None:
340
+ self._file.close()
341
+ return super().__del__()
342
+
343
+ def __repr__(self) -> str:
344
+ return (
345
+ f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
346
+ )
347
+
348
+ def __len__(self) -> int:
349
+ return self._size_limit
350
+
351
+ def __getattribute__(self, name: str):
352
+ if name.startswith("_") or name in (
353
+ "read",
354
+ "tell",
355
+ "seek",
356
+ ): # only 3 public methods supported
357
+ return super().__getattribute__(name)
358
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
359
+
360
+ def tell(self) -> int:
361
+ """Return the current file position."""
362
+ return self._file.tell()
363
+
364
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
365
+ """Change the stream position to the given offset.
366
+
367
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
368
+ """
369
+ if __whence == SEEK_END:
370
+ # SEEK_END => set from the truncated end
371
+ __offset = len(self) + __offset
372
+ __whence = SEEK_SET
373
+
374
+ pos = self._file.seek(__offset, __whence)
375
+ if pos > self._size_limit:
376
+ return self._file.seek(self._size_limit)
377
+ return pos
378
+
379
+ def read(self, __size: Optional[int] = -1) -> bytes:
380
+ """Read at most `__size` bytes from the file.
381
+
382
+ Behavior is the same as a regular file, except that it is capped to the size limit.
383
+ """
384
+ current = self._file.tell()
385
+ if __size is None or __size < 0:
386
+ # Read until file limit
387
+ truncated_size = self._size_limit - current
388
+ else:
389
+ # Read until file limit or __size
390
+ truncated_size = min(__size, self._size_limit - current)
391
+ return self._file.read(truncated_size)
context_vars.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from trackio.run import Run
6
+
7
+ current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
8
+ "current_run", default=None
9
+ )
10
+ current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
11
+ "current_project", default=None
12
+ )
13
+ current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
14
+ "current_server", default=None
15
+ )
deploy.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from importlib.resources import files
5
+ from pathlib import Path
6
+
7
+ import gradio
8
+ import huggingface_hub
9
+ from gradio_client import Client, handle_file
10
+ from httpx import ReadTimeout
11
+ from huggingface_hub.errors import RepositoryNotFoundError
12
+ from requests import HTTPError
13
+
14
+ from trackio.sqlite_storage import SQLiteStorage
15
+
16
+ SPACE_URL = "https://huggingface.co/spaces/{space_id}"
17
+ PERSISTENT_STORAGE_DIR = "/data/.huggingface/trackio"
18
+
19
+
20
+ def deploy_as_space(
21
+ space_id: str,
22
+ dataset_id: str | None = None,
23
+ ):
24
+ if (
25
+ os.getenv("SYSTEM") == "spaces"
26
+ ): # in case a repo with this function is uploaded to spaces
27
+ return
28
+
29
+ trackio_path = files("trackio")
30
+
31
+ hf_api = huggingface_hub.HfApi()
32
+
33
+ try:
34
+ huggingface_hub.create_repo(
35
+ space_id,
36
+ space_sdk="gradio",
37
+ repo_type="space",
38
+ exist_ok=True,
39
+ )
40
+ except HTTPError as e:
41
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
42
+ print("Need 'write' access token to create a Spaces repo.")
43
+ huggingface_hub.login(add_to_git_credential=False)
44
+ huggingface_hub.create_repo(
45
+ space_id,
46
+ space_sdk="gradio",
47
+ repo_type="space",
48
+ exist_ok=True,
49
+ )
50
+ else:
51
+ raise ValueError(f"Failed to create Space: {e}")
52
+
53
+ with open(Path(trackio_path, "README.md"), "r") as f:
54
+ readme_content = f.read()
55
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
56
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
57
+ hf_api.upload_file(
58
+ path_or_fileobj=readme_buffer,
59
+ path_in_repo="README.md",
60
+ repo_id=space_id,
61
+ repo_type="space",
62
+ )
63
+
64
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
65
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
66
+ requirements_content = """
67
+ pyarrow>=21.0
68
+ mediapy>=1.0.0
69
+ """
70
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
71
+ hf_api.upload_file(
72
+ path_or_fileobj=requirements_buffer,
73
+ path_in_repo="requirements.txt",
74
+ repo_id=space_id,
75
+ repo_type="space",
76
+ )
77
+
78
+ huggingface_hub.utils.disable_progress_bars()
79
+ hf_api.upload_folder(
80
+ repo_id=space_id,
81
+ repo_type="space",
82
+ folder_path=trackio_path,
83
+ ignore_patterns=["README.md"],
84
+ )
85
+
86
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DIR", PERSISTENT_STORAGE_DIR)
87
+ if hf_token := huggingface_hub.utils.get_token():
88
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
89
+ if dataset_id is not None:
90
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
91
+
92
+
93
+ def create_space_if_not_exists(
94
+ space_id: str,
95
+ dataset_id: str | None = None,
96
+ ) -> None:
97
+ """
98
+ Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable.
99
+
100
+ Args:
101
+ space_id: The ID of the Space to create.
102
+ dataset_id: The ID of the Dataset to add to the Space.
103
+ """
104
+ if "/" not in space_id:
105
+ raise ValueError(
106
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
107
+ )
108
+ if dataset_id is not None and "/" not in dataset_id:
109
+ raise ValueError(
110
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
111
+ )
112
+ try:
113
+ huggingface_hub.repo_info(space_id, repo_type="space")
114
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
115
+ if dataset_id is not None:
116
+ huggingface_hub.add_space_variable(
117
+ space_id, "TRACKIO_DATASET_ID", dataset_id
118
+ )
119
+ return
120
+ except RepositoryNotFoundError:
121
+ pass
122
+ except HTTPError as e:
123
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
124
+ print("Need 'write' access token to create a Spaces repo.")
125
+ huggingface_hub.login(add_to_git_credential=False)
126
+ huggingface_hub.add_space_variable(
127
+ space_id, "TRACKIO_DATASET_ID", dataset_id
128
+ )
129
+ else:
130
+ raise ValueError(f"Failed to create Space: {e}")
131
+
132
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
133
+ deploy_as_space(space_id, dataset_id)
134
+
135
+
136
+ def wait_until_space_exists(
137
+ space_id: str,
138
+ ) -> None:
139
+ """
140
+ Blocks the current thread until the space exists.
141
+ May raise a TimeoutError if this takes quite a while.
142
+
143
+ Args:
144
+ space_id: The ID of the Space to wait for.
145
+ """
146
+ delay = 1
147
+ for _ in range(10):
148
+ try:
149
+ Client(space_id, verbose=False)
150
+ return
151
+ except (ReadTimeout, ValueError):
152
+ time.sleep(delay)
153
+ delay = min(delay * 2, 30)
154
+ raise TimeoutError("Waiting for space to exist took longer than expected")
155
+
156
+
157
+ def upload_db_to_space(project: str, space_id: str) -> None:
158
+ """
159
+ Uploads the database of a local Trackio project to a Hugging Face Space.
160
+
161
+ Args:
162
+ project: The name of the project to upload.
163
+ space_id: The ID of the Space to upload to.
164
+ """
165
+ db_path = SQLiteStorage.get_project_db_path(project)
166
+ client = Client(space_id, verbose=False)
167
+ client.predict(
168
+ api_name="/upload_db_to_space",
169
+ project=project,
170
+ uploaded_db=handle_file(db_path),
171
+ hf_token=huggingface_hub.utils.get_token(),
172
+ )
dummy_commit_scheduler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A dummy object to fit the interface of huggingface_hub's CommitScheduler
2
+ class DummyCommitSchedulerLock:
3
+ def __enter__(self):
4
+ return None
5
+
6
+ def __exit__(self, exception_type, exception_value, exception_traceback):
7
+ pass
8
+
9
+
10
+ class DummyCommitScheduler:
11
+ def __init__(self):
12
+ self.lock = DummyCommitSchedulerLock()
file_storage.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ try: # absolute imports when installed
4
+ from trackio.utils import MEDIA_DIR
5
+ except ImportError: # relative imports for local execution on Spaces
6
+ from utils import MEDIA_DIR
7
+
8
+
9
+ class FileStorage:
10
+ @staticmethod
11
+ def get_project_media_path(
12
+ project: str,
13
+ run: str | None = None,
14
+ step: int | None = None,
15
+ filename: str | None = None,
16
+ ) -> Path:
17
+ if filename is not None and step is None:
18
+ raise ValueError("filename requires step")
19
+ if step is not None and run is None:
20
+ raise ValueError("step requires run")
21
+
22
+ path = MEDIA_DIR / project
23
+ if run:
24
+ path /= run
25
+ if step is not None:
26
+ path /= str(step)
27
+ if filename:
28
+ path /= filename
29
+ return path
30
+
31
+ @staticmethod
32
+ def init_project_media_path(
33
+ project: str, run: str | None = None, step: int | None = None
34
+ ) -> Path:
35
+ path = FileStorage.get_project_media_path(project, run, step)
36
+ path.mkdir(parents=True, exist_ok=True)
37
+ return path
imports.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from trackio import deploy, utils
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+
9
+
10
+ def import_csv(
11
+ csv_path: str | Path,
12
+ project: str,
13
+ name: str | None = None,
14
+ space_id: str | None = None,
15
+ dataset_id: str | None = None,
16
+ ) -> None:
17
+ """
18
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
19
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
20
+ treated as metrics. It should also include a header row with the column names.
21
+
22
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
23
+
24
+ Args:
25
+ csv_path (`str` or `Path`):
26
+ The str or Path to the CSV file to import.
27
+ project (`str`):
28
+ The name of the project to import the CSV file into. Must not be an existing
29
+ project.
30
+ name (`str` or `None`, *optional*, defaults to `None`):
31
+ The name of the Run to import the CSV file into. If not provided, a default
32
+ name will be generated.
33
+ name (`str` or `None`, *optional*, defaults to `None`):
34
+ The name of the run (if not provided, a default name will be generated).
35
+ space_id (`str` or `None`, *optional*, defaults to `None`):
36
+ If provided, the project will be logged to a Hugging Face Space instead of a
37
+ local directory. Should be a complete Space name like `"username/reponame"`
38
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
39
+ be created in the currently-logged-in Hugging Face user's namespace. If the
40
+ Space does not exist, it will be created. If the Space already exists, the
41
+ project will be logged to it.
42
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
43
+ If provided, a persistent Hugging Face Dataset will be created and the
44
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
45
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
46
+ `"datasetname"` in which case the Dataset will be created in the
47
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
48
+ exist, it will be created. If the Dataset already exists, the project will
49
+ be appended to it. If not provided, the metrics will be logged to a local
50
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
51
+ will be automatically created with the same name as the Space but with the
52
+ `"_dataset"` suffix.
53
+ """
54
+ if SQLiteStorage.get_runs(project):
55
+ raise ValueError(
56
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
57
+ )
58
+
59
+ csv_path = Path(csv_path)
60
+ if not csv_path.exists():
61
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
62
+
63
+ df = pd.read_csv(csv_path)
64
+ if df.empty:
65
+ raise ValueError("CSV file is empty")
66
+
67
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
68
+ df = df.rename(columns=column_mapping)
69
+
70
+ step_column = None
71
+ for col in df.columns:
72
+ if col.lower() == "step":
73
+ step_column = col
74
+ break
75
+
76
+ if step_column is None:
77
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
78
+
79
+ if name is None:
80
+ name = csv_path.stem
81
+
82
+ metrics_list = []
83
+ steps = []
84
+ timestamps = []
85
+
86
+ numeric_columns = []
87
+ for column in df.columns:
88
+ if column == step_column:
89
+ continue
90
+ if column == "timestamp":
91
+ continue
92
+
93
+ try:
94
+ pd.to_numeric(df[column], errors="raise")
95
+ numeric_columns.append(column)
96
+ except (ValueError, TypeError):
97
+ continue
98
+
99
+ for _, row in df.iterrows():
100
+ metrics = {}
101
+ for column in numeric_columns:
102
+ value = row[column]
103
+ if bool(pd.notna(value)):
104
+ metrics[column] = float(value)
105
+
106
+ if metrics:
107
+ metrics_list.append(metrics)
108
+ steps.append(int(row[step_column]))
109
+
110
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
111
+ timestamps.append(str(row["timestamp"]))
112
+ else:
113
+ timestamps.append("")
114
+
115
+ if metrics_list:
116
+ SQLiteStorage.bulk_log(
117
+ project=project,
118
+ run=name,
119
+ metrics_list=metrics_list,
120
+ steps=steps,
121
+ timestamps=timestamps,
122
+ )
123
+
124
+ print(
125
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
126
+ )
127
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
128
+
129
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
130
+ if dataset_id is not None:
131
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
132
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
133
+
134
+ if space_id is None:
135
+ utils.print_dashboard_instructions(project)
136
+ else:
137
+ deploy.create_space_if_not_exists(space_id, dataset_id)
138
+ deploy.wait_until_space_exists(space_id)
139
+ deploy.upload_db_to_space(project, space_id)
140
+ print(
141
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
142
+ )
143
+
144
+
145
+ def import_tf_events(
146
+ log_dir: str | Path,
147
+ project: str,
148
+ name: str | None = None,
149
+ space_id: str | None = None,
150
+ dataset_id: str | None = None,
151
+ ) -> None:
152
+ """
153
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
154
+ subdirectory in the log directory will be imported as a separate run.
155
+
156
+ Args:
157
+ log_dir (`str` or `Path`):
158
+ The str or Path to the directory containing TensorFlow Events files.
159
+ project (`str`):
160
+ The name of the project to import the TensorFlow Events files into. Must not
161
+ be an existing project.
162
+ name (`str` or `None`, *optional*, defaults to `None`):
163
+ The name prefix for runs (if not provided, will use directory names). Each
164
+ subdirectory will create a separate run.
165
+ space_id (`str` or `None`, *optional*, defaults to `None`):
166
+ If provided, the project will be logged to a Hugging Face Space instead of a
167
+ local directory. Should be a complete Space name like `"username/reponame"`
168
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
169
+ be created in the currently-logged-in Hugging Face user's namespace. If the
170
+ Space does not exist, it will be created. If the Space already exists, the
171
+ project will be logged to it.
172
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
173
+ If provided, a persistent Hugging Face Dataset will be created and the
174
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
175
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
176
+ `"datasetname"` in which case the Dataset will be created in the
177
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
178
+ exist, it will be created. If the Dataset already exists, the project will
179
+ be appended to it. If not provided, the metrics will be logged to a local
180
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
181
+ will be automatically created with the same name as the Space but with the
182
+ `"_dataset"` suffix.
183
+ """
184
+ try:
185
+ from tbparse import SummaryReader
186
+ except ImportError:
187
+ raise ImportError(
188
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
189
+ )
190
+
191
+ if SQLiteStorage.get_runs(project):
192
+ raise ValueError(
193
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
194
+ )
195
+
196
+ path = Path(log_dir)
197
+ if not path.exists():
198
+ raise FileNotFoundError(f"TF events directory not found: {path}")
199
+
200
+ # Use tbparse to read all tfevents files in the directory structure
201
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
202
+ df = reader.scalars
203
+
204
+ if df.empty:
205
+ raise ValueError(f"No TensorFlow events data found in {path}")
206
+
207
+ total_imported = 0
208
+ imported_runs = []
209
+
210
+ # Group by dir_name to create separate runs
211
+ for dir_name, group_df in df.groupby("dir_name"):
212
+ try:
213
+ # Determine run name based on directory name
214
+ if dir_name == "":
215
+ run_name = "main" # For files in the root directory
216
+ else:
217
+ run_name = dir_name # Use directory name
218
+
219
+ if name:
220
+ run_name = f"{name}_{run_name}"
221
+
222
+ if group_df.empty:
223
+ print(f"* Skipping directory {dir_name}: no scalar data found")
224
+ continue
225
+
226
+ metrics_list = []
227
+ steps = []
228
+ timestamps = []
229
+
230
+ for _, row in group_df.iterrows():
231
+ # Convert row values to appropriate types
232
+ tag = str(row["tag"])
233
+ value = float(row["value"])
234
+ step = int(row["step"])
235
+
236
+ metrics = {tag: value}
237
+ metrics_list.append(metrics)
238
+ steps.append(step)
239
+
240
+ # Use wall_time if present, else fallback
241
+ if "wall_time" in group_df.columns and not bool(
242
+ pd.isna(row["wall_time"])
243
+ ):
244
+ timestamps.append(str(row["wall_time"]))
245
+ else:
246
+ timestamps.append("")
247
+
248
+ if metrics_list:
249
+ SQLiteStorage.bulk_log(
250
+ project=project,
251
+ run=str(run_name),
252
+ metrics_list=metrics_list,
253
+ steps=steps,
254
+ timestamps=timestamps,
255
+ )
256
+
257
+ total_imported += len(metrics_list)
258
+ imported_runs.append(run_name)
259
+
260
+ print(
261
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
262
+ )
263
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
264
+
265
+ except Exception as e:
266
+ print(f"* Error processing directory {dir_name}: {e}")
267
+ continue
268
+
269
+ if not imported_runs:
270
+ raise ValueError("No valid TensorFlow events data could be imported")
271
+
272
+ print(f"* Total imported events: {total_imported}")
273
+ print(f"* Created runs: {', '.join(imported_runs)}")
274
+
275
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
276
+ if dataset_id is not None:
277
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
278
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
279
+
280
+ if space_id is None:
281
+ utils.print_dashboard_instructions(project)
282
+ else:
283
+ deploy.create_space_if_not_exists(space_id, dataset_id)
284
+ deploy.wait_until_space_exists(space_id)
285
+ deploy.upload_db_to_space(project, space_id)
286
+ print(
287
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
288
+ )
media.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import uuid
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import mediapy as mp
9
+ import numpy as np
10
+ from PIL import Image as PILImage
11
+
12
+ try: # absolute imports when installed
13
+ from trackio.file_storage import FileStorage
14
+ from trackio.utils import MEDIA_DIR
15
+ except ImportError: # relative imports for local execution on Spaces
16
+ from file_storage import FileStorage
17
+ from utils import MEDIA_DIR
18
+
19
+
20
+ class TrackioMedia(ABC):
21
+ """
22
+ Abstract base class for Trackio media objects
23
+ Provides shared functionality for file handling and serialization.
24
+ """
25
+
26
+ TYPE: str
27
+
28
+ def __init_subclass__(cls, **kwargs):
29
+ """Ensure subclasses define the TYPE attribute."""
30
+ super().__init_subclass__(**kwargs)
31
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
32
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
33
+
34
+ def __init__(self, value, caption: str | None = None):
35
+ self.caption = caption
36
+ self._value = value
37
+ self._file_path: Path | None = None
38
+
39
+ # Validate file existence for string/Path inputs
40
+ if isinstance(self._value, str | Path):
41
+ if not os.path.isfile(self._value):
42
+ raise ValueError(f"File not found: {self._value}")
43
+
44
+ def _file_extension(self) -> str:
45
+ if self._file_path:
46
+ return self._file_path.suffix[1:].lower()
47
+ if isinstance(self._value, str | Path):
48
+ path = Path(self._value)
49
+ return path.suffix[1:].lower()
50
+ if hasattr(self, "_format") and self._format:
51
+ return self._format
52
+ return "unknown"
53
+
54
+ def _get_relative_file_path(self) -> Path | None:
55
+ return self._file_path
56
+
57
+ def _get_absolute_file_path(self) -> Path | None:
58
+ if self._file_path:
59
+ return MEDIA_DIR / self._file_path
60
+ return None
61
+
62
+ def _save(self, project: str, run: str, step: int = 0):
63
+ if self._file_path:
64
+ return
65
+
66
+ media_dir = FileStorage.init_project_media_path(project, run, step)
67
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
68
+ file_path = media_dir / filename
69
+
70
+ # Delegate to subclass-specific save logic
71
+ self._save_media(file_path)
72
+
73
+ self._file_path = file_path.relative_to(MEDIA_DIR)
74
+
75
+ @abstractmethod
76
+ def _save_media(self, file_path: Path):
77
+ """
78
+ Performs the actual media saving logic.
79
+ """
80
+ pass
81
+
82
+ def _to_dict(self) -> dict:
83
+ if not self._file_path:
84
+ raise ValueError("Media must be saved to file before serialization")
85
+ return {
86
+ "_type": self.TYPE,
87
+ "file_path": str(self._get_relative_file_path()),
88
+ "caption": self.caption,
89
+ }
90
+
91
+
92
+ TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
93
+
94
+
95
+ class TrackioImage(TrackioMedia):
96
+ """
97
+ Initializes an Image object.
98
+
99
+ Args:
100
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*, defaults to `None`):
101
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
102
+ caption (`str`, *optional*, defaults to `None`):
103
+ A string caption for the image.
104
+ """
105
+
106
+ TYPE = "trackio.image"
107
+
108
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
109
+ super().__init__(value, caption)
110
+ self._format: str | None = None
111
+
112
+ if (
113
+ isinstance(self._value, np.ndarray | PILImage.Image)
114
+ and self._format is None
115
+ ):
116
+ self._format = "png"
117
+
118
+ def _as_pil(self) -> PILImage.Image | None:
119
+ try:
120
+ if isinstance(self._value, np.ndarray):
121
+ arr = np.asarray(self._value).astype("uint8")
122
+ return PILImage.fromarray(arr).convert("RGBA")
123
+ if isinstance(self._value, PILImage.Image):
124
+ return self._value.convert("RGBA")
125
+ except Exception as e:
126
+ raise ValueError(f"Failed to process image data: {self._value}") from e
127
+ return None
128
+
129
+ def _save_media(self, file_path: Path):
130
+ if pil := self._as_pil():
131
+ pil.save(file_path, format=self._format)
132
+ elif isinstance(self._value, str | Path):
133
+ if os.path.isfile(self._value):
134
+ shutil.copy(self._value, file_path)
135
+ else:
136
+ raise ValueError(f"File not found: {self._value}")
137
+
138
+
139
+ TrackioVideoSourceType = str | Path | np.ndarray
140
+ TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
141
+
142
+
143
+ class TrackioVideo(TrackioMedia):
144
+ """
145
+ Initializes a Video object.
146
+
147
+ Args:
148
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*, defaults to `None`):
149
+ A path to a video file, or a numpy array of shape (frames, channels, height, width) or (batch, frames, channels, height, width).
150
+ caption (`str`, *optional*, defaults to `None`):
151
+ A string caption for the video.
152
+ fps (`int`, *optional*, defaults to `None`):
153
+ Frames per second for the video. Only relevant when using a numpy array.
154
+ format (`Literal["gif", "mp4", "webm"]`, *optional*, defaults to `None`):
155
+ Video format ("gif", "mp4", or "webm"). Only relevant when using a numpy array.
156
+ """
157
+
158
+ TYPE = "trackio.video"
159
+
160
+ def __init__(
161
+ self,
162
+ value: TrackioVideoSourceType,
163
+ caption: str | None = None,
164
+ fps: int | None = None,
165
+ format: TrackioVideoFormatType | None = None,
166
+ ):
167
+ super().__init__(value, caption)
168
+ self._fps = fps
169
+ self._format = format
170
+ if isinstance(self._value, np.ndarray) and self._format is None:
171
+ self._format = "gif"
172
+
173
+ @property
174
+ def _codec(self) -> str | None:
175
+ match self._format:
176
+ case "gif":
177
+ return "gif"
178
+ case "mp4":
179
+ return "h264"
180
+ case "webm":
181
+ return "vp9"
182
+ case _:
183
+ return None
184
+
185
+ def _save_media(self, file_path: Path):
186
+ if isinstance(self._value, np.ndarray):
187
+ video = TrackioVideo._process_ndarray(self._value)
188
+ mp.write_video(file_path, video, fps=self._fps, codec=self._codec)
189
+ elif isinstance(self._value, str | Path):
190
+ if os.path.isfile(self._value):
191
+ shutil.copy(self._value, file_path)
192
+ else:
193
+ raise ValueError(f"File not found: {self._value}")
194
+
195
+ @staticmethod
196
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
197
+ # Verify value is either 4D (single video) or 5D array (batched videos).
198
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
199
+ if value.ndim < 4:
200
+ raise ValueError(
201
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
202
+ )
203
+ if value.ndim > 5:
204
+ raise ValueError(
205
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
206
+ )
207
+ if value.ndim == 4:
208
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
209
+ value = value[np.newaxis, ...]
210
+
211
+ value = TrackioVideo._tile_batched_videos(value)
212
+ return value
213
+
214
+ @staticmethod
215
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
216
+ """
217
+ Tiles a batch of videos into a grid of videos.
218
+
219
+ Input format: (batch, frames, channels, height, width) - original FCHW format
220
+ Output format: (frames, total_height, total_width, channels)
221
+ """
222
+ batch_size, frames, channels, height, width = video.shape
223
+
224
+ next_pow2 = 1 << (batch_size - 1).bit_length()
225
+ if batch_size != next_pow2:
226
+ pad_len = next_pow2 - batch_size
227
+ pad_shape = (pad_len, frames, channels, height, width)
228
+ padding = np.zeros(pad_shape, dtype=video.dtype)
229
+ video = np.concatenate((video, padding), axis=0)
230
+ batch_size = next_pow2
231
+
232
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
233
+ n_cols = batch_size // n_rows
234
+
235
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
236
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
237
+
238
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
239
+ video = video.transpose(2, 0, 4, 1, 5, 3)
240
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
241
+ return video
py.typed ADDED
File without changes
run.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ import huggingface_hub
5
+ from gradio_client import Client, handle_file
6
+
7
+ from trackio.media import TrackioMedia
8
+ from trackio.sqlite_storage import SQLiteStorage
9
+ from trackio.table import Table
10
+ from trackio.typehints import LogEntry, UploadEntry
11
+ from trackio.utils import (
12
+ RESERVED_KEYS,
13
+ fibo,
14
+ generate_readable_name,
15
+ serialize_values,
16
+ )
17
+
18
+ BATCH_SEND_INTERVAL = 0.5
19
+
20
+
21
+ class Run:
22
+ def __init__(
23
+ self,
24
+ url: str,
25
+ project: str,
26
+ client: Client | None,
27
+ name: str | None = None,
28
+ config: dict | None = None,
29
+ space_id: str | None = None,
30
+ ):
31
+ self.url = url
32
+ self.project = project
33
+ self._client_lock = threading.Lock()
34
+ self._client_thread = None
35
+ self._client = client
36
+ self._space_id = space_id
37
+ self.name = name or generate_readable_name(
38
+ SQLiteStorage.get_runs(project), space_id
39
+ )
40
+ self.config = config or {}
41
+ self._queued_logs: list[LogEntry] = []
42
+ self._queued_uploads: list[UploadEntry] = []
43
+ self._stop_flag = threading.Event()
44
+
45
+ self._client_thread = threading.Thread(target=self._init_client_background)
46
+ self._client_thread.daemon = True
47
+ self._client_thread.start()
48
+
49
+ def _batch_sender(self):
50
+ """Send batched logs every BATCH_SEND_INTERVAL."""
51
+ while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
52
+ # If the stop flag has been set, then just quickly send all
53
+ # the logs and exit.
54
+ if not self._stop_flag.is_set():
55
+ time.sleep(BATCH_SEND_INTERVAL)
56
+
57
+ with self._client_lock:
58
+ if self._client is None:
59
+ return
60
+ if self._queued_logs:
61
+ logs_to_send = self._queued_logs.copy()
62
+ self._queued_logs.clear()
63
+ self._client.predict(
64
+ api_name="/bulk_log",
65
+ logs=logs_to_send,
66
+ hf_token=huggingface_hub.utils.get_token(),
67
+ )
68
+ if self._queued_uploads:
69
+ uploads_to_send = self._queued_uploads.copy()
70
+ self._queued_uploads.clear()
71
+ self._client.predict(
72
+ api_name="/bulk_upload_media",
73
+ uploads=uploads_to_send,
74
+ hf_token=huggingface_hub.utils.get_token(),
75
+ )
76
+
77
+ def _init_client_background(self):
78
+ if self._client is None:
79
+ fib = fibo()
80
+ for sleep_coefficient in fib:
81
+ try:
82
+ client = Client(self.url, verbose=False)
83
+
84
+ with self._client_lock:
85
+ self._client = client
86
+ break
87
+ except Exception:
88
+ pass
89
+ if sleep_coefficient is not None:
90
+ time.sleep(0.1 * sleep_coefficient)
91
+
92
+ self._batch_sender()
93
+
94
+ def _process_media(self, metrics, step: int | None) -> dict:
95
+ """
96
+ Serialize media in metrics and upload to space if needed.
97
+ """
98
+ serializable_metrics = {}
99
+ if not step:
100
+ step = 0
101
+ for key, value in metrics.items():
102
+ if isinstance(value, TrackioMedia):
103
+ value._save(self.project, self.name, step)
104
+ serializable_metrics[key] = value._to_dict()
105
+ if self._space_id:
106
+ # Upload local media when deploying to space
107
+ upload_entry: UploadEntry = {
108
+ "project": self.project,
109
+ "run": self.name,
110
+ "step": step,
111
+ "uploaded_file": handle_file(value._get_absolute_file_path()),
112
+ }
113
+ with self._client_lock:
114
+ self._queued_uploads.append(upload_entry)
115
+ else:
116
+ serializable_metrics[key] = value
117
+ return serializable_metrics
118
+
119
+ @staticmethod
120
+ def _replace_tables(metrics):
121
+ for k, v in metrics.items():
122
+ if isinstance(v, Table):
123
+ metrics[k] = v._to_dict()
124
+
125
+ def log(self, metrics: dict, step: int | None = None):
126
+ for k in metrics.keys():
127
+ if k in RESERVED_KEYS or k.startswith("__"):
128
+ raise ValueError(
129
+ f"Please do not use this reserved key as a metric: {k}"
130
+ )
131
+ Run._replace_tables(metrics)
132
+
133
+ metrics = self._process_media(metrics, step)
134
+ metrics = serialize_values(metrics)
135
+ log_entry: LogEntry = {
136
+ "project": self.project,
137
+ "run": self.name,
138
+ "metrics": metrics,
139
+ "step": step,
140
+ }
141
+
142
+ with self._client_lock:
143
+ self._queued_logs.append(log_entry)
144
+
145
+ def finish(self):
146
+ """Cleanup when run is finished."""
147
+ self._stop_flag.set()
148
+
149
+ # Wait for the batch sender to finish before joining the client thread.
150
+ time.sleep(2 * BATCH_SEND_INTERVAL)
151
+
152
+ if self._client_thread is not None:
153
+ print(
154
+ f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)"
155
+ )
156
+ self._client_thread.join()
sqlite_storage.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sqlite3
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from threading import Lock
7
+
8
+ import huggingface_hub as hf
9
+ import pandas as pd
10
+
11
+ try: # absolute imports when installed
12
+ from trackio.commit_scheduler import CommitScheduler
13
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
14
+ from trackio.utils import (
15
+ TRACKIO_DIR,
16
+ deserialize_values,
17
+ serialize_values,
18
+ )
19
+ except Exception: # relative imports for local execution on Spaces
20
+ from commit_scheduler import CommitScheduler
21
+ from dummy_commit_scheduler import DummyCommitScheduler
22
+ from utils import TRACKIO_DIR, deserialize_values, serialize_values
23
+
24
+
25
+ class SQLiteStorage:
26
+ _dataset_import_attempted = False
27
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
28
+ _scheduler_lock = Lock()
29
+
30
+ @staticmethod
31
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
32
+ conn = sqlite3.connect(str(db_path))
33
+ conn.row_factory = sqlite3.Row
34
+ return conn
35
+
36
+ @staticmethod
37
+ def get_project_db_filename(project: str) -> Path:
38
+ """Get the database filename for a specific project."""
39
+ safe_project_name = "".join(
40
+ c for c in project if c.isalnum() or c in ("-", "_")
41
+ ).rstrip()
42
+ if not safe_project_name:
43
+ safe_project_name = "default"
44
+ return f"{safe_project_name}.db"
45
+
46
+ @staticmethod
47
+ def get_project_db_path(project: str) -> Path:
48
+ """Get the database path for a specific project."""
49
+ filename = SQLiteStorage.get_project_db_filename(project)
50
+ return TRACKIO_DIR / filename
51
+
52
+ @staticmethod
53
+ def init_db(project: str) -> Path:
54
+ """
55
+ Initialize the SQLite database with required tables.
56
+ If there is a dataset ID provided, copies from that dataset instead.
57
+ Returns the database path.
58
+ """
59
+ db_path = SQLiteStorage.get_project_db_path(project)
60
+ db_path.parent.mkdir(parents=True, exist_ok=True)
61
+ with SQLiteStorage.get_scheduler().lock:
62
+ with sqlite3.connect(db_path) as conn:
63
+ cursor = conn.cursor()
64
+ cursor.execute("""
65
+ CREATE TABLE IF NOT EXISTS metrics (
66
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
67
+ timestamp TEXT NOT NULL,
68
+ run_name TEXT NOT NULL,
69
+ step INTEGER NOT NULL,
70
+ metrics TEXT NOT NULL
71
+ )
72
+ """)
73
+ cursor.execute(
74
+ """
75
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
76
+ ON metrics(run_name, step)
77
+ """
78
+ )
79
+ conn.commit()
80
+ return db_path
81
+
82
+ @staticmethod
83
+ def export_to_parquet():
84
+ """
85
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
86
+ """
87
+ # don't attempt to export (potentially wrong/blank) data before importing for the first time
88
+ if not SQLiteStorage._dataset_import_attempted:
89
+ return
90
+ all_paths = os.listdir(TRACKIO_DIR)
91
+ db_paths = [f for f in all_paths if f.endswith(".db")]
92
+ for db_path in db_paths:
93
+ db_path = TRACKIO_DIR / db_path
94
+ parquet_path = db_path.with_suffix(".parquet")
95
+ if (not parquet_path.exists()) or (
96
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
97
+ ):
98
+ with sqlite3.connect(db_path) as conn:
99
+ df = pd.read_sql("SELECT * from metrics", conn)
100
+ # break out the single JSON metrics column into individual columns
101
+ metrics = df["metrics"].copy()
102
+ metrics = pd.DataFrame(
103
+ metrics.apply(
104
+ lambda x: deserialize_values(json.loads(x))
105
+ ).values.tolist(),
106
+ index=df.index,
107
+ )
108
+ del df["metrics"]
109
+ for col in metrics.columns:
110
+ df[col] = metrics[col]
111
+ df.to_parquet(parquet_path)
112
+
113
+ @staticmethod
114
+ def import_from_parquet():
115
+ """
116
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
117
+ """
118
+ all_paths = os.listdir(TRACKIO_DIR)
119
+ parquet_paths = [f for f in all_paths if f.endswith(".parquet")]
120
+ for parquet_path in parquet_paths:
121
+ parquet_path = TRACKIO_DIR / parquet_path
122
+ db_path = parquet_path.with_suffix(".db")
123
+ df = pd.read_parquet(parquet_path)
124
+ with sqlite3.connect(db_path) as conn:
125
+ # fix up df to have a single JSON metrics column
126
+ if "metrics" not in df.columns:
127
+ # separate other columns from metrics
128
+ metrics = df.copy()
129
+ other_cols = ["id", "timestamp", "run_name", "step"]
130
+ df = df[other_cols]
131
+ for col in other_cols:
132
+ del metrics[col]
133
+ # combine them all into a single metrics col
134
+ metrics = json.loads(metrics.to_json(orient="records"))
135
+ df["metrics"] = [
136
+ json.dumps(serialize_values(row)) for row in metrics
137
+ ]
138
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
139
+
140
+ @staticmethod
141
+ def get_scheduler():
142
+ """
143
+ Get the scheduler for the database based on the environment variables.
144
+ This applies to both local and Spaces.
145
+ """
146
+ with SQLiteStorage._scheduler_lock:
147
+ if SQLiteStorage._current_scheduler is not None:
148
+ return SQLiteStorage._current_scheduler
149
+ hf_token = os.environ.get("HF_TOKEN")
150
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
151
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
152
+ if dataset_id is None or space_repo_name is None:
153
+ scheduler = DummyCommitScheduler()
154
+ else:
155
+ scheduler = CommitScheduler(
156
+ repo_id=dataset_id,
157
+ repo_type="dataset",
158
+ folder_path=TRACKIO_DIR,
159
+ private=True,
160
+ allow_patterns=["*.parquet", "media/**/*"],
161
+ squash_history=True,
162
+ token=hf_token,
163
+ on_before_commit=SQLiteStorage.export_to_parquet,
164
+ )
165
+ SQLiteStorage._current_scheduler = scheduler
166
+ return scheduler
167
+
168
+ @staticmethod
169
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
170
+ """
171
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
172
+ and is set up with the correct tables. It also uses the scheduler to lock the database so
173
+ that there is no race condition when logging / syncing to the Hugging Face Dataset.
174
+ """
175
+ db_path = SQLiteStorage.init_db(project)
176
+
177
+ with SQLiteStorage.get_scheduler().lock:
178
+ with SQLiteStorage._get_connection(db_path) as conn:
179
+ cursor = conn.cursor()
180
+
181
+ cursor.execute(
182
+ """
183
+ SELECT MAX(step)
184
+ FROM metrics
185
+ WHERE run_name = ?
186
+ """,
187
+ (run,),
188
+ )
189
+ last_step = cursor.fetchone()[0]
190
+ if step is None:
191
+ current_step = 0 if last_step is None else last_step + 1
192
+ else:
193
+ current_step = step
194
+
195
+ current_timestamp = datetime.now().isoformat()
196
+
197
+ cursor.execute(
198
+ """
199
+ INSERT INTO metrics
200
+ (timestamp, run_name, step, metrics)
201
+ VALUES (?, ?, ?, ?)
202
+ """,
203
+ (
204
+ current_timestamp,
205
+ run,
206
+ current_step,
207
+ json.dumps(serialize_values(metrics)),
208
+ ),
209
+ )
210
+ conn.commit()
211
+
212
+ @staticmethod
213
+ def bulk_log(
214
+ project: str,
215
+ run: str,
216
+ metrics_list: list[dict],
217
+ steps: list[int] | None = None,
218
+ timestamps: list[str] | None = None,
219
+ ):
220
+ """Bulk log metrics to the database with specified steps and timestamps."""
221
+ if not metrics_list:
222
+ return
223
+
224
+ if timestamps is None:
225
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
226
+
227
+ db_path = SQLiteStorage.init_db(project)
228
+ with SQLiteStorage.get_scheduler().lock:
229
+ with SQLiteStorage._get_connection(db_path) as conn:
230
+ cursor = conn.cursor()
231
+
232
+ if steps is None:
233
+ steps = list(range(len(metrics_list)))
234
+ elif any(s is None for s in steps):
235
+ cursor.execute(
236
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
237
+ )
238
+ last_step = cursor.fetchone()[0]
239
+ current_step = 0 if last_step is None else last_step + 1
240
+
241
+ processed_steps = []
242
+ for step in steps:
243
+ if step is None:
244
+ processed_steps.append(current_step)
245
+ current_step += 1
246
+ else:
247
+ processed_steps.append(step)
248
+ steps = processed_steps
249
+
250
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
251
+ timestamps
252
+ ):
253
+ raise ValueError(
254
+ "metrics_list, steps, and timestamps must have the same length"
255
+ )
256
+
257
+ data = []
258
+ for i, metrics in enumerate(metrics_list):
259
+ data.append(
260
+ (
261
+ timestamps[i],
262
+ run,
263
+ steps[i],
264
+ json.dumps(serialize_values(metrics)),
265
+ )
266
+ )
267
+
268
+ cursor.executemany(
269
+ """
270
+ INSERT INTO metrics
271
+ (timestamp, run_name, step, metrics)
272
+ VALUES (?, ?, ?, ?)
273
+ """,
274
+ data,
275
+ )
276
+ conn.commit()
277
+
278
+ @staticmethod
279
+ def get_logs(project: str, run: str) -> list[dict]:
280
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
281
+ db_path = SQLiteStorage.get_project_db_path(project)
282
+ if not db_path.exists():
283
+ return []
284
+
285
+ with SQLiteStorage._get_connection(db_path) as conn:
286
+ cursor = conn.cursor()
287
+ cursor.execute(
288
+ """
289
+ SELECT timestamp, step, metrics
290
+ FROM metrics
291
+ WHERE run_name = ?
292
+ ORDER BY timestamp
293
+ """,
294
+ (run,),
295
+ )
296
+
297
+ rows = cursor.fetchall()
298
+ results = []
299
+ for row in rows:
300
+ metrics = json.loads(row["metrics"])
301
+ metrics = deserialize_values(metrics)
302
+ metrics["timestamp"] = row["timestamp"]
303
+ metrics["step"] = row["step"]
304
+ results.append(metrics)
305
+ return results
306
+
307
+ @staticmethod
308
+ def load_from_dataset():
309
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
310
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
311
+ if dataset_id is not None and space_repo_name is not None:
312
+ hfapi = hf.HfApi()
313
+ updated = False
314
+ if not TRACKIO_DIR.exists():
315
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
316
+ with SQLiteStorage.get_scheduler().lock:
317
+ try:
318
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
319
+ for file in files:
320
+ is_media = file.startswith("media/")
321
+ is_parquet = file.endswith(".parquet")
322
+ # Download parquet and media assets
323
+ if is_media or is_parquet:
324
+ hf.hf_hub_download(
325
+ dataset_id,
326
+ file,
327
+ repo_type="dataset",
328
+ local_dir=TRACKIO_DIR,
329
+ )
330
+ updated = True
331
+ except hf.errors.EntryNotFoundError:
332
+ pass
333
+ except hf.errors.RepositoryNotFoundError:
334
+ pass
335
+ if updated:
336
+ SQLiteStorage.import_from_parquet()
337
+ SQLiteStorage._dataset_import_attempted = True
338
+
339
+ @staticmethod
340
+ def get_projects() -> list[str]:
341
+ """
342
+ Get list of all projects by scanning the database files in the trackio directory.
343
+ """
344
+ if not SQLiteStorage._dataset_import_attempted:
345
+ SQLiteStorage.load_from_dataset()
346
+
347
+ projects: set[str] = set()
348
+ if not TRACKIO_DIR.exists():
349
+ return []
350
+
351
+ for db_file in TRACKIO_DIR.glob("*.db"):
352
+ project_name = db_file.stem
353
+ projects.add(project_name)
354
+ return sorted(projects)
355
+
356
+ @staticmethod
357
+ def get_runs(project: str) -> list[str]:
358
+ """Get list of all runs for a project."""
359
+ db_path = SQLiteStorage.get_project_db_path(project)
360
+ if not db_path.exists():
361
+ return []
362
+
363
+ with SQLiteStorage._get_connection(db_path) as conn:
364
+ cursor = conn.cursor()
365
+ cursor.execute(
366
+ "SELECT DISTINCT run_name FROM metrics",
367
+ )
368
+ return [row[0] for row in cursor.fetchall()]
369
+
370
+ @staticmethod
371
+ def get_max_steps_for_runs(project: str, runs: list[str]) -> dict[str, int]:
372
+ """Efficiently get the maximum step for multiple runs in a single query."""
373
+ db_path = SQLiteStorage.get_project_db_path(project)
374
+ if not db_path.exists():
375
+ return {run: 0 for run in runs}
376
+
377
+ with SQLiteStorage._get_connection(db_path) as conn:
378
+ cursor = conn.cursor()
379
+ placeholders = ",".join("?" * len(runs))
380
+ cursor.execute(
381
+ f"""
382
+ SELECT run_name, MAX(step) as max_step
383
+ FROM metrics
384
+ WHERE run_name IN ({placeholders})
385
+ GROUP BY run_name
386
+ """,
387
+ runs,
388
+ )
389
+
390
+ results = {run: 0 for run in runs} # Default to 0 for runs with no data
391
+ for row in cursor.fetchall():
392
+ results[row["run_name"]] = row["max_step"]
393
+
394
+ return results
395
+
396
+ def finish(self):
397
+ """Cleanup when run is finished."""
398
+ pass
table.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Optional, Union
2
+
3
+ from pandas import DataFrame
4
+
5
+
6
+ class Table:
7
+ """
8
+ Initializes a Table object.
9
+
10
+ Args:
11
+ columns (`list[str]`, *optional*, defaults to `None`):
12
+ Names of the columns in the table. Optional if `data` is provided. Not
13
+ expected if `dataframe` is provided. Currently ignored.
14
+ data (`list[list[Any]]`, *optional*, defaults to `None`):
15
+ 2D row-oriented array of values.
16
+ dataframe (`pandas.`DataFrame``, *optional*, defaults to `None`):
17
+ DataFrame object used to create the table. When set, `data` and `columns`
18
+ arguments are ignored.
19
+ rows (`list[list[any]]`, *optional*, defaults to `None`):
20
+ Currently ignored.
21
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
22
+ Currently ignored.
23
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
24
+ Currently ignored.
25
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
26
+ Currently ignored.
27
+ """
28
+
29
+ TYPE = "trackio.table"
30
+
31
+ def __init__(
32
+ self,
33
+ columns: Optional[list[str]] = None,
34
+ data: Optional[list[list[Any]]] = None,
35
+ dataframe: Optional[DataFrame] = None,
36
+ rows: Optional[list[list[Any]]] = None,
37
+ optional: Union[bool, list[bool]] = True,
38
+ allow_mixed_types: bool = False,
39
+ log_mode: Optional[
40
+ Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]
41
+ ] = "IMMUTABLE",
42
+ ):
43
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
44
+ # for now (like `rows`) they are included for API compat but don't do anything.
45
+
46
+ if dataframe is None:
47
+ self.data = data
48
+ else:
49
+ self.data = dataframe.to_dict(orient="records")
50
+
51
+ def _to_dict(self):
52
+ return {
53
+ "_type": self.TYPE,
54
+ "_value": self.data,
55
+ }
typehints.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypedDict
2
+
3
+ from gradio import FileData
4
+
5
+
6
+ class LogEntry(TypedDict):
7
+ project: str
8
+ run: str
9
+ metrics: dict[str, Any]
10
+ step: int | None
11
+
12
+
13
+ class UploadEntry(TypedDict):
14
+ project: str
15
+ run: str
16
+ step: int | None
17
+ uploaded_file: FileData
ui.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import gradio as gr
8
+ import huggingface_hub as hf
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ HfApi = hf.HfApi()
13
+
14
+ try:
15
+ import trackio.utils as utils
16
+ from trackio.file_storage import FileStorage
17
+ from trackio.media import TrackioImage, TrackioVideo
18
+ from trackio.sqlite_storage import SQLiteStorage
19
+ from trackio.table import Table
20
+ from trackio.typehints import LogEntry, UploadEntry
21
+ except: # noqa: E722
22
+ import utils
23
+ from file_storage import FileStorage
24
+ from media import TrackioImage, TrackioVideo
25
+ from sqlite_storage import SQLiteStorage
26
+ from table import Table
27
+ from typehints import LogEntry, UploadEntry
28
+
29
+
30
+ def get_project_info() -> str | None:
31
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
32
+ space_id = os.environ.get("SPACE_ID")
33
+ persistent_storage_enabled = os.environ.get(
34
+ "PERSISTANT_STORAGE_ENABLED"
35
+ ) # Space env name has a typo
36
+ if persistent_storage_enabled:
37
+ return "&#10024; Persistent Storage is enabled, logs are stored directly in this Space."
38
+ if dataset_id:
39
+ sync_status = utils.get_sync_status(SQLiteStorage.get_scheduler())
40
+ upgrade_message = f"New changes are synced every 5 min <span class='info-container'><input type='checkbox' class='info-checkbox' id='upgrade-info'><label for='upgrade-info' class='info-icon'>&#9432;</label><span class='info-expandable'> To avoid losing data between syncs, <a href='https://huggingface.co/spaces/{space_id}/settings' class='accent-link'>click here</a> to open this Space's settings and add Persistent Storage.</span></span>"
41
+ if sync_status is not None:
42
+ info = f"&#x21bb; Backed up {sync_status} min ago to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}"
43
+ else:
44
+ info = f"&#x21bb; Not backed up yet to <a href='https://huggingface.co/datasets/{dataset_id}' target='_blank' class='accent-link'>{dataset_id}</a> | {upgrade_message}"
45
+ return info
46
+ return None
47
+
48
+
49
+ def get_projects(request: gr.Request):
50
+ projects = SQLiteStorage.get_projects()
51
+ if project := request.query_params.get("project"):
52
+ interactive = False
53
+ else:
54
+ interactive = True
55
+ project = projects[0] if projects else None
56
+
57
+ return gr.Dropdown(
58
+ label="Project",
59
+ choices=projects,
60
+ value=project,
61
+ allow_custom_value=True,
62
+ interactive=interactive,
63
+ info=get_project_info(),
64
+ )
65
+
66
+
67
+ def get_runs(project) -> list[str]:
68
+ if not project:
69
+ return []
70
+ return SQLiteStorage.get_runs(project)
71
+
72
+
73
+ def get_available_metrics(project: str, runs: list[str]) -> list[str]:
74
+ """Get all available metrics across all runs for x-axis selection."""
75
+ if not project or not runs:
76
+ return ["step", "time"]
77
+
78
+ all_metrics = set()
79
+ for run in runs:
80
+ metrics = SQLiteStorage.get_logs(project, run)
81
+ if metrics:
82
+ df = pd.DataFrame(metrics)
83
+ numeric_cols = df.select_dtypes(include="number").columns
84
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
85
+ all_metrics.update(numeric_cols)
86
+
87
+ all_metrics.add("step")
88
+ all_metrics.add("time")
89
+
90
+ sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics))
91
+
92
+ result = ["step", "time"]
93
+ for metric in sorted_metrics:
94
+ if metric not in result:
95
+ result.append(metric)
96
+
97
+ return result
98
+
99
+
100
+ @dataclass
101
+ class MediaData:
102
+ caption: str | None
103
+ file_path: str
104
+
105
+
106
+ def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]:
107
+ media_by_key: dict[str, list[MediaData]] = {}
108
+ logs = sorted(logs, key=lambda x: x.get("step", 0))
109
+ for log in logs:
110
+ for key, value in log.items():
111
+ if isinstance(value, dict):
112
+ type = value.get("_type")
113
+ if type == TrackioImage.TYPE or type == TrackioVideo.TYPE:
114
+ if key not in media_by_key:
115
+ media_by_key[key] = []
116
+ try:
117
+ media_data = MediaData(
118
+ file_path=utils.MEDIA_DIR / value.get("file_path"),
119
+ caption=value.get("caption"),
120
+ )
121
+ media_by_key[key].append(media_data)
122
+ except Exception as e:
123
+ print(f"Media currently unavailable: {key}: {e}")
124
+ return media_by_key
125
+
126
+
127
+ def load_run_data(
128
+ project: str | None,
129
+ run: str | None,
130
+ smoothing_granularity: int,
131
+ x_axis: str,
132
+ log_scale: bool = False,
133
+ ) -> tuple[pd.DataFrame, dict]:
134
+ if not project or not run:
135
+ return None, None
136
+
137
+ logs = SQLiteStorage.get_logs(project, run)
138
+ if not logs:
139
+ return None, None
140
+
141
+ media = extract_media(logs)
142
+ df = pd.DataFrame(logs)
143
+
144
+ if "step" not in df.columns:
145
+ df["step"] = range(len(df))
146
+
147
+ if x_axis == "time" and "timestamp" in df.columns:
148
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
149
+ first_timestamp = df["timestamp"].min()
150
+ df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds()
151
+ x_column = "time"
152
+ elif x_axis == "step":
153
+ x_column = "step"
154
+ else:
155
+ x_column = x_axis
156
+
157
+ if log_scale and x_column in df.columns:
158
+ x_vals = df[x_column]
159
+ if (x_vals <= 0).any():
160
+ df[x_column] = np.log10(np.maximum(x_vals, 0) + 1)
161
+ else:
162
+ df[x_column] = np.log10(x_vals)
163
+
164
+ if smoothing_granularity > 0:
165
+ numeric_cols = df.select_dtypes(include="number").columns
166
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
167
+
168
+ df_original = df.copy()
169
+ df_original["run"] = f"{run}_original"
170
+ df_original["data_type"] = "original"
171
+
172
+ df_smoothed = df.copy()
173
+ window_size = max(3, min(smoothing_granularity, len(df)))
174
+ df_smoothed[numeric_cols] = (
175
+ df_smoothed[numeric_cols]
176
+ .rolling(window=window_size, center=True, min_periods=1)
177
+ .mean()
178
+ )
179
+ df_smoothed["run"] = f"{run}_smoothed"
180
+ df_smoothed["data_type"] = "smoothed"
181
+
182
+ combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
183
+ combined_df["x_axis"] = x_column
184
+ return combined_df, media
185
+ else:
186
+ df["run"] = run
187
+ df["data_type"] = "original"
188
+ df["x_axis"] = x_column
189
+ return df, media
190
+
191
+
192
+ def update_runs(project, filter_text, user_interacted_with_runs=False):
193
+ if project is None:
194
+ runs = []
195
+ num_runs = 0
196
+ else:
197
+ runs = get_runs(project)
198
+ num_runs = len(runs)
199
+ if filter_text:
200
+ runs = [r for r in runs if filter_text in r]
201
+ if not user_interacted_with_runs:
202
+ return gr.CheckboxGroup(choices=runs, value=runs), gr.Textbox(
203
+ label=f"Runs ({num_runs})"
204
+ )
205
+ else:
206
+ return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
207
+
208
+
209
+ def filter_runs(project, filter_text):
210
+ runs = get_runs(project)
211
+ runs = [r for r in runs if filter_text in r]
212
+ return gr.CheckboxGroup(choices=runs, value=runs)
213
+
214
+
215
+ def update_x_axis_choices(project, runs):
216
+ """Update x-axis dropdown choices based on available metrics."""
217
+ available_metrics = get_available_metrics(project, runs)
218
+ return gr.Dropdown(
219
+ label="X-axis",
220
+ choices=available_metrics,
221
+ value="step",
222
+ )
223
+
224
+
225
+ def toggle_timer(cb_value):
226
+ if cb_value:
227
+ return gr.Timer(active=True)
228
+ else:
229
+ return gr.Timer(active=False)
230
+
231
+
232
+ def check_auth(hf_token: str | None) -> None:
233
+ if os.getenv("SYSTEM") == "spaces": # if we are running in Spaces
234
+ # check auth token passed in
235
+ if hf_token is None:
236
+ raise PermissionError(
237
+ "Expected a HF_TOKEN to be provided when logging to a Space"
238
+ )
239
+ who = HfApi.whoami(hf_token)
240
+ access_token = who["auth"]["accessToken"]
241
+ owner_name = os.getenv("SPACE_AUTHOR_NAME")
242
+ repo_name = os.getenv("SPACE_REPO_NAME")
243
+ # make sure the token user is either the author of the space,
244
+ # or is a member of an org that is the author.
245
+ orgs = [o["name"] for o in who["orgs"]]
246
+ if owner_name != who["name"] and owner_name not in orgs:
247
+ raise PermissionError(
248
+ "Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space"
249
+ )
250
+ # reject fine-grained tokens without specific repo access
251
+ if access_token["role"] == "fineGrained":
252
+ matched = False
253
+ for item in access_token["fineGrained"]["scoped"]:
254
+ if (
255
+ item["entity"]["type"] == "space"
256
+ and item["entity"]["name"] == f"{owner_name}/{repo_name}"
257
+ and "repo.write" in item["permissions"]
258
+ ):
259
+ matched = True
260
+ break
261
+ if (
262
+ (
263
+ item["entity"]["type"] == "user"
264
+ or item["entity"]["type"] == "org"
265
+ )
266
+ and item["entity"]["name"] == owner_name
267
+ and "repo.write" in item["permissions"]
268
+ ):
269
+ matched = True
270
+ break
271
+ if not matched:
272
+ raise PermissionError(
273
+ "Expected the provided hf_token with fine grained permissions to provide write access to the space"
274
+ )
275
+ # reject read-only tokens
276
+ elif access_token["role"] != "write":
277
+ raise PermissionError(
278
+ "Expected the provided hf_token to provide write permissions"
279
+ )
280
+
281
+
282
+ def upload_db_to_space(
283
+ project: str, uploaded_db: gr.FileData, hf_token: str | None
284
+ ) -> None:
285
+ check_auth(hf_token)
286
+ db_project_path = SQLiteStorage.get_project_db_path(project)
287
+ if os.path.exists(db_project_path):
288
+ raise gr.Error(
289
+ f"Trackio database file already exists for project {project}, cannot overwrite."
290
+ )
291
+ os.makedirs(os.path.dirname(db_project_path), exist_ok=True)
292
+ shutil.copy(uploaded_db["path"], db_project_path)
293
+
294
+
295
+ def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None:
296
+ check_auth(hf_token)
297
+ for upload in uploads:
298
+ media_path = FileStorage.init_project_media_path(
299
+ upload["project"], upload["run"], upload["step"]
300
+ )
301
+ shutil.copy(upload["uploaded_file"]["path"], media_path)
302
+
303
+
304
+ def log(
305
+ project: str,
306
+ run: str,
307
+ metrics: dict[str, Any],
308
+ step: int | None,
309
+ hf_token: str | None,
310
+ ) -> None:
311
+ check_auth(hf_token)
312
+ SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step)
313
+
314
+
315
+ def bulk_log(
316
+ logs: list[LogEntry],
317
+ hf_token: str | None,
318
+ ) -> None:
319
+ check_auth(hf_token)
320
+
321
+ logs_by_run = {}
322
+ for log_entry in logs:
323
+ key = (log_entry["project"], log_entry["run"])
324
+ if key not in logs_by_run:
325
+ logs_by_run[key] = {"metrics": [], "steps": []}
326
+ logs_by_run[key]["metrics"].append(log_entry["metrics"])
327
+ logs_by_run[key]["steps"].append(log_entry.get("step"))
328
+
329
+ for (project, run), data in logs_by_run.items():
330
+ SQLiteStorage.bulk_log(
331
+ project=project,
332
+ run=run,
333
+ metrics_list=data["metrics"],
334
+ steps=data["steps"],
335
+ )
336
+
337
+
338
+ def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]:
339
+ """
340
+ Filter metrics using regex pattern.
341
+
342
+ Args:
343
+ metrics: List of metric names to filter
344
+ filter_pattern: Regex pattern to match against metric names
345
+
346
+ Returns:
347
+ List of metric names that match the pattern
348
+ """
349
+ if not filter_pattern.strip():
350
+ return metrics
351
+
352
+ try:
353
+ pattern = re.compile(filter_pattern, re.IGNORECASE)
354
+ return [metric for metric in metrics if pattern.search(metric)]
355
+ except re.error:
356
+ return [
357
+ metric for metric in metrics if filter_pattern.lower() in metric.lower()
358
+ ]
359
+
360
+
361
+ def configure(request: gr.Request):
362
+ sidebar_param = request.query_params.get("sidebar")
363
+ match sidebar_param:
364
+ case "collapsed":
365
+ sidebar = gr.Sidebar(open=False, visible=True)
366
+ case "hidden":
367
+ sidebar = gr.Sidebar(open=False, visible=False)
368
+ case _:
369
+ sidebar = gr.Sidebar(open=True, visible=True)
370
+
371
+ if metrics := request.query_params.get("metrics"):
372
+ return metrics.split(","), sidebar
373
+ else:
374
+ return [], sidebar
375
+
376
+
377
+ def create_media_section(media_by_run: dict[str, dict[str, list[MediaData]]]):
378
+ with gr.Accordion(label="media"):
379
+ with gr.Group(elem_classes=("media-group")):
380
+ for run, media_by_key in media_by_run.items():
381
+ with gr.Tab(label=run, elem_classes=("media-tab")):
382
+ for key, media_item in media_by_key.items():
383
+ gr.Gallery(
384
+ [(item.file_path, item.caption) for item in media_item],
385
+ label=key,
386
+ columns=6,
387
+ elem_classes=("media-gallery"),
388
+ )
389
+
390
+
391
+ css = """
392
+ #run-cb .wrap { gap: 2px; }
393
+ #run-cb .wrap label {
394
+ line-height: 1;
395
+ padding: 6px;
396
+ }
397
+ .logo-light { display: block; }
398
+ .logo-dark { display: none; }
399
+ .dark .logo-light { display: none; }
400
+ .dark .logo-dark { display: block; }
401
+ .dark .caption-label { color: white; }
402
+
403
+ .info-container {
404
+ position: relative;
405
+ display: inline;
406
+ }
407
+ .info-checkbox {
408
+ position: absolute;
409
+ opacity: 0;
410
+ pointer-events: none;
411
+ }
412
+ .info-icon {
413
+ border-bottom: 1px dotted;
414
+ cursor: pointer;
415
+ user-select: none;
416
+ color: var(--color-accent);
417
+ }
418
+ .info-expandable {
419
+ display: none;
420
+ opacity: 0;
421
+ transition: opacity 0.2s ease-in-out;
422
+ }
423
+ .info-checkbox:checked ~ .info-expandable {
424
+ display: inline;
425
+ opacity: 1;
426
+ }
427
+ .info-icon:hover { opacity: 0.8; }
428
+ .accent-link { font-weight: bold; }
429
+
430
+ .media-gallery .fixed-height { min-height: 275px; }
431
+ .media-group, .media-group > div { background: none; }
432
+ .media-group .tabs { padding: 0.5em; }
433
+ .media-tab { max-height: 500px; overflow-y: scroll; }
434
+ """
435
+
436
+ gr.set_static_paths(paths=[utils.MEDIA_DIR])
437
+ with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo:
438
+ with gr.Sidebar(open=False) as sidebar:
439
+ logo = gr.Markdown(
440
+ f"""
441
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_light_transparent.png' width='80%' class='logo-light'>
442
+ <img src='/gradio_api/file={utils.TRACKIO_LOGO_DIR}/trackio_logo_type_dark_transparent.png' width='80%' class='logo-dark'>
443
+ """
444
+ )
445
+ project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
446
+ run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
447
+ run_cb = gr.CheckboxGroup(
448
+ label="Runs", choices=[], interactive=True, elem_id="run-cb"
449
+ )
450
+ gr.HTML("<hr>")
451
+ realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True)
452
+ smoothing_slider = gr.Slider(
453
+ label="Smoothing Factor",
454
+ minimum=0,
455
+ maximum=20,
456
+ value=10,
457
+ step=1,
458
+ info="0 = no smoothing",
459
+ )
460
+ x_axis_dd = gr.Dropdown(
461
+ label="X-axis",
462
+ choices=["step", "time"],
463
+ value="step",
464
+ )
465
+ log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False)
466
+ metric_filter_tb = gr.Textbox(
467
+ label="Metric Filter (regex)",
468
+ placeholder="e.g., loss|ndcg@10|gpu",
469
+ value="",
470
+ info="Filter metrics using regex patterns. Leave empty to show all metrics.",
471
+ )
472
+
473
+ timer = gr.Timer(value=1)
474
+ metrics_subset = gr.State([])
475
+ user_interacted_with_run_cb = gr.State(False)
476
+
477
+ gr.on([demo.load], fn=configure, outputs=[metrics_subset, sidebar])
478
+ gr.on(
479
+ [demo.load],
480
+ fn=get_projects,
481
+ outputs=project_dd,
482
+ show_progress="hidden",
483
+ )
484
+ gr.on(
485
+ [timer.tick],
486
+ fn=update_runs,
487
+ inputs=[project_dd, run_tb, user_interacted_with_run_cb],
488
+ outputs=[run_cb, run_tb],
489
+ show_progress="hidden",
490
+ )
491
+ gr.on(
492
+ [timer.tick],
493
+ fn=lambda: gr.Dropdown(info=get_project_info()),
494
+ outputs=[project_dd],
495
+ show_progress="hidden",
496
+ )
497
+ gr.on(
498
+ [demo.load, project_dd.change],
499
+ fn=update_runs,
500
+ inputs=[project_dd, run_tb],
501
+ outputs=[run_cb, run_tb],
502
+ show_progress="hidden",
503
+ )
504
+ gr.on(
505
+ [demo.load, project_dd.change, run_cb.change],
506
+ fn=update_x_axis_choices,
507
+ inputs=[project_dd, run_cb],
508
+ outputs=x_axis_dd,
509
+ show_progress="hidden",
510
+ )
511
+
512
+ realtime_cb.change(
513
+ fn=toggle_timer,
514
+ inputs=realtime_cb,
515
+ outputs=timer,
516
+ api_name="toggle_timer",
517
+ )
518
+ run_cb.input(
519
+ fn=lambda: True,
520
+ outputs=user_interacted_with_run_cb,
521
+ )
522
+ run_tb.input(
523
+ fn=filter_runs,
524
+ inputs=[project_dd, run_tb],
525
+ outputs=run_cb,
526
+ )
527
+
528
+ gr.api(
529
+ fn=upload_db_to_space,
530
+ api_name="upload_db_to_space",
531
+ )
532
+ gr.api(
533
+ fn=bulk_upload_media,
534
+ api_name="bulk_upload_media",
535
+ )
536
+ gr.api(
537
+ fn=log,
538
+ api_name="log",
539
+ )
540
+ gr.api(
541
+ fn=bulk_log,
542
+ api_name="bulk_log",
543
+ )
544
+
545
+ x_lim = gr.State(None)
546
+ last_steps = gr.State({})
547
+
548
+ def update_x_lim(select_data: gr.SelectData):
549
+ return select_data.index
550
+
551
+ def update_last_steps(project, runs):
552
+ """Update the last step from all runs to detect when new data is available."""
553
+ if not project or not runs:
554
+ return {}
555
+
556
+ return SQLiteStorage.get_max_steps_for_runs(project, runs)
557
+
558
+ timer.tick(
559
+ fn=update_last_steps,
560
+ inputs=[project_dd, run_cb],
561
+ outputs=last_steps,
562
+ show_progress="hidden",
563
+ )
564
+
565
+ @gr.render(
566
+ triggers=[
567
+ demo.load,
568
+ run_cb.change,
569
+ last_steps.change,
570
+ smoothing_slider.change,
571
+ x_lim.change,
572
+ x_axis_dd.change,
573
+ log_scale_cb.change,
574
+ metric_filter_tb.change,
575
+ ],
576
+ inputs=[
577
+ project_dd,
578
+ run_cb,
579
+ smoothing_slider,
580
+ metrics_subset,
581
+ x_lim,
582
+ x_axis_dd,
583
+ log_scale_cb,
584
+ metric_filter_tb,
585
+ ],
586
+ show_progress="hidden",
587
+ )
588
+ def update_dashboard(
589
+ project,
590
+ runs,
591
+ smoothing_granularity,
592
+ metrics_subset,
593
+ x_lim_value,
594
+ x_axis,
595
+ log_scale,
596
+ metric_filter,
597
+ ):
598
+ dfs = []
599
+ images_by_run = {}
600
+ original_runs = runs.copy()
601
+
602
+ for run in runs:
603
+ df, images_by_key = load_run_data(
604
+ project, run, smoothing_granularity, x_axis, log_scale
605
+ )
606
+ if df is not None:
607
+ dfs.append(df)
608
+ images_by_run[run] = images_by_key
609
+ if dfs:
610
+ master_df = pd.concat(dfs, ignore_index=True)
611
+ else:
612
+ master_df = pd.DataFrame()
613
+
614
+ if master_df.empty:
615
+ return
616
+
617
+ x_column = "step"
618
+ if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns:
619
+ x_column = dfs[0]["x_axis"].iloc[0]
620
+
621
+ numeric_cols = master_df.select_dtypes(include="number").columns
622
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
623
+ if x_column and x_column in numeric_cols:
624
+ numeric_cols.remove(x_column)
625
+
626
+ if metrics_subset:
627
+ numeric_cols = [c for c in numeric_cols if c in metrics_subset]
628
+
629
+ if metric_filter and metric_filter.strip():
630
+ numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter)
631
+
632
+ nested_metric_groups = utils.group_metrics_with_subprefixes(list(numeric_cols))
633
+ color_map = utils.get_color_mapping(original_runs, smoothing_granularity > 0)
634
+
635
+ metric_idx = 0
636
+ for group_name in sorted(nested_metric_groups.keys()):
637
+ group_data = nested_metric_groups[group_name]
638
+
639
+ with gr.Accordion(
640
+ label=group_name,
641
+ open=True,
642
+ key=f"accordion-{group_name}",
643
+ preserved_by_key=["value", "open"],
644
+ ):
645
+ # Render direct metrics at this level
646
+ if group_data["direct_metrics"]:
647
+ with gr.Draggable(
648
+ key=f"row-{group_name}-direct", orientation="row"
649
+ ):
650
+ for metric_name in group_data["direct_metrics"]:
651
+ metric_df = master_df.dropna(subset=[metric_name])
652
+ color = "run" if "run" in metric_df.columns else None
653
+ if not metric_df.empty:
654
+ plot = gr.LinePlot(
655
+ utils.downsample(
656
+ metric_df,
657
+ x_column,
658
+ metric_name,
659
+ color,
660
+ x_lim_value,
661
+ ),
662
+ x=x_column,
663
+ y=metric_name,
664
+ y_title=metric_name.split("/")[-1],
665
+ color=color,
666
+ color_map=color_map,
667
+ title=metric_name,
668
+ key=f"plot-{metric_idx}",
669
+ preserved_by_key=None,
670
+ x_lim=x_lim_value,
671
+ show_fullscreen_button=True,
672
+ min_width=400,
673
+ )
674
+ plot.select(
675
+ update_x_lim,
676
+ outputs=x_lim,
677
+ key=f"select-{metric_idx}",
678
+ )
679
+ plot.double_click(
680
+ lambda: None,
681
+ outputs=x_lim,
682
+ key=f"double-{metric_idx}",
683
+ )
684
+ metric_idx += 1
685
+
686
+ # If there are subgroups, create nested accordions
687
+ if group_data["subgroups"]:
688
+ for subgroup_name in sorted(group_data["subgroups"].keys()):
689
+ subgroup_metrics = group_data["subgroups"][subgroup_name]
690
+
691
+ with gr.Accordion(
692
+ label=subgroup_name,
693
+ open=True,
694
+ key=f"accordion-{group_name}-{subgroup_name}",
695
+ preserved_by_key=["value", "open"],
696
+ ):
697
+ with gr.Draggable(key=f"row-{group_name}-{subgroup_name}"):
698
+ for metric_name in subgroup_metrics:
699
+ metric_df = master_df.dropna(subset=[metric_name])
700
+ color = (
701
+ "run" if "run" in metric_df.columns else None
702
+ )
703
+ if not metric_df.empty:
704
+ plot = gr.LinePlot(
705
+ utils.downsample(
706
+ metric_df,
707
+ x_column,
708
+ metric_name,
709
+ color,
710
+ x_lim_value,
711
+ ),
712
+ x=x_column,
713
+ y=metric_name,
714
+ y_title=metric_name.split("/")[-1],
715
+ color=color,
716
+ color_map=color_map,
717
+ title=metric_name,
718
+ key=f"plot-{metric_idx}",
719
+ preserved_by_key=None,
720
+ x_lim=x_lim_value,
721
+ show_fullscreen_button=True,
722
+ min_width=400,
723
+ )
724
+ plot.select(
725
+ update_x_lim,
726
+ outputs=x_lim,
727
+ key=f"select-{metric_idx}",
728
+ )
729
+ plot.double_click(
730
+ lambda: None,
731
+ outputs=x_lim,
732
+ key=f"double-{metric_idx}",
733
+ )
734
+ metric_idx += 1
735
+ if images_by_run and any(any(images) for images in images_by_run.values()):
736
+ create_media_section(images_by_run)
737
+
738
+ table_cols = master_df.select_dtypes(include="object").columns
739
+ table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS]
740
+ if metrics_subset:
741
+ table_cols = [c for c in table_cols if c in metrics_subset]
742
+ if metric_filter and metric_filter.strip():
743
+ table_cols = filter_metrics_by_regex(list(table_cols), metric_filter)
744
+ if len(table_cols) > 0:
745
+ with gr.Accordion("tables", open=True):
746
+ with gr.Row(key="row"):
747
+ for metric_idx, metric_name in enumerate(table_cols):
748
+ metric_df = master_df.dropna(subset=[metric_name])
749
+ if not metric_df.empty:
750
+ value = metric_df[metric_name].iloc[-1]
751
+ if (
752
+ isinstance(value, dict)
753
+ and "_type" in value
754
+ and value["_type"] == Table.TYPE
755
+ ):
756
+ try:
757
+ df = pd.DataFrame(value["_value"])
758
+ gr.DataFrame(
759
+ df,
760
+ label=f"{metric_name} (latest)",
761
+ key=f"table-{metric_idx}",
762
+ wrap=True,
763
+ )
764
+ except Exception as e:
765
+ gr.Warning(
766
+ f"Column {metric_name} failed to render as a table: {e}"
767
+ )
768
+
769
+
770
+ if __name__ == "__main__":
771
+ demo.launch(allowed_paths=[utils.TRACKIO_LOGO_DIR], show_api=False, show_error=True)
utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ import huggingface_hub
9
+ import numpy as np
10
+ import pandas as pd
11
+ from huggingface_hub.constants import HF_HOME
12
+
13
+ if TYPE_CHECKING:
14
+ from trackio.commit_scheduler import CommitScheduler
15
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
16
+
17
+ RESERVED_KEYS = ["project", "run", "timestamp", "step", "time", "metrics"]
18
+ TRACKIO_DIR = Path(HF_HOME) / "trackio"
19
+ MEDIA_DIR = TRACKIO_DIR / "media"
20
+
21
+ TRACKIO_LOGO_DIR = Path(__file__).parent / "assets"
22
+
23
+
24
+ def generate_readable_name(used_names: list[str], space_id: str | None = None) -> str:
25
+ """
26
+ Generates a random, readable name like "dainty-sunset-0".
27
+ If space_id is provided, generates username-timestamp format instead.
28
+ """
29
+ if space_id is not None:
30
+ username = huggingface_hub.whoami()["name"]
31
+ timestamp = int(time.time())
32
+ return f"{username}-{timestamp}"
33
+ adjectives = [
34
+ "dainty",
35
+ "brave",
36
+ "calm",
37
+ "eager",
38
+ "fancy",
39
+ "gentle",
40
+ "happy",
41
+ "jolly",
42
+ "kind",
43
+ "lively",
44
+ "merry",
45
+ "nice",
46
+ "proud",
47
+ "quick",
48
+ "hugging",
49
+ "silly",
50
+ "tidy",
51
+ "witty",
52
+ "zealous",
53
+ "bright",
54
+ "shy",
55
+ "bold",
56
+ "clever",
57
+ "daring",
58
+ "elegant",
59
+ "faithful",
60
+ "graceful",
61
+ "honest",
62
+ "inventive",
63
+ "jovial",
64
+ "keen",
65
+ "lucky",
66
+ "modest",
67
+ "noble",
68
+ "optimistic",
69
+ "patient",
70
+ "quirky",
71
+ "resourceful",
72
+ "sincere",
73
+ "thoughtful",
74
+ "upbeat",
75
+ "valiant",
76
+ "warm",
77
+ "youthful",
78
+ "zesty",
79
+ "adventurous",
80
+ "breezy",
81
+ "cheerful",
82
+ "delightful",
83
+ "energetic",
84
+ "fearless",
85
+ "glad",
86
+ "hopeful",
87
+ "imaginative",
88
+ "joyful",
89
+ "kindly",
90
+ "luminous",
91
+ "mysterious",
92
+ "neat",
93
+ "outgoing",
94
+ "playful",
95
+ "radiant",
96
+ "spirited",
97
+ "tranquil",
98
+ "unique",
99
+ "vivid",
100
+ "wise",
101
+ "zany",
102
+ "artful",
103
+ "bubbly",
104
+ "charming",
105
+ "dazzling",
106
+ "earnest",
107
+ "festive",
108
+ "gentlemanly",
109
+ "hearty",
110
+ "intrepid",
111
+ "jubilant",
112
+ "knightly",
113
+ "lively",
114
+ "magnetic",
115
+ "nimble",
116
+ "orderly",
117
+ "peaceful",
118
+ "quick-witted",
119
+ "robust",
120
+ "sturdy",
121
+ "trusty",
122
+ "upstanding",
123
+ "vibrant",
124
+ "whimsical",
125
+ ]
126
+ nouns = [
127
+ "sunset",
128
+ "forest",
129
+ "river",
130
+ "mountain",
131
+ "breeze",
132
+ "meadow",
133
+ "ocean",
134
+ "valley",
135
+ "sky",
136
+ "field",
137
+ "cloud",
138
+ "star",
139
+ "rain",
140
+ "leaf",
141
+ "stone",
142
+ "flower",
143
+ "bird",
144
+ "tree",
145
+ "wave",
146
+ "trail",
147
+ "island",
148
+ "desert",
149
+ "hill",
150
+ "lake",
151
+ "pond",
152
+ "grove",
153
+ "canyon",
154
+ "reef",
155
+ "bay",
156
+ "peak",
157
+ "glade",
158
+ "marsh",
159
+ "cliff",
160
+ "dune",
161
+ "spring",
162
+ "brook",
163
+ "cave",
164
+ "plain",
165
+ "ridge",
166
+ "wood",
167
+ "blossom",
168
+ "petal",
169
+ "root",
170
+ "branch",
171
+ "seed",
172
+ "acorn",
173
+ "pine",
174
+ "willow",
175
+ "cedar",
176
+ "elm",
177
+ "falcon",
178
+ "eagle",
179
+ "sparrow",
180
+ "robin",
181
+ "owl",
182
+ "finch",
183
+ "heron",
184
+ "crane",
185
+ "duck",
186
+ "swan",
187
+ "fox",
188
+ "wolf",
189
+ "bear",
190
+ "deer",
191
+ "moose",
192
+ "otter",
193
+ "beaver",
194
+ "lynx",
195
+ "hare",
196
+ "badger",
197
+ "butterfly",
198
+ "bee",
199
+ "ant",
200
+ "beetle",
201
+ "dragonfly",
202
+ "firefly",
203
+ "ladybug",
204
+ "moth",
205
+ "spider",
206
+ "worm",
207
+ "coral",
208
+ "kelp",
209
+ "shell",
210
+ "pebble",
211
+ "face",
212
+ "boulder",
213
+ "cobble",
214
+ "sand",
215
+ "wavelet",
216
+ "tide",
217
+ "current",
218
+ "mist",
219
+ ]
220
+ number = 0
221
+ name = f"{adjectives[0]}-{nouns[0]}-{number}"
222
+ while name in used_names:
223
+ number += 1
224
+ adjective = adjectives[number % len(adjectives)]
225
+ noun = nouns[number % len(nouns)]
226
+ name = f"{adjective}-{noun}-{number}"
227
+ return name
228
+
229
+
230
+ def block_except_in_notebook():
231
+ in_notebook = bool(getattr(sys, "ps1", sys.flags.interactive))
232
+ if in_notebook:
233
+ return
234
+ try:
235
+ while True:
236
+ time.sleep(0.1)
237
+ except (KeyboardInterrupt, OSError):
238
+ print("Keyboard interruption in main thread... closing dashboard.")
239
+
240
+
241
+ def simplify_column_names(columns: list[str]) -> dict[str, str]:
242
+ """
243
+ Simplifies column names to first 10 alphanumeric or "/" characters with unique suffixes.
244
+
245
+ Args:
246
+ columns: List of original column names
247
+
248
+ Returns:
249
+ Dictionary mapping original column names to simplified names
250
+ """
251
+ simplified_names = {}
252
+ used_names = set()
253
+
254
+ for col in columns:
255
+ alphanumeric = re.sub(r"[^a-zA-Z0-9/]", "", col)
256
+ base_name = alphanumeric[:10] if alphanumeric else f"col_{len(used_names)}"
257
+
258
+ final_name = base_name
259
+ suffix = 1
260
+ while final_name in used_names:
261
+ final_name = f"{base_name}_{suffix}"
262
+ suffix += 1
263
+
264
+ simplified_names[col] = final_name
265
+ used_names.add(final_name)
266
+
267
+ return simplified_names
268
+
269
+
270
+ def print_dashboard_instructions(project: str) -> None:
271
+ """
272
+ Prints instructions for viewing the Trackio dashboard.
273
+
274
+ Args:
275
+ project: The name of the project to show dashboard for.
276
+ """
277
+ YELLOW = "\033[93m"
278
+ BOLD = "\033[1m"
279
+ RESET = "\033[0m"
280
+
281
+ print("* View dashboard by running in your terminal:")
282
+ print(f'{BOLD}{YELLOW}trackio show --project "{project}"{RESET}')
283
+ print(f'* or by running in Python: trackio.show(project="{project}")')
284
+
285
+
286
+ def preprocess_space_and_dataset_ids(
287
+ space_id: str | None, dataset_id: str | None
288
+ ) -> tuple[str | None, str | None]:
289
+ if space_id is not None and "/" not in space_id:
290
+ username = huggingface_hub.whoami()["name"]
291
+ space_id = f"{username}/{space_id}"
292
+ if dataset_id is not None and "/" not in dataset_id:
293
+ username = huggingface_hub.whoami()["name"]
294
+ dataset_id = f"{username}/{dataset_id}"
295
+ if space_id is not None and dataset_id is None:
296
+ dataset_id = f"{space_id}-dataset"
297
+ return space_id, dataset_id
298
+
299
+
300
+ def fibo():
301
+ """Generator for Fibonacci backoff: 1, 1, 2, 3, 5, 8, ..."""
302
+ a, b = 1, 1
303
+ while True:
304
+ yield a
305
+ a, b = b, a + b
306
+
307
+
308
+ COLOR_PALETTE = [
309
+ "#3B82F6",
310
+ "#EF4444",
311
+ "#10B981",
312
+ "#F59E0B",
313
+ "#8B5CF6",
314
+ "#EC4899",
315
+ "#06B6D4",
316
+ "#84CC16",
317
+ "#F97316",
318
+ "#6366F1",
319
+ ]
320
+
321
+
322
+ def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
323
+ """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
324
+ color_map = {}
325
+
326
+ for i, run in enumerate(runs):
327
+ base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
328
+
329
+ if smoothing:
330
+ color_map[f"{run}_smoothed"] = base_color
331
+ color_map[f"{run}_original"] = base_color + "4D"
332
+ else:
333
+ color_map[run] = base_color
334
+
335
+ return color_map
336
+
337
+
338
+ def downsample(
339
+ df: pd.DataFrame,
340
+ x: str,
341
+ y: str,
342
+ color: str | None,
343
+ x_lim: tuple[float, float] | None = None,
344
+ ) -> pd.DataFrame:
345
+ if df.empty:
346
+ return df
347
+
348
+ columns_to_keep = [x, y]
349
+ if color is not None and color in df.columns:
350
+ columns_to_keep.append(color)
351
+ df = df[columns_to_keep].copy()
352
+
353
+ n_bins = 100
354
+
355
+ if color is not None and color in df.columns:
356
+ groups = df.groupby(color)
357
+ else:
358
+ groups = [(None, df)]
359
+
360
+ downsampled_indices = []
361
+
362
+ for _, group_df in groups:
363
+ if group_df.empty:
364
+ continue
365
+
366
+ group_df = group_df.sort_values(x)
367
+
368
+ if x_lim is not None:
369
+ x_min, x_max = x_lim
370
+ before_point = group_df[group_df[x] < x_min].tail(1)
371
+ after_point = group_df[group_df[x] > x_max].head(1)
372
+ group_df = group_df[(group_df[x] >= x_min) & (group_df[x] <= x_max)]
373
+ else:
374
+ before_point = after_point = None
375
+ x_min = group_df[x].min()
376
+ x_max = group_df[x].max()
377
+
378
+ if before_point is not None and not before_point.empty:
379
+ downsampled_indices.extend(before_point.index.tolist())
380
+ if after_point is not None and not after_point.empty:
381
+ downsampled_indices.extend(after_point.index.tolist())
382
+
383
+ if group_df.empty:
384
+ continue
385
+
386
+ if x_min == x_max:
387
+ min_y_idx = group_df[y].idxmin()
388
+ max_y_idx = group_df[y].idxmax()
389
+ if min_y_idx != max_y_idx:
390
+ downsampled_indices.extend([min_y_idx, max_y_idx])
391
+ else:
392
+ downsampled_indices.append(min_y_idx)
393
+ continue
394
+
395
+ if len(group_df) < 500:
396
+ downsampled_indices.extend(group_df.index.tolist())
397
+ continue
398
+
399
+ bins = np.linspace(x_min, x_max, n_bins + 1)
400
+ group_df["bin"] = pd.cut(
401
+ group_df[x], bins=bins, labels=False, include_lowest=True
402
+ )
403
+
404
+ for bin_idx in group_df["bin"].dropna().unique():
405
+ bin_data = group_df[group_df["bin"] == bin_idx]
406
+ if bin_data.empty:
407
+ continue
408
+
409
+ min_y_idx = bin_data[y].idxmin()
410
+ max_y_idx = bin_data[y].idxmax()
411
+
412
+ downsampled_indices.append(min_y_idx)
413
+ if min_y_idx != max_y_idx:
414
+ downsampled_indices.append(max_y_idx)
415
+
416
+ unique_indices = list(set(downsampled_indices))
417
+
418
+ downsampled_df = df.loc[unique_indices].copy()
419
+ downsampled_df = downsampled_df.sort_values(x).reset_index(drop=True)
420
+ downsampled_df = downsampled_df.drop(columns=["bin"], errors="ignore")
421
+
422
+ return downsampled_df
423
+
424
+
425
+ def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
426
+ """
427
+ Sort metrics by grouping prefixes together for dropdown/list display.
428
+ Metrics without prefixes come first, then grouped by prefix.
429
+
430
+ Args:
431
+ metrics: List of metric names
432
+
433
+ Returns:
434
+ List of metric names sorted by prefix
435
+
436
+ Example:
437
+ Input: ["train/loss", "loss", "train/acc", "val/loss"]
438
+ Output: ["loss", "train/acc", "train/loss", "val/loss"]
439
+ """
440
+ groups = group_metrics_by_prefix(metrics)
441
+ result = []
442
+
443
+ if "charts" in groups:
444
+ result.extend(groups["charts"])
445
+
446
+ for group_name in sorted(groups.keys()):
447
+ if group_name != "charts":
448
+ result.extend(groups[group_name])
449
+
450
+ return result
451
+
452
+
453
+ def group_metrics_by_prefix(metrics: list[str]) -> dict[str, list[str]]:
454
+ """
455
+ Group metrics by their prefix. Metrics without prefix go to 'charts' group.
456
+
457
+ Args:
458
+ metrics: List of metric names
459
+
460
+ Returns:
461
+ Dictionary with prefix names as keys and lists of metrics as values
462
+
463
+ Example:
464
+ Input: ["loss", "accuracy", "train/loss", "train/acc", "val/loss"]
465
+ Output: {
466
+ "charts": ["loss", "accuracy"],
467
+ "train": ["train/loss", "train/acc"],
468
+ "val": ["val/loss"]
469
+ }
470
+ """
471
+ no_prefix = []
472
+ with_prefix = []
473
+
474
+ for metric in metrics:
475
+ if "/" in metric:
476
+ with_prefix.append(metric)
477
+ else:
478
+ no_prefix.append(metric)
479
+
480
+ no_prefix.sort()
481
+
482
+ prefix_groups = {}
483
+ for metric in with_prefix:
484
+ prefix = metric.split("/")[0]
485
+ if prefix not in prefix_groups:
486
+ prefix_groups[prefix] = []
487
+ prefix_groups[prefix].append(metric)
488
+
489
+ for prefix in prefix_groups:
490
+ prefix_groups[prefix].sort()
491
+
492
+ groups = {}
493
+ if no_prefix:
494
+ groups["charts"] = no_prefix
495
+
496
+ for prefix in sorted(prefix_groups.keys()):
497
+ groups[prefix] = prefix_groups[prefix]
498
+
499
+ return groups
500
+
501
+
502
+ def group_metrics_with_subprefixes(metrics: list[str]) -> dict:
503
+ """
504
+ Group metrics with simple 2-level nested structure detection.
505
+
506
+ Returns a dictionary where each prefix group can have:
507
+ - direct_metrics: list of metrics at this level (e.g., "train/acc")
508
+ - subgroups: dict of subgroup name -> list of metrics (e.g., "loss" -> ["train/loss/norm", "train/loss/unnorm"])
509
+
510
+ Example:
511
+ Input: ["loss", "train/acc", "train/loss/normalized", "train/loss/unnormalized", "val/loss"]
512
+ Output: {
513
+ "charts": {
514
+ "direct_metrics": ["loss"],
515
+ "subgroups": {}
516
+ },
517
+ "train": {
518
+ "direct_metrics": ["train/acc"],
519
+ "subgroups": {
520
+ "loss": ["train/loss/normalized", "train/loss/unnormalized"]
521
+ }
522
+ },
523
+ "val": {
524
+ "direct_metrics": ["val/loss"],
525
+ "subgroups": {}
526
+ }
527
+ }
528
+ """
529
+ result = {}
530
+
531
+ for metric in metrics:
532
+ if "/" not in metric:
533
+ if "charts" not in result:
534
+ result["charts"] = {"direct_metrics": [], "subgroups": {}}
535
+ result["charts"]["direct_metrics"].append(metric)
536
+ else:
537
+ parts = metric.split("/")
538
+ main_prefix = parts[0]
539
+
540
+ if main_prefix not in result:
541
+ result[main_prefix] = {"direct_metrics": [], "subgroups": {}}
542
+
543
+ if len(parts) == 2:
544
+ result[main_prefix]["direct_metrics"].append(metric)
545
+ else:
546
+ subprefix = parts[1]
547
+ if subprefix not in result[main_prefix]["subgroups"]:
548
+ result[main_prefix]["subgroups"][subprefix] = []
549
+ result[main_prefix]["subgroups"][subprefix].append(metric)
550
+
551
+ for group_data in result.values():
552
+ group_data["direct_metrics"].sort()
553
+ for subgroup_metrics in group_data["subgroups"].values():
554
+ subgroup_metrics.sort()
555
+
556
+ if "charts" in result and not result["charts"]["direct_metrics"]:
557
+ del result["charts"]
558
+
559
+ return result
560
+
561
+
562
+ def get_sync_status(scheduler: "CommitScheduler | DummyCommitScheduler") -> int | None:
563
+ """Get the sync status from the CommitScheduler in an integer number of minutes, or None if not synced yet."""
564
+ if getattr(
565
+ scheduler, "last_push_time", None
566
+ ): # DummyCommitScheduler doesn't have last_push_time
567
+ time_diff = time.time() - scheduler.last_push_time
568
+ return int(time_diff / 60)
569
+ else:
570
+ return None
571
+
572
+
573
+ def serialize_values(metrics):
574
+ """
575
+ Serialize infinity and NaN values in metrics dict to make it JSON-compliant.
576
+ Only handles top-level float values.
577
+
578
+ Converts:
579
+ - float('inf') -> "Infinity"
580
+ - float('-inf') -> "-Infinity"
581
+ - float('nan') -> "NaN"
582
+
583
+ Example:
584
+ {"loss": float('inf'), "accuracy": 0.95} -> {"loss": "Infinity", "accuracy": 0.95}
585
+ """
586
+ if not isinstance(metrics, dict):
587
+ return metrics
588
+
589
+ result = {}
590
+ for key, value in metrics.items():
591
+ if isinstance(value, float):
592
+ if math.isinf(value):
593
+ result[key] = "Infinity" if value > 0 else "-Infinity"
594
+ elif math.isnan(value):
595
+ result[key] = "NaN"
596
+ else:
597
+ result[key] = value
598
+ elif isinstance(value, np.floating):
599
+ float_val = float(value)
600
+ if math.isinf(float_val):
601
+ result[key] = "Infinity" if float_val > 0 else "-Infinity"
602
+ elif math.isnan(float_val):
603
+ result[key] = "NaN"
604
+ else:
605
+ result[key] = float_val
606
+ else:
607
+ result[key] = value
608
+ return result
609
+
610
+
611
+ def deserialize_values(metrics):
612
+ """
613
+ Deserialize infinity and NaN string values back to their numeric forms.
614
+ Only handles top-level string values.
615
+
616
+ Converts:
617
+ - "Infinity" -> float('inf')
618
+ - "-Infinity" -> float('-inf')
619
+ - "NaN" -> float('nan')
620
+
621
+ Example:
622
+ {"loss": "Infinity", "accuracy": 0.95} -> {"loss": float('inf'), "accuracy": 0.95}
623
+ """
624
+ if not isinstance(metrics, dict):
625
+ return metrics
626
+
627
+ result = {}
628
+ for key, value in metrics.items():
629
+ if value == "Infinity":
630
+ result[key] = float("inf")
631
+ elif value == "-Infinity":
632
+ result[key] = float("-inf")
633
+ elif value == "NaN":
634
+ result[key] = float("nan")
635
+ else:
636
+ result[key] = value
637
+ return result