JotunnBurton commited on
Commit
2359fbe
·
verified ·
1 Parent(s): 0fc12b1

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -356
utils.py DELETED
@@ -1,356 +0,0 @@
1
- import os
2
- import glob
3
- import argparse
4
- import logging
5
- import json
6
- import subprocess
7
- import numpy as np
8
- from scipy.io.wavfile import read
9
- import torch
10
-
11
- MATPLOTLIB_FLAG = False
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
17
- assert os.path.isfile(checkpoint_path)
18
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
19
- iteration = checkpoint_dict["iteration"]
20
- learning_rate = checkpoint_dict["learning_rate"]
21
- if (
22
- optimizer is not None
23
- and not skip_optimizer
24
- and checkpoint_dict["optimizer"] is not None
25
- ):
26
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
27
- elif optimizer is None and not skip_optimizer:
28
- # else: Disable this line if Infer and resume checkpoint,then enable the line upper
29
- new_opt_dict = optimizer.state_dict()
30
- new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
31
- new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
32
- new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
33
- optimizer.load_state_dict(new_opt_dict)
34
-
35
- saved_state_dict = checkpoint_dict["model"]
36
- if hasattr(model, "module"):
37
- state_dict = model.module.state_dict()
38
- else:
39
- state_dict = model.state_dict()
40
-
41
- new_state_dict = {}
42
- for k, v in state_dict.items():
43
- try:
44
- # assert "emb_g" not in k
45
- new_state_dict[k] = saved_state_dict[k]
46
- assert saved_state_dict[k].shape == v.shape, (
47
- saved_state_dict[k].shape,
48
- v.shape,
49
- )
50
- except:
51
- # For upgrading from the old version
52
- if "ja_bert_proj" in k:
53
- v = torch.zeros_like(v)
54
- logger.warn(
55
- f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
56
- )
57
- else:
58
- logger.error(f"{k} is not in the checkpoint")
59
-
60
- new_state_dict[k] = v
61
-
62
- if hasattr(model, "module"):
63
- model.module.load_state_dict(new_state_dict, strict=False)
64
- else:
65
- model.load_state_dict(new_state_dict, strict=False)
66
-
67
- logger.info(
68
- "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
69
- )
70
-
71
- return model, optimizer, learning_rate, iteration
72
-
73
-
74
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
75
- logger.info(
76
- "Saving model and optimizer state at iteration {} to {}".format(
77
- iteration, checkpoint_path
78
- )
79
- )
80
- if hasattr(model, "module"):
81
- state_dict = model.module.state_dict()
82
- else:
83
- state_dict = model.state_dict()
84
- torch.save(
85
- {
86
- "model": state_dict,
87
- "iteration": iteration,
88
- "optimizer": optimizer.state_dict(),
89
- "learning_rate": learning_rate,
90
- },
91
- checkpoint_path,
92
- )
93
-
94
-
95
- def summarize(
96
- writer,
97
- global_step,
98
- scalars={},
99
- histograms={},
100
- images={},
101
- audios={},
102
- audio_sampling_rate=22050,
103
- ):
104
- for k, v in scalars.items():
105
- writer.add_scalar(k, v, global_step)
106
- for k, v in histograms.items():
107
- writer.add_histogram(k, v, global_step)
108
- for k, v in images.items():
109
- writer.add_image(k, v, global_step, dataformats="HWC")
110
- for k, v in audios.items():
111
- writer.add_audio(k, v, global_step, audio_sampling_rate)
112
-
113
-
114
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
115
- f_list = glob.glob(os.path.join(dir_path, regex))
116
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
117
- x = f_list[-1]
118
- return x
119
-
120
-
121
- def plot_spectrogram_to_numpy(spectrogram):
122
- global MATPLOTLIB_FLAG
123
- if not MATPLOTLIB_FLAG:
124
- import matplotlib
125
-
126
- matplotlib.use("Agg")
127
- MATPLOTLIB_FLAG = True
128
- mpl_logger = logging.getLogger("matplotlib")
129
- mpl_logger.setLevel(logging.WARNING)
130
- import matplotlib.pylab as plt
131
- import numpy as np
132
-
133
- fig, ax = plt.subplots(figsize=(10, 2))
134
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
135
- plt.colorbar(im, ax=ax)
136
- plt.xlabel("Frames")
137
- plt.ylabel("Channels")
138
- plt.tight_layout()
139
-
140
- fig.canvas.draw()
141
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
142
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
143
- plt.close()
144
- return data
145
-
146
-
147
- def plot_alignment_to_numpy(alignment, info=None):
148
- global MATPLOTLIB_FLAG
149
- if not MATPLOTLIB_FLAG:
150
- import matplotlib
151
-
152
- matplotlib.use("Agg")
153
- MATPLOTLIB_FLAG = True
154
- mpl_logger = logging.getLogger("matplotlib")
155
- mpl_logger.setLevel(logging.WARNING)
156
- import matplotlib.pylab as plt
157
- import numpy as np
158
-
159
- fig, ax = plt.subplots(figsize=(6, 4))
160
- im = ax.imshow(
161
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
162
- )
163
- fig.colorbar(im, ax=ax)
164
- xlabel = "Decoder timestep"
165
- if info is not None:
166
- xlabel += "\n\n" + info
167
- plt.xlabel(xlabel)
168
- plt.ylabel("Encoder timestep")
169
- plt.tight_layout()
170
-
171
- fig.canvas.draw()
172
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
173
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
174
- plt.close()
175
- return data
176
-
177
-
178
- def load_wav_to_torch(full_path):
179
- sampling_rate, data = read(full_path)
180
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
181
-
182
-
183
- def load_filepaths_and_text(filename, split="|"):
184
- with open(filename, encoding="utf-8") as f:
185
- filepaths_and_text = [line.strip().split(split) for line in f]
186
- return filepaths_and_text
187
-
188
-
189
- def get_hparams(init=True):
190
- parser = argparse.ArgumentParser()
191
- parser.add_argument(
192
- "-c",
193
- "--config",
194
- type=str,
195
- default="./configs/base.json",
196
- help="JSON file for configuration",
197
- )
198
- parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
199
-
200
- args = parser.parse_args()
201
- model_dir = os.path.join("./logs", args.model)
202
-
203
- if not os.path.exists(model_dir):
204
- os.makedirs(model_dir)
205
-
206
- config_path = args.config
207
- config_save_path = os.path.join(model_dir, "config.json")
208
- if init:
209
- with open(config_path, "r", encoding="utf-8") as f:
210
- data = f.read()
211
- with open(config_save_path, "w", encoding="utf-8") as f:
212
- f.write(data)
213
- else:
214
- with open(config_save_path, "r", vencoding="utf-8") as f:
215
- data = f.read()
216
- config = json.loads(data)
217
- hparams = HParams(**config)
218
- hparams.model_dir = model_dir
219
- return hparams
220
-
221
-
222
- def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
223
- """Freeing up space by deleting saved ckpts
224
-
225
- Arguments:
226
- path_to_models -- Path to the model directory
227
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
228
- sort_by_time -- True -> chronologically delete ckpts
229
- False -> lexicographically delete ckpts
230
- """
231
- import re
232
-
233
- ckpts_files = [
234
- f
235
- for f in os.listdir(path_to_models)
236
- if os.path.isfile(os.path.join(path_to_models, f))
237
- ]
238
-
239
- def name_key(_f):
240
- return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
241
-
242
- def time_key(_f):
243
- return os.path.getmtime(os.path.join(path_to_models, _f))
244
-
245
- sort_key = time_key if sort_by_time else name_key
246
-
247
- def x_sorted(_x):
248
- return sorted(
249
- [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
250
- key=sort_key,
251
- )
252
-
253
- to_del = [
254
- os.path.join(path_to_models, fn)
255
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
256
- ]
257
-
258
- def del_info(fn):
259
- return logger.info(f".. Free up space by deleting ckpt {fn}")
260
-
261
- def del_routine(x):
262
- return [os.remove(x), del_info(x)]
263
-
264
- [del_routine(fn) for fn in to_del]
265
-
266
-
267
- def get_hparams_from_dir(model_dir):
268
- config_save_path = os.path.join(model_dir, "config.json")
269
- with open(config_save_path, "r", encoding="utf-8") as f:
270
- data = f.read()
271
- config = json.loads(data)
272
-
273
- hparams = HParams(**config)
274
- hparams.model_dir = model_dir
275
- return hparams
276
-
277
-
278
- def get_hparams_from_file(config_path):
279
- with open(config_path, "r", encoding="utf-8") as f:
280
- data = f.read()
281
- config = json.loads(data)
282
-
283
- hparams = HParams(**config)
284
- return hparams
285
-
286
-
287
- def check_git_hash(model_dir):
288
- source_dir = os.path.dirname(os.path.realpath(__file__))
289
- if not os.path.exists(os.path.join(source_dir, ".git")):
290
- logger.warn(
291
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
292
- source_dir
293
- )
294
- )
295
- return
296
-
297
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
298
-
299
- path = os.path.join(model_dir, "githash")
300
- if os.path.exists(path):
301
- saved_hash = open(path).read()
302
- if saved_hash != cur_hash:
303
- logger.warn(
304
- "git hash values are different. {}(saved) != {}(current)".format(
305
- saved_hash[:8], cur_hash[:8]
306
- )
307
- )
308
- else:
309
- open(path, "w").write(cur_hash)
310
-
311
-
312
- def get_logger(model_dir, filename="train.log"):
313
- global logger
314
- logger = logging.getLogger(os.path.basename(model_dir))
315
- logger.setLevel(logging.DEBUG)
316
-
317
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
318
- if not os.path.exists(model_dir):
319
- os.makedirs(model_dir)
320
- h = logging.FileHandler(os.path.join(model_dir, filename))
321
- h.setLevel(logging.DEBUG)
322
- h.setFormatter(formatter)
323
- logger.addHandler(h)
324
- return logger
325
-
326
-
327
- class HParams:
328
- def __init__(self, **kwargs):
329
- for k, v in kwargs.items():
330
- if type(v) == dict:
331
- v = HParams(**v)
332
- self[k] = v
333
-
334
- def keys(self):
335
- return self.__dict__.keys()
336
-
337
- def items(self):
338
- return self.__dict__.items()
339
-
340
- def values(self):
341
- return self.__dict__.values()
342
-
343
- def __len__(self):
344
- return len(self.__dict__)
345
-
346
- def __getitem__(self, key):
347
- return getattr(self, key)
348
-
349
- def __setitem__(self, key, value):
350
- return setattr(self, key, value)
351
-
352
- def __contains__(self, key):
353
- return key in self.__dict__
354
-
355
- def __repr__(self):
356
- return self.__dict__.__repr__()