Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files- __init__.py +0 -0
- build_vocab.py +80 -0
- callbacks.py +1066 -0
- dataset.py +1 -1
- logger.py +71 -0
- models_debugger.py +816 -0
- tcn.py +83 -0
__init__.py
ADDED
|
File without changes
|
build_vocab.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from collections import Counter
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class JsonReader(object):
|
| 7 |
+
def __init__(self, json_file):
|
| 8 |
+
self.data = self.__read_json(json_file)
|
| 9 |
+
self.keys = list(self.data.keys())
|
| 10 |
+
|
| 11 |
+
def __read_json(self, filename):
|
| 12 |
+
with open(filename, 'r') as f:
|
| 13 |
+
data = json.load(f)
|
| 14 |
+
return data
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, item):
|
| 17 |
+
return self.data[item]
|
| 18 |
+
# return self.data[self.keys[item]]
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.data)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Vocabulary(object):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.word2idx = {}
|
| 27 |
+
self.id2word = {}
|
| 28 |
+
self.idx = 0
|
| 29 |
+
self.add_word('<pad>')
|
| 30 |
+
self.add_word('<end>')
|
| 31 |
+
self.add_word('<start>')
|
| 32 |
+
self.add_word('<unk>')
|
| 33 |
+
|
| 34 |
+
def add_word(self, word):
|
| 35 |
+
if word not in self.word2idx:
|
| 36 |
+
self.word2idx[word] = self.idx
|
| 37 |
+
self.id2word[self.idx] = word
|
| 38 |
+
self.idx += 1
|
| 39 |
+
|
| 40 |
+
def get_word_by_id(self, id):
|
| 41 |
+
return self.id2word[id]
|
| 42 |
+
|
| 43 |
+
def __call__(self, word):
|
| 44 |
+
if word not in self.word2idx:
|
| 45 |
+
return self.word2idx['<unk>']
|
| 46 |
+
return self.word2idx[word]
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.word2idx)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_vocab(json_file, threshold):
|
| 53 |
+
caption_reader = JsonReader(json_file)
|
| 54 |
+
counter = Counter()
|
| 55 |
+
|
| 56 |
+
for items in caption_reader:
|
| 57 |
+
text = items.replace('.', '').replace(',', '')
|
| 58 |
+
counter.update(text.lower().split(' '))
|
| 59 |
+
words = [word for word, cnt in counter.items() if cnt > threshold and word != '']
|
| 60 |
+
vocab = Vocabulary()
|
| 61 |
+
|
| 62 |
+
for word in words:
|
| 63 |
+
print(word)
|
| 64 |
+
vocab.add_word(word)
|
| 65 |
+
return vocab
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main(json_file, threshold, vocab_path):
|
| 69 |
+
vocab = build_vocab(json_file=json_file,
|
| 70 |
+
threshold=threshold)
|
| 71 |
+
with open(vocab_path, 'wb') as f:
|
| 72 |
+
pickle.dump(vocab, f)
|
| 73 |
+
print("Total vocabulary size:{}".format(len(vocab)))
|
| 74 |
+
print("Saved path in {}".format(vocab_path))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == '__main__':
|
| 78 |
+
main(json_file='../data/new_data/debugging_captions.json',
|
| 79 |
+
threshold=0,
|
| 80 |
+
vocab_path='../data/new_data/debug_vocab.pkl')
|
callbacks.py
ADDED
|
@@ -0,0 +1,1066 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Callbacks: utilities called at certain points during model training.
|
| 2 |
+
|
| 3 |
+
# Adapted from
|
| 4 |
+
|
| 5 |
+
- https://github.com/keras-team/keras
|
| 6 |
+
- https://github.com/bstriner/keras-tqdm/blob/master/keras_tqdm/tqdm_callback.py
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import absolute_import
|
| 10 |
+
from __future__ import division
|
| 11 |
+
from __future__ import print_function
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import csv
|
| 15 |
+
import six
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import time
|
| 19 |
+
import json
|
| 20 |
+
import warnings
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from collections import deque
|
| 24 |
+
from collections import OrderedDict
|
| 25 |
+
from collections import Iterable
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import requests
|
| 29 |
+
except ImportError:
|
| 30 |
+
requests = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CallbackList(object):
|
| 34 |
+
"""Container abstracting a list of callbacks.
|
| 35 |
+
|
| 36 |
+
# Arguments
|
| 37 |
+
callbacks: List of `Callback` instances.
|
| 38 |
+
queue_length: Queue length for keeping
|
| 39 |
+
running statistics over callback execution time.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, callbacks=None, queue_length=10):
|
| 43 |
+
callbacks = callbacks or []
|
| 44 |
+
self.callbacks = [c for c in callbacks]
|
| 45 |
+
self.queue_length = queue_length
|
| 46 |
+
|
| 47 |
+
def append(self, callback):
|
| 48 |
+
self.callbacks.append(callback)
|
| 49 |
+
|
| 50 |
+
def set_params(self, params):
|
| 51 |
+
for callback in self.callbacks:
|
| 52 |
+
callback.set_params(params)
|
| 53 |
+
|
| 54 |
+
def set_model(self, model):
|
| 55 |
+
for callback in self.callbacks:
|
| 56 |
+
callback.set_model(model)
|
| 57 |
+
|
| 58 |
+
def on_epoch_begin(self, epoch, logs=None):
|
| 59 |
+
"""Called at the start of an epoch.
|
| 60 |
+
|
| 61 |
+
# Arguments
|
| 62 |
+
epoch: integer, index of epoch.
|
| 63 |
+
logs: dictionary of logs.
|
| 64 |
+
"""
|
| 65 |
+
logs = logs or {}
|
| 66 |
+
for callback in self.callbacks:
|
| 67 |
+
callback.on_epoch_begin(epoch, logs)
|
| 68 |
+
self._delta_t_batch = 0.
|
| 69 |
+
self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
|
| 70 |
+
self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
|
| 71 |
+
|
| 72 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 73 |
+
"""Called at the end of an epoch.
|
| 74 |
+
|
| 75 |
+
# Arguments
|
| 76 |
+
epoch: integer, index of epoch.
|
| 77 |
+
logs: dictionary of logs.
|
| 78 |
+
"""
|
| 79 |
+
logs = logs or {}
|
| 80 |
+
for callback in self.callbacks:
|
| 81 |
+
callback.on_epoch_end(epoch, logs)
|
| 82 |
+
|
| 83 |
+
def on_batch_begin(self, batch, logs=None):
|
| 84 |
+
"""Called right before processing a batch.
|
| 85 |
+
|
| 86 |
+
# Arguments
|
| 87 |
+
batch: integer, index of batch within the current epoch.
|
| 88 |
+
logs: dictionary of logs.
|
| 89 |
+
"""
|
| 90 |
+
logs = logs or {}
|
| 91 |
+
t_before_callbacks = time.time()
|
| 92 |
+
for callback in self.callbacks:
|
| 93 |
+
callback.on_batch_begin(batch, logs)
|
| 94 |
+
self._delta_ts_batch_begin.append(time.time() - t_before_callbacks)
|
| 95 |
+
delta_t_median = np.median(self._delta_ts_batch_begin)
|
| 96 |
+
if (self._delta_t_batch > 0. and
|
| 97 |
+
delta_t_median > 0.95 * self._delta_t_batch and
|
| 98 |
+
delta_t_median > 0.1):
|
| 99 |
+
warnings.warn('Method on_batch_begin() is slow compared '
|
| 100 |
+
'to the batch update (%f). Check your callbacks.'
|
| 101 |
+
% delta_t_median)
|
| 102 |
+
self._t_enter_batch = time.time()
|
| 103 |
+
|
| 104 |
+
def on_batch_end(self, batch, logs=None):
|
| 105 |
+
"""Called at the end of a batch.
|
| 106 |
+
|
| 107 |
+
# Arguments
|
| 108 |
+
batch: integer, index of batch within the current epoch.
|
| 109 |
+
logs: dictionary of logs.
|
| 110 |
+
"""
|
| 111 |
+
logs = logs or {}
|
| 112 |
+
if not hasattr(self, '_t_enter_batch'):
|
| 113 |
+
self._t_enter_batch = time.time()
|
| 114 |
+
self._delta_t_batch = time.time() - self._t_enter_batch
|
| 115 |
+
t_before_callbacks = time.time()
|
| 116 |
+
for callback in self.callbacks:
|
| 117 |
+
callback.on_batch_end(batch, logs)
|
| 118 |
+
self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
|
| 119 |
+
delta_t_median = np.median(self._delta_ts_batch_end)
|
| 120 |
+
if (self._delta_t_batch > 0. and
|
| 121 |
+
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
|
| 122 |
+
warnings.warn('Method on_batch_end() is slow compared '
|
| 123 |
+
'to the batch update (%f). Check your callbacks.'
|
| 124 |
+
% delta_t_median)
|
| 125 |
+
|
| 126 |
+
def on_train_begin(self, logs=None):
|
| 127 |
+
"""Called at the beginning of training.
|
| 128 |
+
|
| 129 |
+
# Arguments
|
| 130 |
+
logs: dictionary of logs.
|
| 131 |
+
"""
|
| 132 |
+
logs = logs or {}
|
| 133 |
+
for callback in self.callbacks:
|
| 134 |
+
callback.on_train_begin(logs)
|
| 135 |
+
|
| 136 |
+
def on_train_end(self, logs=None):
|
| 137 |
+
"""Called at the end of training.
|
| 138 |
+
|
| 139 |
+
# Arguments
|
| 140 |
+
logs: dictionary of logs.
|
| 141 |
+
"""
|
| 142 |
+
logs = logs or {}
|
| 143 |
+
for callback in self.callbacks:
|
| 144 |
+
callback.on_train_end(logs)
|
| 145 |
+
|
| 146 |
+
def __iter__(self):
|
| 147 |
+
return iter(self.callbacks)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Callback(object):
|
| 151 |
+
"""Abstract base class used to build new callbacks.
|
| 152 |
+
|
| 153 |
+
# Properties
|
| 154 |
+
params: dict. Training parameters
|
| 155 |
+
(eg. verbosity, batch size, number of epochs...).
|
| 156 |
+
model: instance of `keras.models.Model`.
|
| 157 |
+
Reference of the model being trained.
|
| 158 |
+
|
| 159 |
+
The `logs` dictionary that callback methods
|
| 160 |
+
take as argument will contain keys for quantities relevant to
|
| 161 |
+
the current batch or epoch.
|
| 162 |
+
|
| 163 |
+
Currently, the `.fit()` method of the `Sequential` model class
|
| 164 |
+
will include the following quantities in the `logs` that
|
| 165 |
+
it passes to its callbacks:
|
| 166 |
+
|
| 167 |
+
on_epoch_end: logs include `acc` and `loss`, and
|
| 168 |
+
optionally include `val_loss`
|
| 169 |
+
(if validation is enabled in `fit`), and `val_acc`
|
| 170 |
+
(if validation and accuracy monitoring are enabled).
|
| 171 |
+
on_batch_begin: logs include `size`,
|
| 172 |
+
the number of samples in the current batch.
|
| 173 |
+
on_batch_end: logs include `loss`, and optionally `acc`
|
| 174 |
+
(if accuracy monitoring is enabled).
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self):
|
| 178 |
+
self.validation_data = None
|
| 179 |
+
self.model = None
|
| 180 |
+
|
| 181 |
+
def set_params(self, params):
|
| 182 |
+
self.params = params
|
| 183 |
+
|
| 184 |
+
def set_model(self, model):
|
| 185 |
+
self.model = model
|
| 186 |
+
|
| 187 |
+
def on_epoch_begin(self, epoch, logs=None):
|
| 188 |
+
pass
|
| 189 |
+
|
| 190 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
def on_batch_begin(self, batch, logs=None):
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
def on_batch_end(self, batch, logs=None):
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
def on_train_begin(self, logs=None):
|
| 200 |
+
pass
|
| 201 |
+
|
| 202 |
+
def on_train_end(self, logs=None):
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class BaseLogger(Callback):
|
| 207 |
+
"""Callback that accumulates epoch averages of metrics.
|
| 208 |
+
|
| 209 |
+
This callback is automatically applied to every Keras model.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def on_epoch_begin(self, epoch, logs=None):
|
| 213 |
+
self.seen = 0
|
| 214 |
+
self.totals = {}
|
| 215 |
+
|
| 216 |
+
def on_batch_end(self, batch, logs=None):
|
| 217 |
+
logs = logs or {}
|
| 218 |
+
batch_size = logs.get('size', 0)
|
| 219 |
+
self.seen += batch_size
|
| 220 |
+
|
| 221 |
+
for k, v in logs.items():
|
| 222 |
+
if k in self.totals:
|
| 223 |
+
self.totals[k] += v * batch_size
|
| 224 |
+
else:
|
| 225 |
+
self.totals[k] = v * batch_size
|
| 226 |
+
|
| 227 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 228 |
+
if logs is not None:
|
| 229 |
+
for k in self.params['metrics']:
|
| 230 |
+
if k in self.totals:
|
| 231 |
+
# Make value available to next callbacks.
|
| 232 |
+
logs[k] = self.totals[k] / self.seen
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class TerminateOnNaN(Callback):
|
| 236 |
+
"""Callback that terminates training when a NaN loss is encountered.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(self):
|
| 240 |
+
super(TerminateOnNaN, self).__init__()
|
| 241 |
+
|
| 242 |
+
def on_batch_end(self, batch, logs=None):
|
| 243 |
+
logs = logs or {}
|
| 244 |
+
loss = logs.get('loss')
|
| 245 |
+
if loss is not None:
|
| 246 |
+
if np.isnan(loss) or np.isinf(loss):
|
| 247 |
+
print('Batch %d: Invalid loss, terminating training' % (batch))
|
| 248 |
+
self.model.stop_training = True
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class History(Callback):
|
| 252 |
+
"""Callback that records events into a `History` object.
|
| 253 |
+
|
| 254 |
+
This callback is automatically applied to
|
| 255 |
+
every Keras model. The `History` object
|
| 256 |
+
gets returned by the `fit` method of models.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def on_train_begin(self, logs=None):
|
| 260 |
+
self.epoch = []
|
| 261 |
+
self.history = {}
|
| 262 |
+
|
| 263 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 264 |
+
logs = logs or {}
|
| 265 |
+
self.epoch.append(epoch)
|
| 266 |
+
for k, v in logs.items():
|
| 267 |
+
self.history.setdefault(k, []).append(v)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class ModelCheckpoint(Callback):
|
| 271 |
+
"""Save the model after every epoch.
|
| 272 |
+
|
| 273 |
+
`filepath` can contain named formatting options,
|
| 274 |
+
which will be filled the value of `epoch` and
|
| 275 |
+
keys in `logs` (passed in `on_epoch_end`).
|
| 276 |
+
|
| 277 |
+
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
|
| 278 |
+
then the model checkpoints will be saved with the epoch number and
|
| 279 |
+
the validation loss in the filename.
|
| 280 |
+
|
| 281 |
+
# Arguments
|
| 282 |
+
filepath: string, path to save the model file.
|
| 283 |
+
monitor: quantity to monitor.
|
| 284 |
+
verbose: verbosity mode, 0 or 1.
|
| 285 |
+
save_best_only: if `save_best_only=True`,
|
| 286 |
+
the latest best model according to
|
| 287 |
+
the quantity monitored will not be overwritten.
|
| 288 |
+
mode: one of {auto, min, max}.
|
| 289 |
+
If `save_best_only=True`, the decision
|
| 290 |
+
to overwrite the current save file is made
|
| 291 |
+
based on either the maximization or the
|
| 292 |
+
minimization of the monitored quantity. For `val_acc`,
|
| 293 |
+
this should be `max`, for `val_loss` this should
|
| 294 |
+
be `min`, etc. In `auto` mode, the direction is
|
| 295 |
+
automatically inferred from the name of the monitored quantity.
|
| 296 |
+
save_weights_only: if True, then only the model's weights will be
|
| 297 |
+
saved (`torch.save(self.model.state_dict(), filepath)`), else the full model
|
| 298 |
+
is saved (`torch.save(self.model.state_dict(), filepath)`).
|
| 299 |
+
period: Interval (number of epochs) between checkpoints.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(self, filepath, monitor='val_loss', verbose=0,
|
| 303 |
+
save_best_only=False, save_weights_only=False,
|
| 304 |
+
mode='auto', period=1):
|
| 305 |
+
super(ModelCheckpoint, self).__init__()
|
| 306 |
+
self.monitor = monitor
|
| 307 |
+
self.verbose = verbose
|
| 308 |
+
self.filepath = filepath
|
| 309 |
+
self.save_best_only = save_best_only
|
| 310 |
+
self.save_weights_only = save_weights_only
|
| 311 |
+
self.period = period
|
| 312 |
+
self.epochs_since_last_save = 0
|
| 313 |
+
|
| 314 |
+
if mode not in ['auto', 'min', 'max']:
|
| 315 |
+
warnings.warn('ModelCheckpoint mode %s is unknown, '
|
| 316 |
+
'fallback to auto mode.' % (mode),
|
| 317 |
+
RuntimeWarning)
|
| 318 |
+
mode = 'auto'
|
| 319 |
+
|
| 320 |
+
if mode == 'min':
|
| 321 |
+
self.monitor_op = np.less
|
| 322 |
+
self.best = np.Inf
|
| 323 |
+
elif mode == 'max':
|
| 324 |
+
self.monitor_op = np.greater
|
| 325 |
+
self.best = -np.Inf
|
| 326 |
+
else:
|
| 327 |
+
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
|
| 328 |
+
self.monitor_op = np.greater
|
| 329 |
+
self.best = -np.Inf
|
| 330 |
+
else:
|
| 331 |
+
self.monitor_op = np.less
|
| 332 |
+
self.best = np.Inf
|
| 333 |
+
|
| 334 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 335 |
+
import torch
|
| 336 |
+
logs = logs or {}
|
| 337 |
+
self.epochs_since_last_save += 1
|
| 338 |
+
if self.epochs_since_last_save >= self.period:
|
| 339 |
+
self.epochs_since_last_save = 0
|
| 340 |
+
filepath = self.filepath.format(epoch=epoch + 1, **logs)
|
| 341 |
+
if self.save_best_only:
|
| 342 |
+
current = logs.get(self.monitor)
|
| 343 |
+
if current is None:
|
| 344 |
+
warnings.warn('Can save best model only with %s available, '
|
| 345 |
+
'skipping.' % (self.monitor), RuntimeWarning)
|
| 346 |
+
else:
|
| 347 |
+
if self.monitor_op(current, self.best):
|
| 348 |
+
if self.verbose > 0:
|
| 349 |
+
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
|
| 350 |
+
' saving model to %s'
|
| 351 |
+
% (epoch + 1, self.monitor, self.best,
|
| 352 |
+
current, filepath))
|
| 353 |
+
self.best = current
|
| 354 |
+
if self.save_weights_only:
|
| 355 |
+
torch.save(self.model.state_dict(), filepath)
|
| 356 |
+
else:
|
| 357 |
+
torch.save(self.model.state_dict(), filepath)
|
| 358 |
+
else:
|
| 359 |
+
if self.verbose > 0:
|
| 360 |
+
print('\nEpoch %05d: %s did not improve' %
|
| 361 |
+
(epoch + 1, self.monitor))
|
| 362 |
+
else:
|
| 363 |
+
if self.verbose > 0:
|
| 364 |
+
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
|
| 365 |
+
if self.save_weights_only:
|
| 366 |
+
torch.save(self.model.state_dict(), filepath)
|
| 367 |
+
else:
|
| 368 |
+
torch.save(self.model.state_dict(), filepath)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class EarlyStopping(Callback):
|
| 372 |
+
"""Stop training when a monitored quantity has stopped improving.
|
| 373 |
+
|
| 374 |
+
# Arguments
|
| 375 |
+
monitor: quantity to be monitored.
|
| 376 |
+
min_delta: minimum change in the monitored quantity
|
| 377 |
+
to qualify as an improvement, i.e. an absolute
|
| 378 |
+
change of less than min_delta, will count as no
|
| 379 |
+
improvement.
|
| 380 |
+
patience: number of epochs with no improvement
|
| 381 |
+
after which training will be stopped.
|
| 382 |
+
verbose: verbosity mode.
|
| 383 |
+
mode: one of {auto, min, max}. In `min` mode,
|
| 384 |
+
training will stop when the quantity
|
| 385 |
+
monitored has stopped decreasing; in `max`
|
| 386 |
+
mode it will stop when the quantity
|
| 387 |
+
monitored has stopped increasing; in `auto`
|
| 388 |
+
mode, the direction is automatically inferred
|
| 389 |
+
from the name of the monitored quantity.
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
def __init__(self, monitor='val_loss',
|
| 393 |
+
min_delta=0, patience=0, verbose=0, mode='auto'):
|
| 394 |
+
super(EarlyStopping, self).__init__()
|
| 395 |
+
|
| 396 |
+
self.monitor = monitor
|
| 397 |
+
self.patience = patience
|
| 398 |
+
self.verbose = verbose
|
| 399 |
+
self.min_delta = min_delta
|
| 400 |
+
self.wait = 0
|
| 401 |
+
self.stopped_epoch = 0
|
| 402 |
+
|
| 403 |
+
if mode not in ['auto', 'min', 'max']:
|
| 404 |
+
warnings.warn('EarlyStopping mode %s is unknown, '
|
| 405 |
+
'fallback to auto mode.' % mode,
|
| 406 |
+
RuntimeWarning)
|
| 407 |
+
mode = 'auto'
|
| 408 |
+
|
| 409 |
+
if mode == 'min':
|
| 410 |
+
self.monitor_op = np.less
|
| 411 |
+
elif mode == 'max':
|
| 412 |
+
self.monitor_op = np.greater
|
| 413 |
+
else:
|
| 414 |
+
if 'acc' in self.monitor:
|
| 415 |
+
self.monitor_op = np.greater
|
| 416 |
+
else:
|
| 417 |
+
self.monitor_op = np.less
|
| 418 |
+
|
| 419 |
+
if self.monitor_op == np.greater:
|
| 420 |
+
self.min_delta *= 1
|
| 421 |
+
else:
|
| 422 |
+
self.min_delta *= -1
|
| 423 |
+
|
| 424 |
+
def on_train_begin(self, logs=None):
|
| 425 |
+
# Allow instances to be re-used
|
| 426 |
+
self.wait = 0
|
| 427 |
+
self.stopped_epoch = 0
|
| 428 |
+
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
| 429 |
+
|
| 430 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 431 |
+
current = logs.get(self.monitor)
|
| 432 |
+
if current is None:
|
| 433 |
+
warnings.warn(
|
| 434 |
+
'Early stopping conditioned on metric `%s` '
|
| 435 |
+
'which is not available. Available metrics are: %s' %
|
| 436 |
+
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
|
| 437 |
+
)
|
| 438 |
+
return
|
| 439 |
+
if self.monitor_op(current - self.min_delta, self.best):
|
| 440 |
+
self.best = current
|
| 441 |
+
self.wait = 0
|
| 442 |
+
else:
|
| 443 |
+
self.wait += 1
|
| 444 |
+
if self.wait >= self.patience:
|
| 445 |
+
self.stopped_epoch = epoch
|
| 446 |
+
self.model.stop_training = True
|
| 447 |
+
|
| 448 |
+
def on_train_end(self, logs=None):
|
| 449 |
+
if self.stopped_epoch > 0 and self.verbose > 0:
|
| 450 |
+
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class RemoteMonitor(Callback):
|
| 454 |
+
"""Callback used to stream events to a server.
|
| 455 |
+
|
| 456 |
+
Requires the `requests` library.
|
| 457 |
+
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
|
| 458 |
+
HTTP POST, with a `images` argument which is a
|
| 459 |
+
JSON-encoded dictionary of event images.
|
| 460 |
+
|
| 461 |
+
# Arguments
|
| 462 |
+
root: String; root url of the target server.
|
| 463 |
+
path: String; path relative to `root` to which the events will be sent.
|
| 464 |
+
field: String; JSON field under which the images will be stored.
|
| 465 |
+
headers: Dictionary; optional custom HTTP headers.
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
def __init__(self,
|
| 469 |
+
root='http://localhost:9000',
|
| 470 |
+
path='/publish/epoch/end/',
|
| 471 |
+
field='images',
|
| 472 |
+
headers=None):
|
| 473 |
+
super(RemoteMonitor, self).__init__()
|
| 474 |
+
|
| 475 |
+
self.root = root
|
| 476 |
+
self.path = path
|
| 477 |
+
self.field = field
|
| 478 |
+
self.headers = headers
|
| 479 |
+
|
| 480 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 481 |
+
if requests is None:
|
| 482 |
+
raise ImportError('RemoteMonitor requires '
|
| 483 |
+
'the `requests` library.')
|
| 484 |
+
logs = logs or {}
|
| 485 |
+
send = {}
|
| 486 |
+
send['epoch'] = epoch
|
| 487 |
+
for k, v in logs.items():
|
| 488 |
+
if isinstance(v, (np.ndarray, np.generic)):
|
| 489 |
+
send[k] = v.item()
|
| 490 |
+
else:
|
| 491 |
+
send[k] = v
|
| 492 |
+
try:
|
| 493 |
+
requests.post(self.root + self.path,
|
| 494 |
+
{self.field: json.dumps(send)},
|
| 495 |
+
headers=self.headers)
|
| 496 |
+
except requests.exceptions.RequestException:
|
| 497 |
+
warnings.warn('Warning: could not reach RemoteMonitor '
|
| 498 |
+
'root server at ' + str(self.root))
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class TensorBoard(Callback):
|
| 502 |
+
"""TensorBoard basic visualizations.
|
| 503 |
+
|
| 504 |
+
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
|
| 505 |
+
is a visualization tool provided with TensorFlow.
|
| 506 |
+
|
| 507 |
+
This callback writes a log for TensorBoard, which allows
|
| 508 |
+
you to visualize dynamic graphs of your training and test
|
| 509 |
+
metrics, as well as activation histograms for the different
|
| 510 |
+
layers in your model.
|
| 511 |
+
|
| 512 |
+
If you have installed TensorFlow with pip, you should be able
|
| 513 |
+
to launch TensorBoard from the command line:
|
| 514 |
+
```sh
|
| 515 |
+
tensorboard --logdir=/full_path_to_your_logs
|
| 516 |
+
```
|
| 517 |
+
|
| 518 |
+
When using a backend other than TensorFlow, TensorBoard will still work
|
| 519 |
+
(if you have TensorFlow installed), but the only feature available will
|
| 520 |
+
be the display of the losses and metrics plots.
|
| 521 |
+
|
| 522 |
+
# Arguments
|
| 523 |
+
log_dir: the path of the directory where to save the log
|
| 524 |
+
files to be parsed by TensorBoard.
|
| 525 |
+
histogram_freq: frequency (in epochs) at which to compute activation
|
| 526 |
+
and weight histograms for the layers of the model. If set to 0,
|
| 527 |
+
histograms won't be computed. Validation images (or split) must be
|
| 528 |
+
specified for histogram visualizations.
|
| 529 |
+
write_graph: whether to visualize the graph in TensorBoard.
|
| 530 |
+
The log file can become quite large when
|
| 531 |
+
write_graph is set to True.
|
| 532 |
+
write_grads: whether to visualize gradient histograms in TensorBoard.
|
| 533 |
+
`histogram_freq` must be greater than 0.
|
| 534 |
+
batch_size: size of batch of inputs to feed to the network
|
| 535 |
+
for histograms computation.
|
| 536 |
+
write_images: whether to write model weights to visualize as
|
| 537 |
+
image in TensorBoard.
|
| 538 |
+
embeddings_freq: frequency (in epochs) at which selected embedding
|
| 539 |
+
layers will be saved.
|
| 540 |
+
embeddings_layer_names: a list of names of layers to keep eye on. If
|
| 541 |
+
None or empty list all the embedding layer will be watched.
|
| 542 |
+
embeddings_metadata: a dictionary which maps layer name to a file name
|
| 543 |
+
in which metadata for this embedding layer is saved. See the
|
| 544 |
+
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
|
| 545 |
+
about metadata files format. In case if the same metadata file is
|
| 546 |
+
used for all embedding layers, string can be passed.
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
def __init__(self, log_dir='./logs',
|
| 550 |
+
histogram_freq=0,
|
| 551 |
+
batch_size=32,
|
| 552 |
+
write_graph=True,
|
| 553 |
+
write_grads=False,
|
| 554 |
+
write_images=False,
|
| 555 |
+
embeddings_freq=0,
|
| 556 |
+
embeddings_layer_names=None,
|
| 557 |
+
embeddings_metadata=None):
|
| 558 |
+
super(TensorBoard, self).__init__()
|
| 559 |
+
global tf, projector
|
| 560 |
+
try:
|
| 561 |
+
import tensorflow as tf
|
| 562 |
+
from tensorflow.contrib.tensorboard.plugins import projector
|
| 563 |
+
except ImportError:
|
| 564 |
+
raise ImportError('You need the TensorFlow module installed to use TensorBoard.')
|
| 565 |
+
|
| 566 |
+
if K.backend() != 'tensorflow':
|
| 567 |
+
if histogram_freq != 0:
|
| 568 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
| 569 |
+
'histogram_freq was set to 0')
|
| 570 |
+
histogram_freq = 0
|
| 571 |
+
if write_graph:
|
| 572 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
| 573 |
+
'write_graph was set to False')
|
| 574 |
+
write_graph = False
|
| 575 |
+
if write_images:
|
| 576 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
| 577 |
+
'write_images was set to False')
|
| 578 |
+
write_images = False
|
| 579 |
+
if embeddings_freq != 0:
|
| 580 |
+
warnings.warn('You are not using the TensorFlow backend. '
|
| 581 |
+
'embeddings_freq was set to 0')
|
| 582 |
+
embeddings_freq = 0
|
| 583 |
+
|
| 584 |
+
self.log_dir = log_dir
|
| 585 |
+
self.histogram_freq = histogram_freq
|
| 586 |
+
self.merged = None
|
| 587 |
+
self.write_graph = write_graph
|
| 588 |
+
self.write_grads = write_grads
|
| 589 |
+
self.write_images = write_images
|
| 590 |
+
self.embeddings_freq = embeddings_freq
|
| 591 |
+
self.embeddings_layer_names = embeddings_layer_names
|
| 592 |
+
self.embeddings_metadata = embeddings_metadata or {}
|
| 593 |
+
self.batch_size = batch_size
|
| 594 |
+
|
| 595 |
+
def set_model(self, model):
|
| 596 |
+
self.model = model
|
| 597 |
+
if K.backend() == 'tensorflow':
|
| 598 |
+
self.sess = K.get_session()
|
| 599 |
+
if self.histogram_freq and self.merged is None:
|
| 600 |
+
for layer in self.model.layers:
|
| 601 |
+
|
| 602 |
+
for weight in layer.weights:
|
| 603 |
+
mapped_weight_name = weight.name.replace(':', '_')
|
| 604 |
+
tf.summary.histogram(mapped_weight_name, weight)
|
| 605 |
+
if self.write_grads:
|
| 606 |
+
grads = model.optimizer.get_gradients(model.total_loss,
|
| 607 |
+
weight)
|
| 608 |
+
|
| 609 |
+
def is_indexed_slices(grad):
|
| 610 |
+
return type(grad).__name__ == 'IndexedSlices'
|
| 611 |
+
grads = [
|
| 612 |
+
grad.values if is_indexed_slices(grad) else grad
|
| 613 |
+
for grad in grads]
|
| 614 |
+
tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads)
|
| 615 |
+
if self.write_images:
|
| 616 |
+
w_img = tf.squeeze(weight)
|
| 617 |
+
shape = K.int_shape(w_img)
|
| 618 |
+
if len(shape) == 2: # dense layer kernel case
|
| 619 |
+
if shape[0] > shape[1]:
|
| 620 |
+
w_img = tf.transpose(w_img)
|
| 621 |
+
shape = K.int_shape(w_img)
|
| 622 |
+
w_img = tf.reshape(w_img, [1,
|
| 623 |
+
shape[0],
|
| 624 |
+
shape[1],
|
| 625 |
+
1])
|
| 626 |
+
elif len(shape) == 3: # convnet case
|
| 627 |
+
if K.image_data_format() == 'channels_last':
|
| 628 |
+
# switch to channels_first to display
|
| 629 |
+
# every kernel as a separate image
|
| 630 |
+
w_img = tf.transpose(w_img, perm=[2, 0, 1])
|
| 631 |
+
shape = K.int_shape(w_img)
|
| 632 |
+
w_img = tf.reshape(w_img, [shape[0],
|
| 633 |
+
shape[1],
|
| 634 |
+
shape[2],
|
| 635 |
+
1])
|
| 636 |
+
elif len(shape) == 1: # bias case
|
| 637 |
+
w_img = tf.reshape(w_img, [1,
|
| 638 |
+
shape[0],
|
| 639 |
+
1,
|
| 640 |
+
1])
|
| 641 |
+
else:
|
| 642 |
+
# not possible to handle 3D convnets etc.
|
| 643 |
+
continue
|
| 644 |
+
|
| 645 |
+
shape = K.int_shape(w_img)
|
| 646 |
+
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
|
| 647 |
+
tf.summary.image(mapped_weight_name, w_img)
|
| 648 |
+
|
| 649 |
+
if hasattr(layer, 'output'):
|
| 650 |
+
tf.summary.histogram('{}_out'.format(layer.name),
|
| 651 |
+
layer.output)
|
| 652 |
+
self.merged = tf.summary.merge_all()
|
| 653 |
+
|
| 654 |
+
if self.write_graph:
|
| 655 |
+
self.writer = tf.summary.FileWriter(self.log_dir,
|
| 656 |
+
self.sess.graph)
|
| 657 |
+
else:
|
| 658 |
+
self.writer = tf.summary.FileWriter(self.log_dir)
|
| 659 |
+
|
| 660 |
+
if self.embeddings_freq:
|
| 661 |
+
embeddings_layer_names = self.embeddings_layer_names
|
| 662 |
+
|
| 663 |
+
if not embeddings_layer_names:
|
| 664 |
+
embeddings_layer_names = [layer.name for layer in self.model.layers
|
| 665 |
+
if type(layer).__name__ == 'Embedding']
|
| 666 |
+
|
| 667 |
+
embeddings = {layer.name: layer.weights[0]
|
| 668 |
+
for layer in self.model.layers
|
| 669 |
+
if layer.name in embeddings_layer_names}
|
| 670 |
+
|
| 671 |
+
self.saver = tf.train.Saver(list(embeddings.values()))
|
| 672 |
+
|
| 673 |
+
embeddings_metadata = {}
|
| 674 |
+
|
| 675 |
+
if not isinstance(self.embeddings_metadata, str):
|
| 676 |
+
embeddings_metadata = self.embeddings_metadata
|
| 677 |
+
else:
|
| 678 |
+
embeddings_metadata = {layer_name: self.embeddings_metadata
|
| 679 |
+
for layer_name in embeddings.keys()}
|
| 680 |
+
|
| 681 |
+
config = projector.ProjectorConfig()
|
| 682 |
+
self.embeddings_ckpt_path = os.path.join(self.log_dir,
|
| 683 |
+
'keras_embedding.ckpt')
|
| 684 |
+
|
| 685 |
+
for layer_name, tensor in embeddings.items():
|
| 686 |
+
embedding = config.embeddings.add()
|
| 687 |
+
embedding.tensor_name = tensor.name
|
| 688 |
+
|
| 689 |
+
if layer_name in embeddings_metadata:
|
| 690 |
+
embedding.metadata_path = embeddings_metadata[layer_name]
|
| 691 |
+
|
| 692 |
+
projector.visualize_embeddings(self.writer, config)
|
| 693 |
+
|
| 694 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 695 |
+
logs = logs or {}
|
| 696 |
+
|
| 697 |
+
if not self.validation_data and self.histogram_freq:
|
| 698 |
+
raise ValueError('If printing histograms, validation_data must be '
|
| 699 |
+
'provided, and cannot be a generator.')
|
| 700 |
+
if self.validation_data and self.histogram_freq:
|
| 701 |
+
if epoch % self.histogram_freq == 0:
|
| 702 |
+
|
| 703 |
+
val_data = self.validation_data
|
| 704 |
+
tensors = (self.model.inputs +
|
| 705 |
+
self.model.targets +
|
| 706 |
+
self.model.sample_weights)
|
| 707 |
+
|
| 708 |
+
if self.model.uses_learning_phase:
|
| 709 |
+
tensors += [K.learning_phase()]
|
| 710 |
+
|
| 711 |
+
assert len(val_data) == len(tensors)
|
| 712 |
+
val_size = val_data[0].shape[0]
|
| 713 |
+
i = 0
|
| 714 |
+
while i < val_size:
|
| 715 |
+
step = min(self.batch_size, val_size - i)
|
| 716 |
+
if self.model.uses_learning_phase:
|
| 717 |
+
# do not slice the learning phase
|
| 718 |
+
batch_val = [x[i:i + step] for x in val_data[:-1]]
|
| 719 |
+
batch_val.append(val_data[-1])
|
| 720 |
+
else:
|
| 721 |
+
batch_val = [x[i:i + step] for x in val_data]
|
| 722 |
+
assert len(batch_val) == len(tensors)
|
| 723 |
+
feed_dict = dict(zip(tensors, batch_val))
|
| 724 |
+
result = self.sess.run([self.merged], feed_dict=feed_dict)
|
| 725 |
+
summary_str = result[0]
|
| 726 |
+
self.writer.add_summary(summary_str, epoch)
|
| 727 |
+
i += self.batch_size
|
| 728 |
+
|
| 729 |
+
if self.embeddings_freq and self.embeddings_ckpt_path:
|
| 730 |
+
if epoch % self.embeddings_freq == 0:
|
| 731 |
+
self.saver.save(self.sess,
|
| 732 |
+
self.embeddings_ckpt_path,
|
| 733 |
+
epoch)
|
| 734 |
+
|
| 735 |
+
for name, value in logs.items():
|
| 736 |
+
if name in ['batch', 'size']:
|
| 737 |
+
continue
|
| 738 |
+
summary = tf.Summary()
|
| 739 |
+
summary_value = summary.value.add()
|
| 740 |
+
summary_value.simple_value = value.item()
|
| 741 |
+
summary_value.tag = name
|
| 742 |
+
self.writer.add_summary(summary, epoch)
|
| 743 |
+
self.writer.flush()
|
| 744 |
+
|
| 745 |
+
def on_train_end(self, _):
|
| 746 |
+
self.writer.close()
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class CSVLogger(Callback):
|
| 750 |
+
"""Callback that streams epoch results to a csv file.
|
| 751 |
+
|
| 752 |
+
Supports all values that can be represented as a string,
|
| 753 |
+
including 1D iterables such as np.ndarray.
|
| 754 |
+
|
| 755 |
+
# Example
|
| 756 |
+
|
| 757 |
+
```python
|
| 758 |
+
csv_logger = CSVLogger('training.log')
|
| 759 |
+
model.fit(X_train, Y_train, callbacks=[csv_logger])
|
| 760 |
+
```
|
| 761 |
+
|
| 762 |
+
# Arguments
|
| 763 |
+
filename: filename of the csv file, e.g. 'run/log.csv'.
|
| 764 |
+
separator: string used to separate elements in the csv file.
|
| 765 |
+
append: True: append if file exists (useful for continuing
|
| 766 |
+
training). False: overwrite existing file,
|
| 767 |
+
output_on_train_end: An additional output file to write to
|
| 768 |
+
write to when training ends. An example is
|
| 769 |
+
CSVLogger(filename='./mylog.csv', output_on_train_end=os.sys.stdout)
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
def __init__(self, filename, separator=',', append=False, output_on_train_end=None):
|
| 773 |
+
self.sep = separator
|
| 774 |
+
self.filename = filename
|
| 775 |
+
self.append = append
|
| 776 |
+
self.writer = None
|
| 777 |
+
self.keys = None
|
| 778 |
+
self.append_header = True
|
| 779 |
+
self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
|
| 780 |
+
self.output_on_train_end = output_on_train_end
|
| 781 |
+
super(CSVLogger, self).__init__()
|
| 782 |
+
|
| 783 |
+
def on_train_begin(self, logs=None):
|
| 784 |
+
if self.append:
|
| 785 |
+
if os.path.exists(self.filename):
|
| 786 |
+
with open(self.filename, 'r' + self.file_flags) as f:
|
| 787 |
+
self.append_header = not bool(len(f.readline()))
|
| 788 |
+
self.csv_file = open(self.filename, 'a' + self.file_flags)
|
| 789 |
+
else:
|
| 790 |
+
self.csv_file = open(self.filename, 'w' + self.file_flags)
|
| 791 |
+
|
| 792 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 793 |
+
logs = logs or {}
|
| 794 |
+
|
| 795 |
+
def handle_value(k):
|
| 796 |
+
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
| 797 |
+
if isinstance(k, six.string_types):
|
| 798 |
+
return k
|
| 799 |
+
elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
|
| 800 |
+
return '"[%s]"' % (', '.join(map(str, k)))
|
| 801 |
+
else:
|
| 802 |
+
return k
|
| 803 |
+
|
| 804 |
+
if self.keys is None:
|
| 805 |
+
self.keys = sorted(logs.keys())
|
| 806 |
+
|
| 807 |
+
if self.model is not None and getattr(self.model, 'stop_training', False):
|
| 808 |
+
# We set NA so that csv parsers do not fail for this last epoch.
|
| 809 |
+
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
|
| 810 |
+
|
| 811 |
+
if not self.writer:
|
| 812 |
+
class CustomDialect(csv.excel):
|
| 813 |
+
delimiter = self.sep
|
| 814 |
+
|
| 815 |
+
self.writer = csv.DictWriter(self.csv_file,
|
| 816 |
+
fieldnames=['epoch'] + self.keys, dialect=CustomDialect)
|
| 817 |
+
if self.append_header:
|
| 818 |
+
self.writer.writeheader()
|
| 819 |
+
|
| 820 |
+
row_dict = OrderedDict({'epoch': epoch})
|
| 821 |
+
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
|
| 822 |
+
self.writer.writerow(row_dict)
|
| 823 |
+
self.csv_file.flush()
|
| 824 |
+
|
| 825 |
+
def on_train_end(self, logs=None):
|
| 826 |
+
self.csv_file.close()
|
| 827 |
+
if os.path.exists(self.filename):
|
| 828 |
+
with open(self.filename, 'r' + self.file_flags) as f:
|
| 829 |
+
print(f.read(), file=self.output_on_train_end)
|
| 830 |
+
self.writer = None
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
class LambdaCallback(Callback):
|
| 834 |
+
r"""Callback for creating simple, custom callbacks on-the-fly.
|
| 835 |
+
|
| 836 |
+
This callback is constructed with anonymous functions that will be called
|
| 837 |
+
at the appropriate time. Note that the callbacks expects positional
|
| 838 |
+
arguments, as:
|
| 839 |
+
|
| 840 |
+
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
| 841 |
+
`epoch`, `logs`
|
| 842 |
+
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
| 843 |
+
`batch`, `logs`
|
| 844 |
+
- `on_train_begin` and `on_train_end` expect one positional argument:
|
| 845 |
+
`logs`
|
| 846 |
+
|
| 847 |
+
# Arguments
|
| 848 |
+
on_epoch_begin: called at the beginning of every epoch.
|
| 849 |
+
on_epoch_end: called at the end of every epoch.
|
| 850 |
+
on_batch_begin: called at the beginning of every batch.
|
| 851 |
+
on_batch_end: called at the end of every batch.
|
| 852 |
+
on_train_begin: called at the beginning of model training.
|
| 853 |
+
on_train_end: called at the end of model training.
|
| 854 |
+
|
| 855 |
+
# Example
|
| 856 |
+
|
| 857 |
+
```python
|
| 858 |
+
# Print the batch number at the beginning of every batch.
|
| 859 |
+
batch_print_callback = LambdaCallback(
|
| 860 |
+
on_batch_begin=lambda batch,logs: print(batch))
|
| 861 |
+
|
| 862 |
+
# Stream the epoch loss to a file in JSON format. The file content
|
| 863 |
+
# is not well-formed JSON but rather has a JSON object per line.
|
| 864 |
+
import json
|
| 865 |
+
json_log = open('loss_log.json', mode='wt', buffering=1)
|
| 866 |
+
json_logging_callback = LambdaCallback(
|
| 867 |
+
on_epoch_end=lambda epoch, logs: json_log.write(
|
| 868 |
+
json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
|
| 869 |
+
on_train_end=lambda logs: json_log.close()
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
# Terminate some processes after having finished model training.
|
| 873 |
+
processes = ...
|
| 874 |
+
cleanup_callback = LambdaCallback(
|
| 875 |
+
on_train_end=lambda logs: [
|
| 876 |
+
p.terminate() for p in processes if p.is_alive()])
|
| 877 |
+
|
| 878 |
+
model.fit(...,
|
| 879 |
+
callbacks=[batch_print_callback,
|
| 880 |
+
json_logging_callback,
|
| 881 |
+
cleanup_callback])
|
| 882 |
+
```
|
| 883 |
+
"""
|
| 884 |
+
|
| 885 |
+
def __init__(self,
|
| 886 |
+
on_epoch_begin=None,
|
| 887 |
+
on_epoch_end=None,
|
| 888 |
+
on_batch_begin=None,
|
| 889 |
+
on_batch_end=None,
|
| 890 |
+
on_train_begin=None,
|
| 891 |
+
on_train_end=None,
|
| 892 |
+
**kwargs):
|
| 893 |
+
super(LambdaCallback, self).__init__()
|
| 894 |
+
self.__dict__.update(kwargs)
|
| 895 |
+
if on_epoch_begin is not None:
|
| 896 |
+
self.on_epoch_begin = on_epoch_begin
|
| 897 |
+
else:
|
| 898 |
+
self.on_epoch_begin = lambda epoch, logs: None
|
| 899 |
+
if on_epoch_end is not None:
|
| 900 |
+
self.on_epoch_end = on_epoch_end
|
| 901 |
+
else:
|
| 902 |
+
self.on_epoch_end = lambda epoch, logs: None
|
| 903 |
+
if on_batch_begin is not None:
|
| 904 |
+
self.on_batch_begin = on_batch_begin
|
| 905 |
+
else:
|
| 906 |
+
self.on_batch_begin = lambda batch, logs: None
|
| 907 |
+
if on_batch_end is not None:
|
| 908 |
+
self.on_batch_end = on_batch_end
|
| 909 |
+
else:
|
| 910 |
+
self.on_batch_end = lambda batch, logs: None
|
| 911 |
+
if on_train_begin is not None:
|
| 912 |
+
self.on_train_begin = on_train_begin
|
| 913 |
+
else:
|
| 914 |
+
self.on_train_begin = lambda logs: None
|
| 915 |
+
if on_train_end is not None:
|
| 916 |
+
self.on_train_end = on_train_end
|
| 917 |
+
else:
|
| 918 |
+
self.on_train_end = lambda logs: None
|
| 919 |
+
from sys import stderr
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class TQDMCallback(Callback):
|
| 923 |
+
def __init__(self, outer_description="Training",
|
| 924 |
+
inner_description_initial="Epoch: {epoch}",
|
| 925 |
+
inner_description_update="Epoch: {epoch} - {metrics}",
|
| 926 |
+
metric_format="{name}: {value:0.3f}",
|
| 927 |
+
separator=", ",
|
| 928 |
+
leave_inner=True,
|
| 929 |
+
leave_outer=True,
|
| 930 |
+
show_inner=True,
|
| 931 |
+
show_outer=True,
|
| 932 |
+
output_file=stderr,
|
| 933 |
+
initial=0):
|
| 934 |
+
"""
|
| 935 |
+
Construct a callback that will create and update progress bars.
|
| 936 |
+
|
| 937 |
+
:param outer_description: string for outer progress bar
|
| 938 |
+
:param inner_description_initial: initial format for epoch ("Epoch: {epoch}")
|
| 939 |
+
:param inner_description_update: format after metrics collected ("Epoch: {epoch} - {metrics}")
|
| 940 |
+
:param metric_format: format for each metric name/value pair ("{name}: {value:0.3f}")
|
| 941 |
+
:param separator: separator between metrics (", ")
|
| 942 |
+
:param leave_inner: True to leave inner bars
|
| 943 |
+
:param leave_outer: True to leave outer bars
|
| 944 |
+
:param show_inner: False to hide inner bars
|
| 945 |
+
:param show_outer: False to hide outer bar
|
| 946 |
+
:param output_file: output file (default sys.stderr)
|
| 947 |
+
:param initial: Initial counter state
|
| 948 |
+
"""
|
| 949 |
+
self.outer_description = outer_description
|
| 950 |
+
self.inner_description_initial = inner_description_initial
|
| 951 |
+
self.inner_description_update = inner_description_update
|
| 952 |
+
self.metric_format = metric_format
|
| 953 |
+
self.separator = separator
|
| 954 |
+
self.leave_inner = leave_inner
|
| 955 |
+
self.leave_outer = leave_outer
|
| 956 |
+
self.show_inner = show_inner
|
| 957 |
+
self.show_outer = show_outer
|
| 958 |
+
self.output_file = output_file
|
| 959 |
+
self.tqdm_outer = None
|
| 960 |
+
self.tqdm_inner = None
|
| 961 |
+
self.epoch = None
|
| 962 |
+
self.running_logs = None
|
| 963 |
+
self.inner_count = None
|
| 964 |
+
self.initial = initial
|
| 965 |
+
|
| 966 |
+
def tqdm(self, desc, total, leave, initial=0):
|
| 967 |
+
"""
|
| 968 |
+
Extension point. Override to provide custom options to tqdm initializer.
|
| 969 |
+
:param desc: Description string
|
| 970 |
+
:param total: Total number of updates
|
| 971 |
+
:param leave: Leave progress bar when done
|
| 972 |
+
:param initial: Initial counter state
|
| 973 |
+
:return: new progress bar
|
| 974 |
+
"""
|
| 975 |
+
return tqdm(desc=desc, total=total, leave=leave, file=self.output_file, initial=initial)
|
| 976 |
+
|
| 977 |
+
def build_tqdm_outer(self, desc, total):
|
| 978 |
+
"""
|
| 979 |
+
Extension point. Override to provide custom options to outer progress bars (Epoch loop)
|
| 980 |
+
:param desc: Description
|
| 981 |
+
:param total: Number of epochs
|
| 982 |
+
:return: new progress bar
|
| 983 |
+
"""
|
| 984 |
+
return self.tqdm(desc=desc, total=total, leave=self.leave_outer, initial=self.initial)
|
| 985 |
+
|
| 986 |
+
def build_tqdm_inner(self, desc, total):
|
| 987 |
+
"""
|
| 988 |
+
Extension point. Override to provide custom options to inner progress bars (Batch loop)
|
| 989 |
+
:param desc: Description
|
| 990 |
+
:param total: Number of batches
|
| 991 |
+
:return: new progress bar
|
| 992 |
+
"""
|
| 993 |
+
return self.tqdm(desc=desc, total=total, leave=self.leave_inner)
|
| 994 |
+
|
| 995 |
+
def on_epoch_begin(self, epoch, logs={}):
|
| 996 |
+
self.epoch = epoch
|
| 997 |
+
desc = self.inner_description_initial.format(epoch=self.epoch)
|
| 998 |
+
self.mode = 0 # samples
|
| 999 |
+
if 'samples' in self.params:
|
| 1000 |
+
self.inner_total = self.params['samples']
|
| 1001 |
+
elif 'nb_sample' in self.params:
|
| 1002 |
+
self.inner_total = self.params['nb_sample']
|
| 1003 |
+
else:
|
| 1004 |
+
self.mode = 1 # steps
|
| 1005 |
+
self.inner_total = self.params['steps']
|
| 1006 |
+
if self.show_inner:
|
| 1007 |
+
self.tqdm_inner = self.build_tqdm_inner(desc=desc, total=self.inner_total)
|
| 1008 |
+
self.inner_count = 0
|
| 1009 |
+
self.running_logs = {}
|
| 1010 |
+
|
| 1011 |
+
def on_epoch_end(self, epoch, logs={}):
|
| 1012 |
+
metrics = self.format_metrics(logs)
|
| 1013 |
+
desc = self.inner_description_update.format(epoch=epoch, metrics=metrics)
|
| 1014 |
+
if self.show_inner:
|
| 1015 |
+
self.tqdm_inner.desc = desc
|
| 1016 |
+
# set miniters and mininterval to 0 so last update displays
|
| 1017 |
+
self.tqdm_inner.miniters = 0
|
| 1018 |
+
self.tqdm_inner.mininterval = 0
|
| 1019 |
+
self.tqdm_inner.update(self.inner_total - self.tqdm_inner.n)
|
| 1020 |
+
self.tqdm_inner.close()
|
| 1021 |
+
if self.show_outer:
|
| 1022 |
+
self.tqdm_outer.update(1)
|
| 1023 |
+
|
| 1024 |
+
def on_batch_begin(self, batch, logs={}):
|
| 1025 |
+
pass
|
| 1026 |
+
|
| 1027 |
+
def on_batch_end(self, batch, logs={}):
|
| 1028 |
+
if self.mode == 0:
|
| 1029 |
+
update = logs['size']
|
| 1030 |
+
else:
|
| 1031 |
+
update = 1
|
| 1032 |
+
self.inner_count += update
|
| 1033 |
+
if self.inner_count < self.inner_total:
|
| 1034 |
+
self.append_logs(logs)
|
| 1035 |
+
metrics = self.format_metrics(self.running_logs)
|
| 1036 |
+
desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics)
|
| 1037 |
+
if self.show_inner:
|
| 1038 |
+
self.tqdm_inner.desc = desc
|
| 1039 |
+
self.tqdm_inner.update(update)
|
| 1040 |
+
|
| 1041 |
+
def on_train_begin(self, logs={}):
|
| 1042 |
+
if self.show_outer:
|
| 1043 |
+
epochs = (self.params['epochs'] if 'epochs' in self.params
|
| 1044 |
+
else self.params['nb_epoch'])
|
| 1045 |
+
self.tqdm_outer = self.build_tqdm_outer(desc=self.outer_description,
|
| 1046 |
+
total=epochs)
|
| 1047 |
+
|
| 1048 |
+
def on_train_end(self, logs={}):
|
| 1049 |
+
if self.show_outer:
|
| 1050 |
+
self.tqdm_outer.close()
|
| 1051 |
+
|
| 1052 |
+
def append_logs(self, logs):
|
| 1053 |
+
metrics = self.params['metrics']
|
| 1054 |
+
for metric, value in six.iteritems(logs):
|
| 1055 |
+
if metric in metrics:
|
| 1056 |
+
if metric in self.running_logs:
|
| 1057 |
+
self.running_logs[metric].append(value[()])
|
| 1058 |
+
else:
|
| 1059 |
+
self.running_logs[metric] = [value[()]]
|
| 1060 |
+
|
| 1061 |
+
def format_metrics(self, logs):
|
| 1062 |
+
metrics = self.params['metrics']
|
| 1063 |
+
strings = [self.metric_format.format(name=metric, value=np.mean(logs[metric], axis=None)) for metric in metrics
|
| 1064 |
+
if
|
| 1065 |
+
metric in logs]
|
| 1066 |
+
return self.separator.join(strings)
|
dataset.py
CHANGED
|
@@ -3,7 +3,7 @@ from torch.utils.data import Dataset
|
|
| 3 |
from PIL import Image
|
| 4 |
import os
|
| 5 |
import json
|
| 6 |
-
from build_vocab import Vocabulary, JsonReader
|
| 7 |
import numpy as np
|
| 8 |
from torchvision import transforms
|
| 9 |
import pickle
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import os
|
| 5 |
import json
|
| 6 |
+
from utils.build_vocab import Vocabulary, JsonReader
|
| 7 |
import numpy as np
|
| 8 |
from torchvision import transforms
|
| 9 |
import pickle
|
logger.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.misc
|
| 5 |
+
try:
|
| 6 |
+
from StringIO import StringIO # Python 2.7
|
| 7 |
+
except ImportError:
|
| 8 |
+
from io import BytesIO # Python 3.x
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Logger(object):
|
| 12 |
+
|
| 13 |
+
def __init__(self, log_dir):
|
| 14 |
+
"""Create a summary writer logging to log_dir."""
|
| 15 |
+
self.writer = tf.summary.FileWriter(log_dir)
|
| 16 |
+
|
| 17 |
+
def scalar_summary(self, tag, value, step):
|
| 18 |
+
"""Log a scalar variable."""
|
| 19 |
+
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
|
| 20 |
+
self.writer.add_summary(summary, step)
|
| 21 |
+
|
| 22 |
+
def image_summary(self, tag, images, step):
|
| 23 |
+
"""Log a list of images."""
|
| 24 |
+
|
| 25 |
+
img_summaries = []
|
| 26 |
+
for i, img in enumerate(images):
|
| 27 |
+
# Write the image to a string
|
| 28 |
+
try:
|
| 29 |
+
s = StringIO()
|
| 30 |
+
except:
|
| 31 |
+
s = BytesIO()
|
| 32 |
+
scipy.misc.toimage(img).save(s, format="png")
|
| 33 |
+
|
| 34 |
+
# Create an Image object
|
| 35 |
+
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
|
| 36 |
+
height=img.shape[0],
|
| 37 |
+
width=img.shape[1])
|
| 38 |
+
# Create a Summary value
|
| 39 |
+
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
|
| 40 |
+
|
| 41 |
+
# Create and write Summary
|
| 42 |
+
summary = tf.Summary(value=img_summaries)
|
| 43 |
+
self.writer.add_summary(summary, step)
|
| 44 |
+
|
| 45 |
+
def histo_summary(self, tag, values, step, bins=1000):
|
| 46 |
+
"""Log a histogram of the tensor of values."""
|
| 47 |
+
|
| 48 |
+
# Create a histogram using numpy
|
| 49 |
+
counts, bin_edges = np.histogram(values, bins=bins)
|
| 50 |
+
|
| 51 |
+
# Fill the fields of the histogram proto
|
| 52 |
+
hist = tf.HistogramProto()
|
| 53 |
+
hist.min = float(np.min(values))
|
| 54 |
+
hist.max = float(np.max(values))
|
| 55 |
+
hist.num = int(np.prod(values.shape))
|
| 56 |
+
hist.sum = float(np.sum(values))
|
| 57 |
+
hist.sum_squares = float(np.sum(values**2))
|
| 58 |
+
|
| 59 |
+
# Drop the start of the first bin
|
| 60 |
+
bin_edges = bin_edges[1:]
|
| 61 |
+
|
| 62 |
+
# Add bin edges and counts
|
| 63 |
+
for edge in bin_edges:
|
| 64 |
+
hist.bucket_limit.append(edge)
|
| 65 |
+
for c in counts:
|
| 66 |
+
hist.bucket.append(c)
|
| 67 |
+
|
| 68 |
+
# Create and write Summary
|
| 69 |
+
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
| 70 |
+
self.writer.add_summary(summary, step)
|
| 71 |
+
self.writer.flush()
|
models_debugger.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
from torchvision.models.vgg import model_urls as vgg_model_urls
|
| 7 |
+
import torchvision.models as models
|
| 8 |
+
|
| 9 |
+
from utils.tcn import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DenseNet121(nn.Module):
|
| 13 |
+
def __init__(self, classes=14, pretrained=True):
|
| 14 |
+
super(DenseNet121, self).__init__()
|
| 15 |
+
self.model = torchvision.models.densenet121(pretrained=pretrained)
|
| 16 |
+
num_in_features = self.model.classifier.in_features
|
| 17 |
+
self.model.classifier = nn.Sequential(
|
| 18 |
+
nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
|
| 19 |
+
# nn.Sigmoid()
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, x) -> object:
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
:rtype: object
|
| 26 |
+
"""
|
| 27 |
+
x = self.densenet121(x)
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DenseNet161(nn.Module):
|
| 32 |
+
def __init__(self, classes=156, pretrained=True):
|
| 33 |
+
super(DenseNet161, self).__init__()
|
| 34 |
+
self.model = torchvision.models.densenet161(pretrained=pretrained)
|
| 35 |
+
num_in_features = self.model.classifier.in_features
|
| 36 |
+
self.model.classifier = nn.Sequential(
|
| 37 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 38 |
+
# nn.Sigmoid()
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def __init_linear(self, in_features, out_features):
|
| 42 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 43 |
+
func.weight.data.normal_(0, 0.1)
|
| 44 |
+
return func
|
| 45 |
+
|
| 46 |
+
def forward(self, x) -> object:
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
:rtype: object
|
| 50 |
+
"""
|
| 51 |
+
x = self.model(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DenseNet169(nn.Module):
|
| 56 |
+
def __init__(self, classes=156, pretrained=True):
|
| 57 |
+
super(DenseNet169, self).__init__()
|
| 58 |
+
self.model = torchvision.models.densenet169(pretrained=pretrained)
|
| 59 |
+
num_in_features = self.model.classifier.in_features
|
| 60 |
+
self.model.classifier = nn.Sequential(
|
| 61 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 62 |
+
# nn.Sigmoid()
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def __init_linear(self, in_features, out_features):
|
| 66 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 67 |
+
func.weight.data.normal_(0, 0.1)
|
| 68 |
+
return func
|
| 69 |
+
|
| 70 |
+
def forward(self, x) -> object:
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
:rtype: object
|
| 74 |
+
"""
|
| 75 |
+
x = self.model(x)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class DenseNet201(nn.Module):
|
| 80 |
+
def __init__(self, classes=156, pretrained=True):
|
| 81 |
+
super(DenseNet201, self).__init__()
|
| 82 |
+
self.model = torchvision.models.densenet201(pretrained=pretrained)
|
| 83 |
+
num_in_features = self.model.classifier.in_features
|
| 84 |
+
self.model.classifier = nn.Sequential(
|
| 85 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 86 |
+
nn.Sigmoid()
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def __init_linear(self, in_features, out_features):
|
| 90 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 91 |
+
func.weight.data.normal_(0, 0.1)
|
| 92 |
+
return func
|
| 93 |
+
|
| 94 |
+
def forward(self, x) -> object:
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
:rtype: object
|
| 98 |
+
"""
|
| 99 |
+
x = self.model(x)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ResNet18(nn.Module):
|
| 104 |
+
def __init__(self, classes=156, pretrained=True):
|
| 105 |
+
super(ResNet18, self).__init__()
|
| 106 |
+
self.model = torchvision.models.resnet18(pretrained=pretrained)
|
| 107 |
+
num_in_features = self.model.fc.in_features
|
| 108 |
+
self.model.fc = nn.Sequential(
|
| 109 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 110 |
+
# nn.Sigmoid()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def __init_linear(self, in_features, out_features):
|
| 114 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 115 |
+
func.weight.data.normal_(0, 0.1)
|
| 116 |
+
return func
|
| 117 |
+
|
| 118 |
+
def forward(self, x) -> object:
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
:rtype: object
|
| 122 |
+
"""
|
| 123 |
+
x = self.model(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ResNet34(nn.Module):
|
| 128 |
+
def __init__(self, classes=156, pretrained=True):
|
| 129 |
+
super(ResNet34, self).__init__()
|
| 130 |
+
self.model = torchvision.models.resnet34(pretrained=pretrained)
|
| 131 |
+
num_in_features = self.model.fc.in_features
|
| 132 |
+
self.model.fc = nn.Sequential(
|
| 133 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 134 |
+
# nn.Sigmoid()
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def __init_linear(self, in_features, out_features):
|
| 138 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 139 |
+
func.weight.data.normal_(0, 0.1)
|
| 140 |
+
return func
|
| 141 |
+
|
| 142 |
+
def forward(self, x) -> object:
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
:rtype: object
|
| 146 |
+
"""
|
| 147 |
+
x = self.model(x)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class ResNet50(nn.Module):
|
| 152 |
+
def __init__(self, classes=156, pretrained=True):
|
| 153 |
+
super(ResNet50, self).__init__()
|
| 154 |
+
self.model = torchvision.models.resnet50(pretrained=pretrained)
|
| 155 |
+
num_in_features = self.model.fc.in_features
|
| 156 |
+
self.model.fc = nn.Sequential(
|
| 157 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 158 |
+
# nn.Sigmoid()
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def __init_linear(self, in_features, out_features):
|
| 162 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 163 |
+
func.weight.data.normal_(0, 0.1)
|
| 164 |
+
return func
|
| 165 |
+
|
| 166 |
+
def forward(self, x) -> object:
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
:rtype: object
|
| 170 |
+
"""
|
| 171 |
+
x = self.model(x)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class ResNet101(nn.Module):
|
| 176 |
+
def __init__(self, classes=156, pretrained=True):
|
| 177 |
+
super(ResNet101, self).__init__()
|
| 178 |
+
self.model = torchvision.models.resnet101(pretrained=pretrained)
|
| 179 |
+
num_in_features = self.model.fc.in_features
|
| 180 |
+
self.model.fc = nn.Sequential(
|
| 181 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 182 |
+
# nn.Sigmoid()
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def __init_linear(self, in_features, out_features):
|
| 186 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 187 |
+
func.weight.data.normal_(0, 0.1)
|
| 188 |
+
return func
|
| 189 |
+
|
| 190 |
+
def forward(self, x) -> object:
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
:rtype: object
|
| 194 |
+
"""
|
| 195 |
+
x = self.model(x)
|
| 196 |
+
return x
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class ResNet152(nn.Module):
|
| 200 |
+
def __init__(self, classes=156, pretrained=True):
|
| 201 |
+
super(ResNet152, self).__init__()
|
| 202 |
+
self.model = torchvision.models.resnet152(pretrained=pretrained)
|
| 203 |
+
num_in_features = self.model.fc.in_features
|
| 204 |
+
self.model.fc = nn.Sequential(
|
| 205 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 206 |
+
# nn.Sigmoid()
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def __init_linear(self, in_features, out_features):
|
| 210 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 211 |
+
func.weight.data.normal_(0, 0.1)
|
| 212 |
+
return func
|
| 213 |
+
|
| 214 |
+
def forward(self, x) -> object:
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
:rtype: object
|
| 218 |
+
"""
|
| 219 |
+
x = self.model(x)
|
| 220 |
+
return x
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class VGG19(nn.Module):
|
| 224 |
+
def __init__(self, classes=14, pretrained=True):
|
| 225 |
+
super(VGG19, self).__init__()
|
| 226 |
+
self.model = torchvision.models.vgg19_bn(pretrained=pretrained)
|
| 227 |
+
self.model.classifier = nn.Sequential(
|
| 228 |
+
self.__init_linear(in_features=25088, out_features=4096),
|
| 229 |
+
nn.ReLU(),
|
| 230 |
+
nn.Dropout(0.5),
|
| 231 |
+
self.__init_linear(in_features=4096, out_features=4096),
|
| 232 |
+
nn.ReLU(),
|
| 233 |
+
nn.Dropout(0.5),
|
| 234 |
+
self.__init_linear(in_features=4096, out_features=classes),
|
| 235 |
+
# nn.Sigmoid()
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def __init_linear(self, in_features, out_features):
|
| 239 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 240 |
+
func.weight.data.normal_(0, 0.1)
|
| 241 |
+
return func
|
| 242 |
+
|
| 243 |
+
def forward(self, x) -> object:
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
:rtype: object
|
| 247 |
+
"""
|
| 248 |
+
x = self.model(x)
|
| 249 |
+
return x
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class VGG(nn.Module):
|
| 253 |
+
def __init__(self, tags_num):
|
| 254 |
+
super(VGG, self).__init__()
|
| 255 |
+
vgg_model_urls['vgg19'] = vgg_model_urls['vgg19'].replace('https://', 'http://')
|
| 256 |
+
self.vgg19 = models.vgg19(pretrained=True)
|
| 257 |
+
vgg19_classifier = list(self.vgg19.classifier.children())[:-1]
|
| 258 |
+
self.classifier = nn.Sequential(*vgg19_classifier)
|
| 259 |
+
self.fc = nn.Linear(4096, tags_num)
|
| 260 |
+
self.fc.apply(self.init_weights)
|
| 261 |
+
self.bn = nn.BatchNorm1d(tags_num, momentum=0.1)
|
| 262 |
+
# self.init_weights()
|
| 263 |
+
|
| 264 |
+
def init_weights(self, m):
|
| 265 |
+
if type(m) == nn.Linear:
|
| 266 |
+
self.fc.weight.data.normal_(0, 0.1)
|
| 267 |
+
self.fc.bias.data.fill_(0)
|
| 268 |
+
|
| 269 |
+
def forward(self, images) -> object:
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
:rtype: object
|
| 273 |
+
"""
|
| 274 |
+
visual_feats = self.vgg19.features(images)
|
| 275 |
+
tags_classifier = visual_feats.view(visual_feats.size(0), -1)
|
| 276 |
+
tags_classifier = self.bn(self.fc(self.classifier(tags_classifier)))
|
| 277 |
+
return tags_classifier
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class InceptionV3(nn.Module):
|
| 281 |
+
def __init__(self, classes=156, pretrained=True):
|
| 282 |
+
super(InceptionV3, self).__init__()
|
| 283 |
+
self.model = torchvision.models.inception_v3(pretrained=pretrained)
|
| 284 |
+
num_in_features = self.model.classifier.in_features
|
| 285 |
+
self.model.classifier = nn.Sequential(
|
| 286 |
+
self.__init_linear(in_features=num_in_features, out_features=classes),
|
| 287 |
+
# nn.Sigmoid()
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def __init_linear(self, in_features, out_features):
|
| 291 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 292 |
+
func.weight.data.normal_(0, 0.1)
|
| 293 |
+
return func
|
| 294 |
+
|
| 295 |
+
def forward(self, x) -> object:
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
:rtype: object
|
| 299 |
+
"""
|
| 300 |
+
x = self.model(x)
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class CheXNetDenseNet121(nn.Module):
|
| 305 |
+
def __init__(self, classes=14, pretrained=True):
|
| 306 |
+
super(CheXNetDenseNet121, self).__init__()
|
| 307 |
+
self.densenet121 = torchvision.models.densenet121(pretrained=pretrained)
|
| 308 |
+
num_in_features = self.densenet121.classifier.in_features
|
| 309 |
+
self.densenet121.classifier = nn.Sequential(
|
| 310 |
+
nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
|
| 311 |
+
nn.Sigmoid()
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def forward(self, x) -> object:
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
:rtype: object
|
| 318 |
+
"""
|
| 319 |
+
x = self.densenet121(x)
|
| 320 |
+
return x
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class CheXNet(nn.Module):
|
| 324 |
+
def __init__(self, classes=156):
|
| 325 |
+
super(CheXNet, self).__init__()
|
| 326 |
+
self.densenet121 = CheXNetDenseNet121(classes=14)
|
| 327 |
+
self.densenet121 = torch.nn.DataParallel(self.densenet121).cuda()
|
| 328 |
+
self.densenet121.load_state_dict(torch.load('./models/CheXNet.pth.tar')['state_dict'])
|
| 329 |
+
self.densenet121.module.densenet121.classifier = nn.Sequential(
|
| 330 |
+
self.__init_linear(1024, classes),
|
| 331 |
+
nn.Sigmoid()
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
def __init_linear(self, in_features, out_features):
|
| 335 |
+
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
|
| 336 |
+
func.weight.data.normal_(0, 0.1)
|
| 337 |
+
return func
|
| 338 |
+
|
| 339 |
+
def forward(self, x) -> object:
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
:rtype: object
|
| 343 |
+
"""
|
| 344 |
+
x = self.densenet121(x)
|
| 345 |
+
return x
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class ModelFactory(object):
|
| 349 |
+
def __init__(self, model_name, pretrained, classes):
|
| 350 |
+
self.model_name = model_name
|
| 351 |
+
self.pretrained = pretrained
|
| 352 |
+
self.classes = classes
|
| 353 |
+
|
| 354 |
+
def create_model(self):
|
| 355 |
+
if self.model_name == 'VGG19':
|
| 356 |
+
_model = VGG19(pretrained=self.pretrained, classes=self.classes)
|
| 357 |
+
elif self.model_name == 'DenseNet121':
|
| 358 |
+
_model = DenseNet121(pretrained=self.pretrained, classes=self.classes)
|
| 359 |
+
elif self.model_name == 'DenseNet161':
|
| 360 |
+
_model = DenseNet161(pretrained=self.pretrained, classes=self.classes)
|
| 361 |
+
elif self.model_name == 'DenseNet169':
|
| 362 |
+
_model = DenseNet169(pretrained=self.pretrained, classes=self.classes)
|
| 363 |
+
elif self.model_name == 'DenseNet201':
|
| 364 |
+
_model = DenseNet201(pretrained=self.pretrained, classes=self.classes)
|
| 365 |
+
elif self.model_name == 'CheXNet':
|
| 366 |
+
_model = CheXNet(classes=self.classes)
|
| 367 |
+
elif self.model_name == 'ResNet18':
|
| 368 |
+
_model = ResNet18(pretrained=self.pretrained, classes=self.classes)
|
| 369 |
+
elif self.model_name == 'ResNet34':
|
| 370 |
+
_model = ResNet34(pretrained=self.pretrained, classes=self.classes)
|
| 371 |
+
elif self.model_name == 'ResNet50':
|
| 372 |
+
_model = ResNet50(pretrained=self.pretrained, classes=self.classes)
|
| 373 |
+
elif self.model_name == 'ResNet101':
|
| 374 |
+
_model = ResNet101(pretrained=self.pretrained, classes=self.classes)
|
| 375 |
+
elif self.model_name == 'ResNet152':
|
| 376 |
+
_model = ResNet152(pretrained=self.pretrained, classes=self.classes)
|
| 377 |
+
elif self.model_name == 'VGG':
|
| 378 |
+
_model = VGG(tags_num=self.classes)
|
| 379 |
+
else:
|
| 380 |
+
_model = CheXNet(classes=self.classes)
|
| 381 |
+
|
| 382 |
+
return _model
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class EncoderCNN(nn.Module):
|
| 386 |
+
def __init__(self, embed_size, pretrained=True):
|
| 387 |
+
super(EncoderCNN, self).__init__()
|
| 388 |
+
# TODO Extract Image features from CNN based on other models
|
| 389 |
+
resnet = models.resnet152(pretrained=pretrained)
|
| 390 |
+
modules = list(resnet.children())[:-1]
|
| 391 |
+
self.resnet = nn.Sequential(*modules)
|
| 392 |
+
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
|
| 393 |
+
self.bn = nn.BatchNorm1d(embed_size, momentum=0.1)
|
| 394 |
+
self.__init_weights()
|
| 395 |
+
|
| 396 |
+
def __init_weights(self):
|
| 397 |
+
self.linear.weight.data.normal_(0.0, 0.1)
|
| 398 |
+
self.linear.bias.data.fill_(0)
|
| 399 |
+
|
| 400 |
+
def forward(self, images) -> object:
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
:rtype: object
|
| 404 |
+
"""
|
| 405 |
+
features = self.resnet(images)
|
| 406 |
+
features = Variable(features.data)
|
| 407 |
+
features = features.view(features.size(0), -1)
|
| 408 |
+
features = self.bn(self.linear(features))
|
| 409 |
+
return features
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class DecoderRNN(nn.Module):
|
| 413 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
|
| 414 |
+
super(DecoderRNN, self).__init__()
|
| 415 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
| 416 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
| 417 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
| 418 |
+
self.__init_weights()
|
| 419 |
+
self.n_max = n_max
|
| 420 |
+
|
| 421 |
+
def __init_weights(self):
|
| 422 |
+
self.embed.weight.data.uniform_(-0.1, 0.1)
|
| 423 |
+
self.linear.weight.data.uniform_(-0.1, 0.1)
|
| 424 |
+
self.linear.bias.data.fill_(0)
|
| 425 |
+
|
| 426 |
+
def forward(self, features, captions) -> object:
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
:rtype: object
|
| 430 |
+
"""
|
| 431 |
+
embeddings = self.embed(captions)
|
| 432 |
+
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
|
| 433 |
+
hidden, _ = self.lstm(embeddings)
|
| 434 |
+
outputs = self.linear(hidden[:, -1, :])
|
| 435 |
+
return outputs
|
| 436 |
+
|
| 437 |
+
def sample(self, features, start_tokens):
|
| 438 |
+
sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
|
| 439 |
+
predicted = start_tokens
|
| 440 |
+
embeddings = features
|
| 441 |
+
embeddings = embeddings.unsqueeze(1)
|
| 442 |
+
|
| 443 |
+
for i in range(self.n_max):
|
| 444 |
+
predicted = self.embed(predicted)
|
| 445 |
+
embeddings = torch.cat([embeddings, predicted], dim=1)
|
| 446 |
+
hidden_states, _ = self.lstm(embeddings)
|
| 447 |
+
hidden_states = hidden_states[:, -1, :]
|
| 448 |
+
outputs = self.linear(hidden_states)
|
| 449 |
+
predicted = torch.max(outputs, 1)[1]
|
| 450 |
+
sampled_ids[:, i] = predicted
|
| 451 |
+
predicted = predicted.unsqueeze(1)
|
| 452 |
+
return sampled_ids
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class VisualFeatureExtractor(nn.Module):
|
| 456 |
+
def __init__(self, pretrained=False):
|
| 457 |
+
super(VisualFeatureExtractor, self).__init__()
|
| 458 |
+
resnet = models.resnet152(pretrained=pretrained)
|
| 459 |
+
modules = list(resnet.children())[:-1]
|
| 460 |
+
self.resnet = nn.Sequential(*modules)
|
| 461 |
+
self.out_features = resnet.fc.in_features
|
| 462 |
+
|
| 463 |
+
def forward(self, images) -> object:
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
:rtype: object
|
| 467 |
+
"""
|
| 468 |
+
features = self.resnet(images)
|
| 469 |
+
features = features.view(features.size(0), -1)
|
| 470 |
+
return features
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
class MLC(nn.Module):
|
| 474 |
+
def __init__(self, classes=156, sementic_features_dim=512, fc_in_features=2048, k=10):
|
| 475 |
+
super(MLC, self).__init__()
|
| 476 |
+
self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
|
| 477 |
+
self.embed = nn.Embedding(classes, sementic_features_dim)
|
| 478 |
+
self.k = k
|
| 479 |
+
self.softmax = nn.Softmax()
|
| 480 |
+
|
| 481 |
+
def forward(self, visual_features) -> object:
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
:rtype: object
|
| 485 |
+
"""
|
| 486 |
+
tags = self.softmax(self.classifier(visual_features))
|
| 487 |
+
semantic_features = self.embed(torch.topk(tags, self.k)[1])
|
| 488 |
+
return tags, semantic_features
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class CoAttention(nn.Module):
|
| 492 |
+
def __init__(self, embed_size=512, hidden_size=512, visual_size=2048):
|
| 493 |
+
super(CoAttention, self).__init__()
|
| 494 |
+
self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size)
|
| 495 |
+
self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
|
| 496 |
+
|
| 497 |
+
self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size)
|
| 498 |
+
self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
|
| 499 |
+
|
| 500 |
+
self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size)
|
| 501 |
+
self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
|
| 502 |
+
|
| 503 |
+
self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size)
|
| 504 |
+
self.bn_a = nn.BatchNorm1d(num_features=10, momentum=0.1)
|
| 505 |
+
|
| 506 |
+
self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size)
|
| 507 |
+
self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 508 |
+
|
| 509 |
+
self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
| 510 |
+
self.bn_a_att = nn.BatchNorm1d(num_features=10, momentum=0.1)
|
| 511 |
+
|
| 512 |
+
self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size)
|
| 513 |
+
self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=0.1)
|
| 514 |
+
|
| 515 |
+
self.tanh = nn.Tanh()
|
| 516 |
+
self.softmax = nn.Softmax()
|
| 517 |
+
|
| 518 |
+
def forward(self, visual_features, semantic_features, h_sent) -> object:
|
| 519 |
+
"""
|
| 520 |
+
only training
|
| 521 |
+
:rtype: object
|
| 522 |
+
"""
|
| 523 |
+
W_v = self.bn_v(self.W_v(visual_features))
|
| 524 |
+
W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
|
| 525 |
+
|
| 526 |
+
alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
|
| 527 |
+
v_att = torch.mul(alpha_v, visual_features)
|
| 528 |
+
# v_att = torch.mul(alpha_v, visual_features).sum(1).unsqueeze(1)
|
| 529 |
+
|
| 530 |
+
W_a_h = self.bn_a_h(self.W_a_h(h_sent))
|
| 531 |
+
W_a = self.bn_a(self.W_a(semantic_features))
|
| 532 |
+
alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
|
| 533 |
+
a_att = torch.mul(alpha_a, semantic_features).sum(1)
|
| 534 |
+
# a_att = (alpha_a * semantic_features).sum(1)
|
| 535 |
+
ctx = self.bn_fc(self.W_fc(torch.cat([v_att, a_att], dim=1)))
|
| 536 |
+
# return self.W_fc(self.bn_fc(torch.cat([v_att, a_att], dim=1)))
|
| 537 |
+
return ctx, v_att
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class SentenceLSTM(nn.Module):
|
| 541 |
+
def __init__(self, embed_size=512, hidden_size=512, num_layers=1):
|
| 542 |
+
super(SentenceLSTM, self).__init__()
|
| 543 |
+
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers)
|
| 544 |
+
self.W_t_h = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
|
| 545 |
+
self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 546 |
+
|
| 547 |
+
self.W_t_ctx = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
|
| 548 |
+
self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 549 |
+
|
| 550 |
+
self.W_stop_s_1 = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
|
| 551 |
+
self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 552 |
+
|
| 553 |
+
self.W_stop_s = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
|
| 554 |
+
self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 555 |
+
|
| 556 |
+
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
|
| 557 |
+
self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 558 |
+
|
| 559 |
+
self.W_topic = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
|
| 560 |
+
self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 561 |
+
|
| 562 |
+
self.W_topic_2 = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
|
| 563 |
+
self.bn_topic_2 = nn.BatchNorm1d(num_features=1, momentum=0.1)
|
| 564 |
+
|
| 565 |
+
self.sigmoid = nn.Sigmoid()
|
| 566 |
+
self.tanh = nn.Tanh()
|
| 567 |
+
|
| 568 |
+
# def forward(self, ctx, prev_hidden_state, states=None) -> object:
|
| 569 |
+
# """
|
| 570 |
+
# Only training
|
| 571 |
+
# :rtype: object
|
| 572 |
+
# """
|
| 573 |
+
# ctx = ctx.unsqueeze(1)
|
| 574 |
+
# hidden_state, states = self.lstm(ctx, states)
|
| 575 |
+
# topic = self.bn_topic(self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
|
| 576 |
+
# + self.bn_t_ctx(self.W_t_ctx(ctx)))))
|
| 577 |
+
# p_stop = self.bn_stop(self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
|
| 578 |
+
# + self.bn_stop_s(self.W_stop_s(hidden_state)))))
|
| 579 |
+
# return topic, p_stop, hidden_state, states
|
| 580 |
+
|
| 581 |
+
def forward(self, ctx, prev_hidden_state, states=None) -> object:
|
| 582 |
+
"""
|
| 583 |
+
v2
|
| 584 |
+
:rtype: object
|
| 585 |
+
"""
|
| 586 |
+
ctx = ctx.unsqueeze(1)
|
| 587 |
+
hidden_state, states = self.lstm(ctx, states)
|
| 588 |
+
topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state)
|
| 589 |
+
+ self.W_t_ctx(ctx)))))
|
| 590 |
+
p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state)
|
| 591 |
+
+ self.W_stop_s(hidden_state)))))
|
| 592 |
+
return topic, p_stop, hidden_state, states
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class SentenceTCN(nn.Module):
|
| 596 |
+
def __init__(self,
|
| 597 |
+
input_channel=10,
|
| 598 |
+
embed_size=512,
|
| 599 |
+
output_size=512,
|
| 600 |
+
nhid=512,
|
| 601 |
+
levels=8,
|
| 602 |
+
kernel_size=2,
|
| 603 |
+
dropout=0):
|
| 604 |
+
super(SentenceTCN, self).__init__()
|
| 605 |
+
channel_sizes = [nhid] * levels
|
| 606 |
+
self.tcn = TCN(input_size=input_channel,
|
| 607 |
+
output_size=output_size,
|
| 608 |
+
num_channels=channel_sizes,
|
| 609 |
+
kernel_size=kernel_size,
|
| 610 |
+
dropout=dropout)
|
| 611 |
+
self.W_t_h = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
| 612 |
+
self.W_t_ctx = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
| 613 |
+
self.W_stop_s_1 = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
| 614 |
+
self.W_stop_s = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
|
| 615 |
+
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
|
| 616 |
+
self.t_w = nn.Linear(in_features=5120, out_features=2, bias=True)
|
| 617 |
+
self.tanh = nn.Tanh()
|
| 618 |
+
|
| 619 |
+
def forward(self, ctx, prev_output) -> object:
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
:rtype: object
|
| 623 |
+
"""
|
| 624 |
+
output = self.tcn.forward(ctx)
|
| 625 |
+
topic = self.tanh(self.W_t_h(output) + self.W_t_ctx(ctx[:, -1, :]).squeeze(1))
|
| 626 |
+
p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_output) + self.W_stop_s(output)))
|
| 627 |
+
return topic, p_stop, output
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class WordLSTM(nn.Module):
|
| 631 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
|
| 632 |
+
super(WordLSTM, self).__init__()
|
| 633 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
| 634 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
| 635 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
| 636 |
+
self.__init_weights()
|
| 637 |
+
self.n_max = n_max
|
| 638 |
+
self.vocab_size = vocab_size
|
| 639 |
+
|
| 640 |
+
def __init_weights(self):
|
| 641 |
+
self.embed.weight.data.uniform_(-0.1, 0.1)
|
| 642 |
+
self.linear.weight.data.uniform_(-0.1, 0.1)
|
| 643 |
+
self.linear.bias.data.fill_(0)
|
| 644 |
+
|
| 645 |
+
def forward(self, topic_vec, captions) -> object:
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
:rtype: object
|
| 649 |
+
"""
|
| 650 |
+
embeddings = self.embed(captions)
|
| 651 |
+
embeddings = torch.cat((topic_vec, embeddings), 1)
|
| 652 |
+
hidden, _ = self.lstm(embeddings)
|
| 653 |
+
outputs = self.linear(hidden[:, -1, :])
|
| 654 |
+
return outputs
|
| 655 |
+
|
| 656 |
+
def val(self, features, start_tokens):
|
| 657 |
+
samples = torch.zeros((np.shape(features)[0], self.n_max, self.vocab_size))
|
| 658 |
+
samples[:, 0, start_tokens[0]] = 1
|
| 659 |
+
predicted = start_tokens
|
| 660 |
+
embeddings = features
|
| 661 |
+
embeddings = embeddings
|
| 662 |
+
|
| 663 |
+
for i in range(1, self.n_max):
|
| 664 |
+
predicted = self.embed(predicted)
|
| 665 |
+
embeddings = torch.cat([embeddings, predicted], dim=1)
|
| 666 |
+
hidden_states, _ = self.lstm(embeddings)
|
| 667 |
+
hidden_states = hidden_states[:, -1, :]
|
| 668 |
+
outputs = self.linear(hidden_states)
|
| 669 |
+
samples[:, i, :] = outputs
|
| 670 |
+
predicted = torch.max(outputs, 1)[1]
|
| 671 |
+
predicted = predicted.unsqueeze(1)
|
| 672 |
+
return samples
|
| 673 |
+
|
| 674 |
+
def sample(self, features, start_tokens):
|
| 675 |
+
sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
|
| 676 |
+
sampled_ids[:, 0] = start_tokens.view(-1,)
|
| 677 |
+
predicted = start_tokens
|
| 678 |
+
embeddings = features
|
| 679 |
+
embeddings = embeddings
|
| 680 |
+
|
| 681 |
+
for i in range(1, self.n_max):
|
| 682 |
+
predicted = self.embed(predicted)
|
| 683 |
+
embeddings = torch.cat([embeddings, predicted], dim=1)
|
| 684 |
+
hidden_states, _ = self.lstm(embeddings)
|
| 685 |
+
hidden_states = hidden_states[:, -1, :]
|
| 686 |
+
outputs = self.linear(hidden_states)
|
| 687 |
+
predicted = torch.max(outputs, 1)[1]
|
| 688 |
+
sampled_ids[:, i] = predicted
|
| 689 |
+
predicted = predicted.unsqueeze(1)
|
| 690 |
+
return sampled_ids
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class WordTCN(nn.Module):
|
| 694 |
+
def __init__(self,
|
| 695 |
+
input_channel=11,
|
| 696 |
+
vocab_size=1000,
|
| 697 |
+
embed_size=512,
|
| 698 |
+
output_size=512,
|
| 699 |
+
nhid=512,
|
| 700 |
+
levels=8,
|
| 701 |
+
kernel_size=2,
|
| 702 |
+
dropout=0,
|
| 703 |
+
n_max=50):
|
| 704 |
+
super(WordTCN, self).__init__()
|
| 705 |
+
self.vocab_size = vocab_size
|
| 706 |
+
self.embed_size = embed_size
|
| 707 |
+
self.output_size = output_size
|
| 708 |
+
channel_sizes = [nhid] * levels
|
| 709 |
+
self.kernel_size = kernel_size
|
| 710 |
+
self.dropout = dropout
|
| 711 |
+
self.n_max = n_max
|
| 712 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
| 713 |
+
self.W_out = nn.Linear(in_features=output_size, out_features=vocab_size, bias=True)
|
| 714 |
+
self.tcn = TCN(input_size=input_channel,
|
| 715 |
+
output_size=output_size,
|
| 716 |
+
num_channels=channel_sizes,
|
| 717 |
+
kernel_size=kernel_size,
|
| 718 |
+
dropout=dropout)
|
| 719 |
+
|
| 720 |
+
def forward(self, topic_vec, captions) -> object:
|
| 721 |
+
"""
|
| 722 |
+
|
| 723 |
+
:rtype: object
|
| 724 |
+
"""
|
| 725 |
+
captions = self.embed(captions)
|
| 726 |
+
embeddings = torch.cat([topic_vec, captions], dim=1)
|
| 727 |
+
output = self.tcn.forward(embeddings)
|
| 728 |
+
words = self.W_out(output)
|
| 729 |
+
return words
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
if __name__ == '__main__':
|
| 733 |
+
import warnings
|
| 734 |
+
warnings.filterwarnings("ignore")
|
| 735 |
+
images = torch.randn((4, 3, 224, 224))
|
| 736 |
+
captions = torch.ones((4, 10)).long()
|
| 737 |
+
hidden_state = torch.randn((4, 1, 512))
|
| 738 |
+
|
| 739 |
+
print("images:{}".format(images.shape))
|
| 740 |
+
print("captions:{}".format(captions.shape))
|
| 741 |
+
print("hidden_states:{}".format(hidden_state.shape))
|
| 742 |
+
|
| 743 |
+
extractor = VisualFeatureExtractor()
|
| 744 |
+
visual_features = extractor.forward(images)
|
| 745 |
+
print("visual_features:{}".format(visual_features.shape))
|
| 746 |
+
|
| 747 |
+
mlc = MLC()
|
| 748 |
+
tags, semantic_features = mlc.forward(visual_features)
|
| 749 |
+
print("tags:{}".format(tags.shape))
|
| 750 |
+
print("semantic_features:{}".format(semantic_features.shape))
|
| 751 |
+
|
| 752 |
+
co_att = CoAttention()
|
| 753 |
+
ctx, v_att = co_att.forward(visual_features, semantic_features, hidden_state)
|
| 754 |
+
print("ctx:{}".format(ctx.shape))
|
| 755 |
+
print("v_att:{}".format(v_att.shape))
|
| 756 |
+
|
| 757 |
+
sent_lstm = SentenceLSTM()
|
| 758 |
+
topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state)
|
| 759 |
+
print("Topic:{}".format(topic.shape))
|
| 760 |
+
print("P_STOP:{}".format(p_stop.shape))
|
| 761 |
+
|
| 762 |
+
word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1)
|
| 763 |
+
words = word_lstm.forward(topic, captions)
|
| 764 |
+
print("words:{}".format(words.shape))
|
| 765 |
+
|
| 766 |
+
# Expected Output
|
| 767 |
+
# images: torch.Size([4, 3, 224, 224])
|
| 768 |
+
# captions: torch.Size([4, 1, 10])
|
| 769 |
+
# hidden_states: torch.Size([4, 1, 512])
|
| 770 |
+
# visual_features: torch.Size([4, 2048, 7, 7])
|
| 771 |
+
# tags: torch.Size([4, 156])
|
| 772 |
+
# semantic_features: torch.Size([4, 10, 512])
|
| 773 |
+
# ctx: torch.Size([4, 512])
|
| 774 |
+
# Topic: torch.Size([4, 1, 512])
|
| 775 |
+
# P_STOP: torch.Size([4, 1, 2])
|
| 776 |
+
# words: torch.Size([4, 1000])
|
| 777 |
+
|
| 778 |
+
# images = torch.randn((4, 3, 224, 224))
|
| 779 |
+
# captions = torch.ones((4, 3, 10)).long()
|
| 780 |
+
# prev_outputs = torch.randn((4, 512))
|
| 781 |
+
# now_words = torch.ones((4, 1))
|
| 782 |
+
#
|
| 783 |
+
# ctx_records = torch.zeros((4, 10, 512))
|
| 784 |
+
# captions = torch.zeros((4, 10)).long()
|
| 785 |
+
#
|
| 786 |
+
# print("images:{}".format(images.shape))
|
| 787 |
+
# print("captions:{}".format(captions.shape))
|
| 788 |
+
# print("hidden_states:{}".format(prev_outputs.shape))
|
| 789 |
+
#
|
| 790 |
+
# extractor = VisualFeatureExtractor()
|
| 791 |
+
# visual_features = extractor.forward(images)
|
| 792 |
+
# print("visual_features:{}".format(visual_features.shape))
|
| 793 |
+
#
|
| 794 |
+
# mlc = MLC()
|
| 795 |
+
# tags, semantic_features = mlc.forward(visual_features)
|
| 796 |
+
# print("tags:{}".format(tags.shape))
|
| 797 |
+
# print("semantic_features:{}".format(semantic_features.shape))
|
| 798 |
+
#
|
| 799 |
+
# co_att = CoAttention()
|
| 800 |
+
# ctx = co_att.forward(visual_features, semantic_features, prev_outputs)
|
| 801 |
+
# print("ctx:{}".format(ctx.shape))
|
| 802 |
+
#
|
| 803 |
+
# ctx_records[:, 0, :] = ctx
|
| 804 |
+
#
|
| 805 |
+
# sent_tcn = SentenceTCN()
|
| 806 |
+
# topic, p_stop, prev_outputs = sent_tcn.forward(ctx_records, prev_outputs)
|
| 807 |
+
# print("Topic:{}".format(topic.shape))
|
| 808 |
+
# print("P_STOP:{}".format(p_stop.shape))
|
| 809 |
+
# print("Prev_Outputs:{}".format(prev_outputs.shape))
|
| 810 |
+
#
|
| 811 |
+
# captions[:, 0] = now_words.view(-1,)
|
| 812 |
+
#
|
| 813 |
+
# word_tcn = WordTCN()
|
| 814 |
+
# words = word_tcn.forward(topic, captions)
|
| 815 |
+
# print("words:{}".format(words.shape))
|
| 816 |
+
|
tcn.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.utils import weight_norm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Chomp1d(nn.Module):
|
| 7 |
+
def __init__(self, chomp_size):
|
| 8 |
+
super(Chomp1d, self).__init__()
|
| 9 |
+
self.chomp_size = chomp_size
|
| 10 |
+
|
| 11 |
+
def forward(self, x) -> object:
|
| 12 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TemporalBlock(nn.Module):
|
| 16 |
+
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
|
| 17 |
+
super(TemporalBlock, self).__init__()
|
| 18 |
+
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
|
| 19 |
+
stride=stride, padding=padding, dilation=dilation))
|
| 20 |
+
self.chomp1 = Chomp1d(padding)
|
| 21 |
+
self.relu1 = nn.ReLU(inplace=False)
|
| 22 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 23 |
+
|
| 24 |
+
self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
|
| 25 |
+
stride=stride, padding=padding, dilation=dilation))
|
| 26 |
+
self.chomp2 = Chomp1d(padding)
|
| 27 |
+
self.relu2 = nn.ReLU(inplace=False)
|
| 28 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 29 |
+
|
| 30 |
+
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
|
| 31 |
+
self.conv2, self.chomp2, self.relu2, self.dropout2)
|
| 32 |
+
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
|
| 33 |
+
self.relu = nn.ReLU(inplace=False)
|
| 34 |
+
self.init_weights()
|
| 35 |
+
|
| 36 |
+
def init_weights(self):
|
| 37 |
+
self.conv1.weight.data.normal_(0, 0.01)
|
| 38 |
+
self.conv2.weight.data.normal_(0, 0.01)
|
| 39 |
+
if self.downsample is not None:
|
| 40 |
+
self.downsample.weight.data.normal_(0, 0.01)
|
| 41 |
+
|
| 42 |
+
def forward(self, x) -> object:
|
| 43 |
+
out = self.net(x)
|
| 44 |
+
res = x if self.downsample is None else self.downsample(x)
|
| 45 |
+
return self.relu(out + res)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TemporalConvNet(nn.Module):
|
| 49 |
+
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
| 50 |
+
super(TemporalConvNet, self).__init__()
|
| 51 |
+
layers = []
|
| 52 |
+
num_levels = len(num_channels)
|
| 53 |
+
for i in range(num_levels):
|
| 54 |
+
dilation_size = 2 ** i
|
| 55 |
+
in_channels = num_inputs if i == 0 else num_channels[i-1]
|
| 56 |
+
out_channels = num_channels[i]
|
| 57 |
+
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
|
| 58 |
+
padding=(kernel_size-1) * dilation_size, dropout=dropout)]
|
| 59 |
+
|
| 60 |
+
self.network = nn.Sequential(*layers)
|
| 61 |
+
|
| 62 |
+
def forward(self, x) -> object:
|
| 63 |
+
return self.network(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TCN(nn.Module):
|
| 67 |
+
def __init__(self, input_size, output_size, num_channels, kernel_size=2, dropout=0):
|
| 68 |
+
super(TCN, self).__init__()
|
| 69 |
+
self.tcn = TemporalConvNet(num_inputs=input_size,
|
| 70 |
+
num_channels=num_channels,
|
| 71 |
+
kernel_size=kernel_size,
|
| 72 |
+
dropout=dropout)
|
| 73 |
+
self.linear = nn.Linear(num_channels[-1], output_size)
|
| 74 |
+
self.init_weights()
|
| 75 |
+
|
| 76 |
+
def init_weights(self):
|
| 77 |
+
self.linear.weight.data.normal_(0, 0.01)
|
| 78 |
+
self.linear.bias.data.fill_(0)
|
| 79 |
+
|
| 80 |
+
def forward(self, inputs) -> object:
|
| 81 |
+
y = self.tcn.forward(inputs)
|
| 82 |
+
output = self.linear(y[:, :, -1])
|
| 83 |
+
return output
|