Spaces:
Sleeping
Sleeping
File size: 5,106 Bytes
2c1ff7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import json
import numpy as np
from torch import nn
import torch
from torch.utils.data import random_split, DataLoader
from z_dataops import NamesDataset, transform, proxy_collate_batch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import string
class NameToLanguages(nn.Module):
def __init__(self, feature_size=26, n_classes=18):
super().__init__()
# create simple architecture
self.net_rnn = nn.RNN(input_size=feature_size, hidden_size=128, batch_first=True)
self.net_linear = nn.Linear(in_features=128, out_features=n_classes)
def forward(self, x):
rnn_out, last_ts = self.net_rnn(x)
output = self.net_linear(last_ts[0])
return output
def training(model: nn.Module, train_batch: list, optimizer, loss_fn):
model.train()
batch_loss = 0
for x, y in train_batch:
# predict
y_pred = model(x)
# compute loss
curr_loss = loss_fn(y_pred, y)
batch_loss += curr_loss
# reset grad
optimizer.zero_grad()
# calculate grad
batch_loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(), 3)
# step
optimizer.step()
return batch_loss.item() / len(train_batch)
def validation(model, dl: DataLoader, loss_fn):
model.eval()
batch_loss = 0
with torch.no_grad():
for item in dl:
for x, y in item:
# predict
y_pred = model(x)
# loss
curr_loss = loss_fn(y_pred, y)
batch_loss += curr_loss
return batch_loss.item() / len(dl)
def plot_losses(loss_label, title, save_location="model/loss.png"):
for k, v in loss_label.items():
plt.plot(v, label=k)
plt.legend()
plt.title(title)
plt.savefig(save_location)
def load_labels(input_file="model/label.json"):
# Read the dictionary from the file
with open(input_file, 'r') as file:
dictionary = json.load(file)
return dictionary
def evaluate(rnn, validation_dl, classes):
# CODE AS IS FROM: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#evaluating-the-results
confusion = torch.zeros(len(classes), len(classes))
rnn.eval() #set to eval mode
with torch.no_grad(): # do not record the gradients during eval phase
for item in validation_dl:
for text_tensor, label in item:
output = rnn(text_tensor)
#
_, idx = output.topk(1)
guess, guess_i = classes[str(idx.item())], idx.item()
label_i = label.item()
confusion[label_i][guess_i] += 1
# Normalize by dividing every row by its sum
for i in range(len(classes)):
denom = confusion[i].sum()
if denom > 0:
confusion[i] = confusion[i] / denom
# Set up plot
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version
fig.colorbar(cax)
tag = [classes[str(i)] for i in range(len(classes))]
# Set up axes
ax.set_xticks(np.arange(len(classes)), labels=tag, rotation=90)
ax.set_yticks(np.arange(len(classes)), labels=tag)
# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
# sphinx_gallery_thumbnail_number = 2
plt.savefig("model/evaluate.png")
def load_labels(input_file="model/label.json"):
# Read the dictionary from the file
with open(input_file, 'r') as file:
dictionary = json.load(file)
return dictionary
if __name__=="__main__":
model = NameToLanguages(feature_size=len(string.ascii_letters))
# #Sanity Check Model
# x = torch.randn((1, 7, 26)) # (batch, word_length, one-hot-ascii-char)
# model.eval()
# with torch.no_grad():
# out = model(x)
# print(out.shape)
# #Optimziers, Loss
optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
n_epoch = 27
# #Training Loop
ds = NamesDataset(transform=transform)
train_ds, val_ds = random_split(ds, [0.7, 0.3], generator=torch.Generator().manual_seed(31))
train_dl = DataLoader(dataset=train_ds, batch_size=64, collate_fn=proxy_collate_batch)
val_dl = DataLoader(dataset=val_ds, collate_fn=proxy_collate_batch)
# #Trackers
train_losses, val_losses = [], []
for epoch in range(n_epoch):
for batch in train_dl:
train_loss = training(model, batch, optimizer, loss_fn)
# report val loss
train_losses.append(train_loss)
val_loss = validation(model, val_dl, loss_fn)
val_losses.append(val_loss)
print(f"Epoch {epoch}: Train_loss: {train_losses[-1]}, Val_loss: {val_loss}")
plot_losses({"train": train_losses, "val": val_losses}, "Training Loss")
torch.save(model, "model/rnn.pth")
classes = load_labels()
evaluate(model, val_dl, classes)
|