ProCreations commited on
Commit
3f0fa4c
·
verified ·
1 Parent(s): ca5e029

Update simple_lm.pth

Browse files
Files changed (1) hide show
  1. simple_lm.pth +73 -0
simple_lm.pth CHANGED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import json
6
+
7
+ # Define a simple LSTM-based language model
8
+ class SimpleLM(nn.Module):
9
+ def __init__(self, vocab_size, embedding_dim, hidden_dim):
10
+ super(SimpleLM, self).__init__()
11
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
12
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
13
+ self.linear = nn.Linear(hidden_dim, vocab_size)
14
+
15
+ def forward(self, x, hidden):
16
+ embedded = self.embedding(x)
17
+ output, hidden = self.lstm(embedded, hidden)
18
+ output = self.linear(output)
19
+ return output, hidden
20
+
21
+ # Define a custom dataset class
22
+ class CustomDataset(Dataset):
23
+ def __init__(self, data_path):
24
+ self.data = json.load(open(data_path, 'r'))
25
+
26
+ def __len__(self):
27
+ return len(self.data)
28
+
29
+ def __getitem__(self, idx):
30
+ text = self.data[idx]
31
+ return torch.tensor(text, dtype=torch.long)
32
+
33
+ # Define training parameters
34
+ vocab_size = 10000 # Example vocabulary size
35
+ embedding_dim = 128
36
+ hidden_dim = 256
37
+ batch_size = 32
38
+ num_epochs = 10
39
+
40
+ # Initialize the LM
41
+ lm = SimpleLM(vocab_size, embedding_dim, hidden_dim)
42
+
43
+ # Load data
44
+ dataset = CustomDataset('training_data.json')
45
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
46
+
47
+ # Define loss function and optimizer
48
+ criterion = nn.CrossEntropyLoss()
49
+ optimizer = optim.Adam(lm.parameters(), lr=0.001)
50
+
51
+ # Training loop
52
+ for epoch in range(num_epochs):
53
+ total_loss = 0
54
+ for batch in dataloader:
55
+ optimizer.zero_grad()
56
+ input_data = batch[:, :-1] # Input sequence
57
+ target = batch[:, 1:] # Target sequence shifted by one
58
+ hidden = None
59
+
60
+ output, hidden = lm(input_data, hidden)
61
+ output = output.view(-1, vocab_size)
62
+ target = target.view(-1)
63
+
64
+ loss = criterion(output, target)
65
+ loss.backward()
66
+ optimizer.step()
67
+
68
+ total_loss += loss.item()
69
+
70
+ print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')
71
+
72
+ # Save the trained LM
73
+ torch.save(lm.state_dict(), 'simple_lm.pth')