Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
from torch import nn | |
import torch | |
from torch.utils.data import random_split, DataLoader | |
from z_dataops import NamesDataset, transform, proxy_collate_batch | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
import string | |
class NameToLanguages(nn.Module): | |
def __init__(self, feature_size=26, n_classes=18): | |
super().__init__() | |
# create simple architecture | |
self.net_rnn = nn.RNN(input_size=feature_size, hidden_size=128, batch_first=True) | |
self.net_linear = nn.Linear(in_features=128, out_features=n_classes) | |
def forward(self, x): | |
rnn_out, last_ts = self.net_rnn(x) | |
output = self.net_linear(last_ts[0]) | |
return output | |
def training(model: nn.Module, train_batch: list, optimizer, loss_fn): | |
model.train() | |
batch_loss = 0 | |
for x, y in train_batch: | |
# predict | |
y_pred = model(x) | |
# compute loss | |
curr_loss = loss_fn(y_pred, y) | |
batch_loss += curr_loss | |
# reset grad | |
optimizer.zero_grad() | |
# calculate grad | |
batch_loss.backward() | |
# nn.utils.clip_grad_norm_(model.parameters(), 3) | |
# step | |
optimizer.step() | |
return batch_loss.item() / len(train_batch) | |
def validation(model, dl: DataLoader, loss_fn): | |
model.eval() | |
batch_loss = 0 | |
with torch.no_grad(): | |
for item in dl: | |
for x, y in item: | |
# predict | |
y_pred = model(x) | |
# loss | |
curr_loss = loss_fn(y_pred, y) | |
batch_loss += curr_loss | |
return batch_loss.item() / len(dl) | |
def plot_losses(loss_label, title, save_location="model/loss.png"): | |
for k, v in loss_label.items(): | |
plt.plot(v, label=k) | |
plt.legend() | |
plt.title(title) | |
plt.savefig(save_location) | |
def load_labels(input_file="model/label.json"): | |
# Read the dictionary from the file | |
with open(input_file, 'r') as file: | |
dictionary = json.load(file) | |
return dictionary | |
def evaluate(rnn, validation_dl, classes): | |
# CODE AS IS FROM: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#evaluating-the-results | |
confusion = torch.zeros(len(classes), len(classes)) | |
rnn.eval() #set to eval mode | |
with torch.no_grad(): # do not record the gradients during eval phase | |
for item in validation_dl: | |
for text_tensor, label in item: | |
output = rnn(text_tensor) | |
# | |
_, idx = output.topk(1) | |
guess, guess_i = classes[str(idx.item())], idx.item() | |
label_i = label.item() | |
confusion[label_i][guess_i] += 1 | |
# Normalize by dividing every row by its sum | |
for i in range(len(classes)): | |
denom = confusion[i].sum() | |
if denom > 0: | |
confusion[i] = confusion[i] / denom | |
# Set up plot | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version | |
fig.colorbar(cax) | |
tag = [classes[str(i)] for i in range(len(classes))] | |
# Set up axes | |
ax.set_xticks(np.arange(len(classes)), labels=tag, rotation=90) | |
ax.set_yticks(np.arange(len(classes)), labels=tag) | |
# Force label at every tick | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) | |
ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) | |
# sphinx_gallery_thumbnail_number = 2 | |
plt.savefig("model/evaluate.png") | |
def load_labels(input_file="model/label.json"): | |
# Read the dictionary from the file | |
with open(input_file, 'r') as file: | |
dictionary = json.load(file) | |
return dictionary | |
if __name__=="__main__": | |
model = NameToLanguages(feature_size=len(string.ascii_letters)) | |
# #Sanity Check Model | |
# x = torch.randn((1, 7, 26)) # (batch, word_length, one-hot-ascii-char) | |
# model.eval() | |
# with torch.no_grad(): | |
# out = model(x) | |
# print(out.shape) | |
# #Optimziers, Loss | |
optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-3) | |
loss_fn = nn.CrossEntropyLoss() | |
n_epoch = 27 | |
# #Training Loop | |
ds = NamesDataset(transform=transform) | |
train_ds, val_ds = random_split(ds, [0.7, 0.3], generator=torch.Generator().manual_seed(31)) | |
train_dl = DataLoader(dataset=train_ds, batch_size=64, collate_fn=proxy_collate_batch) | |
val_dl = DataLoader(dataset=val_ds, collate_fn=proxy_collate_batch) | |
# #Trackers | |
train_losses, val_losses = [], [] | |
for epoch in range(n_epoch): | |
for batch in train_dl: | |
train_loss = training(model, batch, optimizer, loss_fn) | |
# report val loss | |
train_losses.append(train_loss) | |
val_loss = validation(model, val_dl, loss_fn) | |
val_losses.append(val_loss) | |
print(f"Epoch {epoch}: Train_loss: {train_losses[-1]}, Val_loss: {val_loss}") | |
plot_losses({"train": train_losses, "val": val_losses}, "Training Loss") | |
torch.save(model, "model/rnn.pth") | |
classes = load_labels() | |
evaluate(model, val_dl, classes) | |