Spaces:
Paused
Paused
| from __future__ import annotations | |
| import dataclasses | |
| import threading | |
| from functools import partial | |
| from typing import List, NamedTuple, Optional, Sequence, Tuple | |
| from hivemind import DHT, PeerID | |
| from hivemind.utils.logging import get_logger, use_hivemind_log_handler | |
| from src.data_structures import ModuleUID, RemoteModuleInfo | |
| from src.dht_utils import _get_remote_module_infos | |
| use_hivemind_log_handler("in_root_logger") | |
| logger = get_logger(__file__) | |
| Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)]) | |
| # TODO[borzunov@] eto ne dataclass | |
| class RemoteSequenceInfo: | |
| """Keeps and updates the meta-information about which peers host which blocks""" | |
| dht: DHT | |
| block_uids: List[ModuleUID, ...] | |
| block_infos: List[Optional[RemoteModuleInfo], ...] | |
| spans_by_priority: List[Span] # sorted from best to worst | |
| spans_containing_block: Tuple[List[Span], ...] | |
| lock_changes: threading.Lock | |
| def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]): | |
| self.dht = dht | |
| self.block_uids = list(block_uids) | |
| self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids) | |
| self.spans_by_priority = [] | |
| self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids))) | |
| self.lock_changes = threading.Lock() | |
| self.update_() | |
| for uid, info in zip(self.block_uids, self.block_infos): | |
| assert info is not None, f"Found no remote peers for block {uid}" | |
| assert self.spans_by_priority and self.spans_containing_block | |
| def update_(self): | |
| with self.lock_changes: | |
| self.update_block_infos_() | |
| self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos) | |
| def update_block_infos_(self): | |
| new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine( | |
| partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False | |
| ) | |
| assert len(new_block_infos) == len(self.block_uids) | |
| for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): | |
| if info is None: | |
| logger.warning(f"Found no block info for block {uid}") | |
| if not isinstance(info, RemoteModuleInfo): | |
| logger.warning(f"Unexpected dht entry type for {uid}: {info}") | |
| if not info.peer_ids: | |
| logger.warning(f"Found no active peers for block {uid}") | |
| if info.uid != uid: | |
| logger.warning(f"The DHT entry for {uid} actually points to {info.uid}") | |
| if not isinstance(info.peer_ids, set): | |
| logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}") | |
| self.block_infos[block_index] = info | |
| def compute_spans(block_infos: Sequence[RemoteModuleInfo]): | |
| closed_spans = [] | |
| active_spans = {} | |
| for block_index, info in enumerate(block_infos): | |
| for peer_id in info.peer_ids: | |
| if peer_id not in active_spans: | |
| active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id) | |
| else: # peer_id in active_spans | |
| active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1) | |
| for peer_id in list(active_spans.keys()): | |
| if peer_id not in info.peer_ids or block_index == len(block_infos) - 1: | |
| closed_spans.append(active_spans.pop(peer_id)) | |
| assert not active_spans | |
| closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) | |
| spans_containing_block = tuple(list() for _ in range(len(block_infos))) | |
| for span in closed_spans: | |
| for block_index in range(span.start, span.end): | |
| spans_containing_block[block_index].append(span) | |
| return closed_spans, spans_containing_block | |
| def __len__(self): | |
| return len(self.block_uids) | |