you_might_speak / z_modelops.py
Deepak Sahu
training; app
2c1ff7f
raw
history blame contribute delete
5.11 kB
import json
import numpy as np
from torch import nn
import torch
from torch.utils.data import random_split, DataLoader
from z_dataops import NamesDataset, transform, proxy_collate_batch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import string
class NameToLanguages(nn.Module):
def __init__(self, feature_size=26, n_classes=18):
super().__init__()
# create simple architecture
self.net_rnn = nn.RNN(input_size=feature_size, hidden_size=128, batch_first=True)
self.net_linear = nn.Linear(in_features=128, out_features=n_classes)
def forward(self, x):
rnn_out, last_ts = self.net_rnn(x)
output = self.net_linear(last_ts[0])
return output
def training(model: nn.Module, train_batch: list, optimizer, loss_fn):
model.train()
batch_loss = 0
for x, y in train_batch:
# predict
y_pred = model(x)
# compute loss
curr_loss = loss_fn(y_pred, y)
batch_loss += curr_loss
# reset grad
optimizer.zero_grad()
# calculate grad
batch_loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(), 3)
# step
optimizer.step()
return batch_loss.item() / len(train_batch)
def validation(model, dl: DataLoader, loss_fn):
model.eval()
batch_loss = 0
with torch.no_grad():
for item in dl:
for x, y in item:
# predict
y_pred = model(x)
# loss
curr_loss = loss_fn(y_pred, y)
batch_loss += curr_loss
return batch_loss.item() / len(dl)
def plot_losses(loss_label, title, save_location="model/loss.png"):
for k, v in loss_label.items():
plt.plot(v, label=k)
plt.legend()
plt.title(title)
plt.savefig(save_location)
def load_labels(input_file="model/label.json"):
# Read the dictionary from the file
with open(input_file, 'r') as file:
dictionary = json.load(file)
return dictionary
def evaluate(rnn, validation_dl, classes):
# CODE AS IS FROM: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#evaluating-the-results
confusion = torch.zeros(len(classes), len(classes))
rnn.eval() #set to eval mode
with torch.no_grad(): # do not record the gradients during eval phase
for item in validation_dl:
for text_tensor, label in item:
output = rnn(text_tensor)
#
_, idx = output.topk(1)
guess, guess_i = classes[str(idx.item())], idx.item()
label_i = label.item()
confusion[label_i][guess_i] += 1
# Normalize by dividing every row by its sum
for i in range(len(classes)):
denom = confusion[i].sum()
if denom > 0:
confusion[i] = confusion[i] / denom
# Set up plot
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version
fig.colorbar(cax)
tag = [classes[str(i)] for i in range(len(classes))]
# Set up axes
ax.set_xticks(np.arange(len(classes)), labels=tag, rotation=90)
ax.set_yticks(np.arange(len(classes)), labels=tag)
# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
# sphinx_gallery_thumbnail_number = 2
plt.savefig("model/evaluate.png")
def load_labels(input_file="model/label.json"):
# Read the dictionary from the file
with open(input_file, 'r') as file:
dictionary = json.load(file)
return dictionary
if __name__=="__main__":
model = NameToLanguages(feature_size=len(string.ascii_letters))
# #Sanity Check Model
# x = torch.randn((1, 7, 26)) # (batch, word_length, one-hot-ascii-char)
# model.eval()
# with torch.no_grad():
# out = model(x)
# print(out.shape)
# #Optimziers, Loss
optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
n_epoch = 27
# #Training Loop
ds = NamesDataset(transform=transform)
train_ds, val_ds = random_split(ds, [0.7, 0.3], generator=torch.Generator().manual_seed(31))
train_dl = DataLoader(dataset=train_ds, batch_size=64, collate_fn=proxy_collate_batch)
val_dl = DataLoader(dataset=val_ds, collate_fn=proxy_collate_batch)
# #Trackers
train_losses, val_losses = [], []
for epoch in range(n_epoch):
for batch in train_dl:
train_loss = training(model, batch, optimizer, loss_fn)
# report val loss
train_losses.append(train_loss)
val_loss = validation(model, val_dl, loss_fn)
val_losses.append(val_loss)
print(f"Epoch {epoch}: Train_loss: {train_losses[-1]}, Val_loss: {val_loss}")
plot_losses({"train": train_losses, "val": val_losses}, "Training Loss")
torch.save(model, "model/rnn.pth")
classes = load_labels()
evaluate(model, val_dl, classes)