Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from functools import reduce, partial | |
| from operator import mul | |
| from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations | |
| def chain_functions(*functions): | |
| return lambda initial: reduce(lambda x, f: f(x), functions, initial) | |
| def remove_fx_parametrisation(fx): | |
| def remover(m): | |
| if not is_parametrized(m): | |
| return | |
| for k in list(m.parametrizations.keys()): | |
| remove_parametrizations(m, k) | |
| fx.apply(remover) | |
| return fx | |
| def get_chunks(keys, original_shapes): | |
| (position, _), *_ = filter(lambda i_k: "U.original" in i_k[1], enumerate(keys)) | |
| original_chunks = list(map(partial(reduce, mul), original_shapes)) | |
| U_matrix_shape = original_shapes[position] | |
| dimensions_not_need = np.ravel_multi_index( | |
| np.tril_indices(**dict(zip(("n", "m"), U_matrix_shape))), U_matrix_shape | |
| ) + sum(original_chunks[:position]) | |
| selected_chunks = ( | |
| original_chunks[:position] | |
| + [original_chunks[position] - dimensions_not_need.size] | |
| + original_chunks[position + 1 :] | |
| ) | |
| return selected_chunks, position, U_matrix_shape, dimensions_not_need | |
| def vec2statedict( | |
| x: torch.Tensor, | |
| keys, | |
| original_shapes, | |
| selected_chunks, | |
| position, | |
| U_matrix_shape, | |
| ): | |
| chunks = list(torch.split(x, selected_chunks)) | |
| U = x.new_zeros(reduce(mul, U_matrix_shape)) | |
| U[ | |
| np.ravel_multi_index( | |
| np.triu_indices(n=U_matrix_shape[0], k=1, m=U_matrix_shape[1]), | |
| U_matrix_shape, | |
| ) | |
| ] = chunks[position] | |
| chunks[position] = U | |
| state_dict = dict( | |
| zip( | |
| keys, | |
| map(lambda x, shape: x.reshape(*shape), chunks, original_shapes), | |
| ) | |
| ) | |
| return state_dict | |