Upload _base_explorers.py
Browse files
audiocraft/grids/_base_explorers.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
import time
|
| 9 |
+
import typing as tp
|
| 10 |
+
from dora import Explorer
|
| 11 |
+
import treetable as tt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_sheep_ping(sheep) -> tp.Optional[str]:
|
| 15 |
+
"""Return the amount of time since the Sheep made some update
|
| 16 |
+
to its log. Returns a str using the relevant time unit."""
|
| 17 |
+
ping = None
|
| 18 |
+
if sheep.log is not None and sheep.log.exists():
|
| 19 |
+
delta = time.time() - sheep.log.stat().st_mtime
|
| 20 |
+
if delta > 3600 * 24:
|
| 21 |
+
ping = f'{delta / (3600 * 24):.1f}d'
|
| 22 |
+
elif delta > 3600:
|
| 23 |
+
ping = f'{delta / (3600):.1f}h'
|
| 24 |
+
elif delta > 60:
|
| 25 |
+
ping = f'{delta / 60:.1f}m'
|
| 26 |
+
else:
|
| 27 |
+
ping = f'{delta:.1f}s'
|
| 28 |
+
return ping
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseExplorer(ABC, Explorer):
|
| 32 |
+
"""Base explorer for AudioCraft grids.
|
| 33 |
+
|
| 34 |
+
All task specific solvers are expected to implement the `get_grid_metrics`
|
| 35 |
+
method to specify logic about metrics to display for a given task.
|
| 36 |
+
|
| 37 |
+
If additional stages are used, the child explorer must define how to handle
|
| 38 |
+
these new stages in the `process_history` and `process_sheep` methods.
|
| 39 |
+
"""
|
| 40 |
+
def stages(self):
|
| 41 |
+
return ["train", "valid", "evaluate"]
|
| 42 |
+
|
| 43 |
+
def get_grid_meta(self):
|
| 44 |
+
"""Returns the list of Meta information to display for each XP/job.
|
| 45 |
+
"""
|
| 46 |
+
return [
|
| 47 |
+
tt.leaf("index", align=">"),
|
| 48 |
+
tt.leaf("name", wrap=140),
|
| 49 |
+
tt.leaf("state"),
|
| 50 |
+
tt.leaf("sig", align=">"),
|
| 51 |
+
tt.leaf("sid", align="<"),
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def get_grid_metrics(self):
|
| 56 |
+
"""Return the metrics that should be displayed in the tracking table.
|
| 57 |
+
"""
|
| 58 |
+
...
|
| 59 |
+
|
| 60 |
+
def process_sheep(self, sheep, history):
|
| 61 |
+
train = {
|
| 62 |
+
"epoch": len(history),
|
| 63 |
+
}
|
| 64 |
+
parts = {"train": train}
|
| 65 |
+
for metrics in history:
|
| 66 |
+
for key, sub in metrics.items():
|
| 67 |
+
part = parts.get(key, {})
|
| 68 |
+
if 'duration' in sub:
|
| 69 |
+
# Convert to minutes for readability.
|
| 70 |
+
sub['duration'] = sub['duration'] / 60.
|
| 71 |
+
part.update(sub)
|
| 72 |
+
parts[key] = part
|
| 73 |
+
ping = get_sheep_ping(sheep)
|
| 74 |
+
if ping is not None:
|
| 75 |
+
for name in self.stages():
|
| 76 |
+
if name not in parts:
|
| 77 |
+
parts[name] = {}
|
| 78 |
+
# Add the ping to each part for convenience.
|
| 79 |
+
parts[name]['ping'] = ping
|
| 80 |
+
return parts
|