from typing import Dict, Text, Callable, List from collections import defaultdict class HookManager(object): def __init__(self, hook_dict: Dict[Text, List[Callable]] = None): self.hook_dict = hook_dict or defaultdict(list) self.called = defaultdict(int) self.forks = dict() def register(self, name: Text, func: Callable): assert name found_successor = False for header, d in self.forks.items(): if name.startswith(header.split('.')[0]+'.'): next_ = name[len(header.split('.')[0]+'.'):].split('.')[0] prev_ = header.split('.')[0] if next_.isnumeric() and prev_ + '.' + next_ == header: d.register(name[len(header)+1:], func) elif next_ == '*': d.register(name[len(prev_ + '.*')+1:], func) else: d.register(name[len(header)+1:], func) found_successor = True if not found_successor: self.hook_dict[name].append(func) def unregister(self, name: Text, func: Callable): assert name found_successor = False for header, d in self.forks.items(): if name.startswith(header.split('.')[0]+'.'): next_ = name[len(header.split('.')[0]+'.'):].split('.')[0] prev_ = header.split('.')[0] if next_.isnumeric() and prev_ + '.' + next_ == header: d.register(name[len(header)+1:], func) elif next_ == '*': d.register(name[len(prev_ + '.*')+1:], func) else: d.register(name[len(header)+1:], func) found_successor = True if not found_successor and func in self.hook_dict[name]: self.hook_dict[name].remove(func) def __call__(self, name: Text, **kwargs): if name in self.hook_dict: self.called[name] += 1 for function in self.hook_dict[name]: ret = function(**kwargs) if len(self.hook_dict[name]) > 1: last = self.hook_dict[name][-1] # print(f'The last returned value comes from func {last}') return ret else: return kwargs['ret'] def fork(self, name): if name in self.forks: raise ValueError(f'Forking with the same name is not allowed. Already forked with {name}.') filtered_hooks = [(k[len(name)+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.')] filtered_hooks_d = defaultdict(list) for i, j in filtered_hooks: if isinstance(j, list): filtered_hooks_d[i].extend(j) else: filtered_hooks_d[i].append(j) new_hook = HookManager(filtered_hooks_d) self.forks[name] = new_hook return new_hook def fork_iterative(self, name, iteration): filtered_hooks = [(k[len(name+'.'+str(iteration))+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.'+str(iteration)+'.')] filtered_hooks += [(k[len(name+'.*')+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.*.')] filtered_hooks_d = defaultdict(list) for i, j in filtered_hooks: if isinstance(j, list): filtered_hooks_d[i].extend(j) else: filtered_hooks_d[i].append(j) new_hook = HookManager(filtered_hooks_d) self.forks[name+'.'+str(iteration)] = new_hook return new_hook def finalize(self): for name in self.hook_dict.keys(): if self.called[name] == 0: raise ValueError(f'Hook {name} was registered but never used!')