trackio-import-126962 / sqlite_storage.py
abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
d9056c3 verified
import fcntl
import json
import os
import sqlite3
import time
from datetime import datetime
from pathlib import Path
from threading import Lock
import huggingface_hub as hf
import pandas as pd
try: # absolute imports when installed
from trackio.commit_scheduler import CommitScheduler
from trackio.dummy_commit_scheduler import DummyCommitScheduler
from trackio.utils import (
TRACKIO_DIR,
deserialize_values,
serialize_values,
)
except Exception: # relative imports for local execution on Spaces
from commit_scheduler import CommitScheduler
from dummy_commit_scheduler import DummyCommitScheduler
from utils import TRACKIO_DIR, deserialize_values, serialize_values
class ProcessLock:
"""A simple file-based lock that works across processes."""
def __init__(self, lockfile_path: Path):
self.lockfile_path = lockfile_path
self.lockfile = None
def __enter__(self):
"""Acquire the lock with retry logic."""
self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
self.lockfile = open(self.lockfile_path, "w")
max_retries = 100
for attempt in range(max_retries):
try:
fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
return self
except IOError:
if attempt < max_retries - 1:
time.sleep(0.1)
else:
raise IOError("Could not acquire database lock after 10 seconds")
def __exit__(self, exc_type, exc_val, exc_tb):
"""Release the lock."""
if self.lockfile:
fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
self.lockfile.close()
class SQLiteStorage:
_dataset_import_attempted = False
_current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
_scheduler_lock = Lock()
@staticmethod
def _get_connection(db_path: Path) -> sqlite3.Connection:
conn = sqlite3.connect(str(db_path), timeout=30.0)
conn.execute("PRAGMA journal_mode = WAL")
conn.row_factory = sqlite3.Row
return conn
@staticmethod
def _get_process_lock(project: str) -> ProcessLock:
lockfile_path = TRACKIO_DIR / f"{project}.lock"
return ProcessLock(lockfile_path)
@staticmethod
def get_project_db_filename(project: str) -> Path:
"""Get the database filename for a specific project."""
safe_project_name = "".join(
c for c in project if c.isalnum() or c in ("-", "_")
).rstrip()
if not safe_project_name:
safe_project_name = "default"
return f"{safe_project_name}.db"
@staticmethod
def get_project_db_path(project: str) -> Path:
"""Get the database path for a specific project."""
filename = SQLiteStorage.get_project_db_filename(project)
return TRACKIO_DIR / filename
@staticmethod
def init_db(project: str) -> Path:
"""
Initialize the SQLite database with required tables.
If there is a dataset ID provided, copies from that dataset instead.
Returns the database path.
"""
db_path = SQLiteStorage.get_project_db_path(project)
db_path.parent.mkdir(parents=True, exist_ok=True)
with SQLiteStorage._get_process_lock(project):
with sqlite3.connect(db_path, timeout=30.0) as conn:
conn.execute("PRAGMA journal_mode = WAL")
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
run_name TEXT NOT NULL,
step INTEGER NOT NULL,
metrics TEXT NOT NULL
)
""")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_metrics_run_step
ON metrics(run_name, step)
"""
)
conn.commit()
return db_path
@staticmethod
def export_to_parquet():
"""
Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
"""
# don't attempt to export (potentially wrong/blank) data before importing for the first time
if not SQLiteStorage._dataset_import_attempted:
return
all_paths = os.listdir(TRACKIO_DIR)
db_paths = [f for f in all_paths if f.endswith(".db")]
for db_path in db_paths:
db_path = TRACKIO_DIR / db_path
parquet_path = db_path.with_suffix(".parquet")
if (not parquet_path.exists()) or (
db_path.stat().st_mtime > parquet_path.stat().st_mtime
):
with sqlite3.connect(db_path) as conn:
df = pd.read_sql("SELECT * from metrics", conn)
# break out the single JSON metrics column into individual columns
metrics = df["metrics"].copy()
metrics = pd.DataFrame(
metrics.apply(
lambda x: deserialize_values(json.loads(x))
).values.tolist(),
index=df.index,
)
del df["metrics"]
for col in metrics.columns:
df[col] = metrics[col]
df.to_parquet(parquet_path)
@staticmethod
def import_from_parquet():
"""
Imports to all DB files that have matching files under the same path but with extension ".parquet".
"""
all_paths = os.listdir(TRACKIO_DIR)
parquet_paths = [f for f in all_paths if f.endswith(".parquet")]
for parquet_path in parquet_paths:
parquet_path = TRACKIO_DIR / parquet_path
db_path = parquet_path.with_suffix(".db")
df = pd.read_parquet(parquet_path)
with sqlite3.connect(db_path) as conn:
# fix up df to have a single JSON metrics column
if "metrics" not in df.columns:
# separate other columns from metrics
metrics = df.copy()
other_cols = ["id", "timestamp", "run_name", "step"]
df = df[other_cols]
for col in other_cols:
del metrics[col]
# combine them all into a single metrics col
metrics = json.loads(metrics.to_json(orient="records"))
df["metrics"] = [
json.dumps(serialize_values(row)) for row in metrics
]
df.to_sql("metrics", conn, if_exists="replace", index=False)
@staticmethod
def get_scheduler():
"""
Get the scheduler for the database based on the environment variables.
This applies to both local and Spaces.
"""
with SQLiteStorage._scheduler_lock:
if SQLiteStorage._current_scheduler is not None:
return SQLiteStorage._current_scheduler
hf_token = os.environ.get("HF_TOKEN")
dataset_id = os.environ.get("TRACKIO_DATASET_ID")
space_repo_name = os.environ.get("SPACE_REPO_NAME")
if dataset_id is None or space_repo_name is None:
scheduler = DummyCommitScheduler()
else:
scheduler = CommitScheduler(
repo_id=dataset_id,
repo_type="dataset",
folder_path=TRACKIO_DIR,
private=True,
allow_patterns=["*.parquet", "media/**/*"],
squash_history=True,
token=hf_token,
on_before_commit=SQLiteStorage.export_to_parquet,
)
SQLiteStorage._current_scheduler = scheduler
return scheduler
@staticmethod
def log(project: str, run: str, metrics: dict, step: int | None = None):
"""
Safely log metrics to the database. Before logging, this method will ensure the database exists
and is set up with the correct tables. It also uses a cross-process lock to prevent
database locking errors when multiple processes access the same database.
This method is not used in the latest versions of Trackio (replaced by bulk_log) but
is kept for backwards compatibility for users who are connecting to a newer version of
a Trackio Spaces dashboard with an older version of Trackio installed locally.
"""
db_path = SQLiteStorage.init_db(project)
with SQLiteStorage._get_process_lock(project):
with SQLiteStorage._get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT MAX(step)
FROM metrics
WHERE run_name = ?
""",
(run,),
)
last_step = cursor.fetchone()[0]
if step is None:
current_step = 0 if last_step is None else last_step + 1
else:
current_step = step
current_timestamp = datetime.now().isoformat()
cursor.execute(
"""
INSERT INTO metrics
(timestamp, run_name, step, metrics)
VALUES (?, ?, ?, ?)
""",
(
current_timestamp,
run,
current_step,
json.dumps(serialize_values(metrics)),
),
)
conn.commit()
@staticmethod
def bulk_log(
project: str,
run: str,
metrics_list: list[dict],
steps: list[int] | None = None,
timestamps: list[str] | None = None,
):
"""
Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
and is set up with the correct tables. It also uses a cross-process lock to prevent
database locking errors when multiple processes access the same database.
"""
if not metrics_list:
return
if timestamps is None:
timestamps = [datetime.now().isoformat()] * len(metrics_list)
db_path = SQLiteStorage.init_db(project)
with SQLiteStorage._get_process_lock(project):
with SQLiteStorage._get_connection(db_path) as conn:
cursor = conn.cursor()
if steps is None:
steps = list(range(len(metrics_list)))
elif any(s is None for s in steps):
cursor.execute(
"SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
)
last_step = cursor.fetchone()[0]
current_step = 0 if last_step is None else last_step + 1
processed_steps = []
for step in steps:
if step is None:
processed_steps.append(current_step)
current_step += 1
else:
processed_steps.append(step)
steps = processed_steps
if len(metrics_list) != len(steps) or len(metrics_list) != len(
timestamps
):
raise ValueError(
"metrics_list, steps, and timestamps must have the same length"
)
data = []
for i, metrics in enumerate(metrics_list):
data.append(
(
timestamps[i],
run,
steps[i],
json.dumps(serialize_values(metrics)),
)
)
cursor.executemany(
"""
INSERT INTO metrics
(timestamp, run_name, step, metrics)
VALUES (?, ?, ?, ?)
""",
data,
)
conn.commit()
@staticmethod
def get_logs(project: str, run: str) -> list[dict]:
"""Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
db_path = SQLiteStorage.get_project_db_path(project)
if not db_path.exists():
return []
with SQLiteStorage._get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT timestamp, step, metrics
FROM metrics
WHERE run_name = ?
ORDER BY timestamp
""",
(run,),
)
rows = cursor.fetchall()
results = []
for row in rows:
metrics = json.loads(row["metrics"])
metrics = deserialize_values(metrics)
metrics["timestamp"] = row["timestamp"]
metrics["step"] = row["step"]
results.append(metrics)
return results
@staticmethod
def load_from_dataset():
dataset_id = os.environ.get("TRACKIO_DATASET_ID")
space_repo_name = os.environ.get("SPACE_REPO_NAME")
if dataset_id is not None and space_repo_name is not None:
hfapi = hf.HfApi()
updated = False
if not TRACKIO_DIR.exists():
TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
with SQLiteStorage.get_scheduler().lock:
try:
files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
for file in files:
# Download parquet and media assets
if not (file.endswith(".parquet") or file.startswith("media/")):
continue
if (TRACKIO_DIR / file).exists():
continue
hf.hf_hub_download(
dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
)
updated = True
except hf.errors.EntryNotFoundError:
pass
except hf.errors.RepositoryNotFoundError:
pass
if updated:
SQLiteStorage.import_from_parquet()
SQLiteStorage._dataset_import_attempted = True
@staticmethod
def get_projects() -> list[str]:
"""
Get list of all projects by scanning the database files in the trackio directory.
"""
if not SQLiteStorage._dataset_import_attempted:
SQLiteStorage.load_from_dataset()
projects: set[str] = set()
if not TRACKIO_DIR.exists():
return []
for db_file in TRACKIO_DIR.glob("*.db"):
project_name = db_file.stem
projects.add(project_name)
return sorted(projects)
@staticmethod
def get_runs(project: str) -> list[str]:
"""Get list of all runs for a project."""
db_path = SQLiteStorage.get_project_db_path(project)
if not db_path.exists():
return []
with SQLiteStorage._get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT DISTINCT run_name FROM metrics",
)
return [row[0] for row in cursor.fetchall()]
@staticmethod
def get_max_steps_for_runs(project: str) -> dict[str, int]:
"""Get the maximum step for each run in a project."""
db_path = SQLiteStorage.get_project_db_path(project)
if not db_path.exists():
return {}
with SQLiteStorage._get_connection(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT run_name, MAX(step) as max_step
FROM metrics
GROUP BY run_name
"""
)
results = {}
for row in cursor.fetchall():
results[row["run_name"]] = row["max_step"]
return results
def finish(self):
"""Cleanup when run is finished."""
pass