Spaces:
Runtime error
Runtime error
Commit
·
e81673f
1
Parent(s):
cc7c5b8
implement Embedding Based Single Layer LM
Browse files
app.py
CHANGED
@@ -14,6 +14,11 @@ def init_count_model():
|
|
14 |
def init_single_layer_model():
|
15 |
return torch.load("single_layer.pt")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
@st.cache_resource
|
18 |
def init_char_index_mappings():
|
19 |
with open("ctoi.json") as ci, open("itoc.json") as ic:
|
@@ -21,6 +26,7 @@ def init_char_index_mappings():
|
|
21 |
|
22 |
count_p = init_count_model()
|
23 |
single_layer_w = init_single_layer_model()
|
|
|
24 |
ctoi, itoc = init_char_index_mappings()
|
25 |
|
26 |
def predict_with_count(starting_char:str, num_words):
|
@@ -64,10 +70,35 @@ def predict_with_single_layer_nn(starting_char:str, num_words):
|
|
64 |
output.append(''.join(out[:-1]))
|
65 |
return output
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def predict(query, num_words):
|
68 |
try:
|
69 |
-
preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words)]
|
70 |
-
labels = ["Count Based
|
71 |
results = {labels[idx]: preds[idx] for idx in range(len(preds))}
|
72 |
st.write(pd.DataFrame(results, index=range(num_words)))
|
73 |
except ValueError as e:
|
|
|
14 |
def init_single_layer_model():
|
15 |
return torch.load("single_layer.pt")
|
16 |
|
17 |
+
@st.cache_resource
|
18 |
+
def init_mlp():
|
19 |
+
mlp_layers = torch.load("mlp.pt")
|
20 |
+
return mlp_layers["emb"], mlp_layers['w1'], mlp_layers['b1'], mlp_layers['w2'], mlp_layers['b2']
|
21 |
+
|
22 |
@st.cache_resource
|
23 |
def init_char_index_mappings():
|
24 |
with open("ctoi.json") as ci, open("itoc.json") as ic:
|
|
|
26 |
|
27 |
count_p = init_count_model()
|
28 |
single_layer_w = init_single_layer_model()
|
29 |
+
mlp_emb, mlp_w1, mlp_b1, mlp_w2, mlp_b2 = init_mlp()
|
30 |
ctoi, itoc = init_char_index_mappings()
|
31 |
|
32 |
def predict_with_count(starting_char:str, num_words):
|
|
|
70 |
output.append(''.join(out[:-1]))
|
71 |
return output
|
72 |
|
73 |
+
def predict_with_mlp(starting_char:str, num_words):
|
74 |
+
g = torch.Generator().manual_seed(SEED)
|
75 |
+
output = []
|
76 |
+
context_length = 3
|
77 |
+
for _ in range(num_words):
|
78 |
+
out = []
|
79 |
+
context = [0]*(context_length-1)
|
80 |
+
if starting_char not in ctoi:
|
81 |
+
raise ValueError("Starting Character is not a valid alphabet. Please input a valid alphabet.")
|
82 |
+
ix = ctoi[starting_char]
|
83 |
+
out.append(starting_char)
|
84 |
+
context+=[ix]
|
85 |
+
while True:
|
86 |
+
emb = mlp_emb[torch.tensor([context])]
|
87 |
+
h = torch.tanh(emb.view(1,-1) @ mlp_w1 + mlp_b1) # create batch_size 1
|
88 |
+
logits = h @ mlp_w2 + mlp_b2
|
89 |
+
probs = F.softmax(logits, dim=1)
|
90 |
+
ix = torch.multinomial(probs, num_samples=1, generator=g).item()
|
91 |
+
context = context[1:] + [ix]
|
92 |
+
out.append(itoc[str(ix)])
|
93 |
+
if ix == 0:
|
94 |
+
break
|
95 |
+
output.append(''.join(out[:-1]))
|
96 |
+
return output
|
97 |
+
|
98 |
def predict(query, num_words):
|
99 |
try:
|
100 |
+
preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words), predict_with_mlp(query, num_words)]
|
101 |
+
labels = ["Count Based LM", "Single Linear Layer LM", "Embedding Based Single Hidden Layer LM"]
|
102 |
results = {labels[idx]: preds[idx] for idx in range(len(preds))}
|
103 |
st.write(pd.DataFrame(results, index=range(num_words)))
|
104 |
except ValueError as e:
|
mlp.pt
ADDED
Binary file (49.3 kB). View file
|
|