Spaces:
Runtime error
Runtime error
Commit
·
e81673f
1
Parent(s):
cc7c5b8
implement Embedding Based Single Layer LM
Browse files
app.py
CHANGED
|
@@ -14,6 +14,11 @@ def init_count_model():
|
|
| 14 |
def init_single_layer_model():
|
| 15 |
return torch.load("single_layer.pt")
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
@st.cache_resource
|
| 18 |
def init_char_index_mappings():
|
| 19 |
with open("ctoi.json") as ci, open("itoc.json") as ic:
|
|
@@ -21,6 +26,7 @@ def init_char_index_mappings():
|
|
| 21 |
|
| 22 |
count_p = init_count_model()
|
| 23 |
single_layer_w = init_single_layer_model()
|
|
|
|
| 24 |
ctoi, itoc = init_char_index_mappings()
|
| 25 |
|
| 26 |
def predict_with_count(starting_char:str, num_words):
|
|
@@ -64,10 +70,35 @@ def predict_with_single_layer_nn(starting_char:str, num_words):
|
|
| 64 |
output.append(''.join(out[:-1]))
|
| 65 |
return output
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def predict(query, num_words):
|
| 68 |
try:
|
| 69 |
-
preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words)]
|
| 70 |
-
labels = ["Count Based
|
| 71 |
results = {labels[idx]: preds[idx] for idx in range(len(preds))}
|
| 72 |
st.write(pd.DataFrame(results, index=range(num_words)))
|
| 73 |
except ValueError as e:
|
|
|
|
| 14 |
def init_single_layer_model():
|
| 15 |
return torch.load("single_layer.pt")
|
| 16 |
|
| 17 |
+
@st.cache_resource
|
| 18 |
+
def init_mlp():
|
| 19 |
+
mlp_layers = torch.load("mlp.pt")
|
| 20 |
+
return mlp_layers["emb"], mlp_layers['w1'], mlp_layers['b1'], mlp_layers['w2'], mlp_layers['b2']
|
| 21 |
+
|
| 22 |
@st.cache_resource
|
| 23 |
def init_char_index_mappings():
|
| 24 |
with open("ctoi.json") as ci, open("itoc.json") as ic:
|
|
|
|
| 26 |
|
| 27 |
count_p = init_count_model()
|
| 28 |
single_layer_w = init_single_layer_model()
|
| 29 |
+
mlp_emb, mlp_w1, mlp_b1, mlp_w2, mlp_b2 = init_mlp()
|
| 30 |
ctoi, itoc = init_char_index_mappings()
|
| 31 |
|
| 32 |
def predict_with_count(starting_char:str, num_words):
|
|
|
|
| 70 |
output.append(''.join(out[:-1]))
|
| 71 |
return output
|
| 72 |
|
| 73 |
+
def predict_with_mlp(starting_char:str, num_words):
|
| 74 |
+
g = torch.Generator().manual_seed(SEED)
|
| 75 |
+
output = []
|
| 76 |
+
context_length = 3
|
| 77 |
+
for _ in range(num_words):
|
| 78 |
+
out = []
|
| 79 |
+
context = [0]*(context_length-1)
|
| 80 |
+
if starting_char not in ctoi:
|
| 81 |
+
raise ValueError("Starting Character is not a valid alphabet. Please input a valid alphabet.")
|
| 82 |
+
ix = ctoi[starting_char]
|
| 83 |
+
out.append(starting_char)
|
| 84 |
+
context+=[ix]
|
| 85 |
+
while True:
|
| 86 |
+
emb = mlp_emb[torch.tensor([context])]
|
| 87 |
+
h = torch.tanh(emb.view(1,-1) @ mlp_w1 + mlp_b1) # create batch_size 1
|
| 88 |
+
logits = h @ mlp_w2 + mlp_b2
|
| 89 |
+
probs = F.softmax(logits, dim=1)
|
| 90 |
+
ix = torch.multinomial(probs, num_samples=1, generator=g).item()
|
| 91 |
+
context = context[1:] + [ix]
|
| 92 |
+
out.append(itoc[str(ix)])
|
| 93 |
+
if ix == 0:
|
| 94 |
+
break
|
| 95 |
+
output.append(''.join(out[:-1]))
|
| 96 |
+
return output
|
| 97 |
+
|
| 98 |
def predict(query, num_words):
|
| 99 |
try:
|
| 100 |
+
preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words), predict_with_mlp(query, num_words)]
|
| 101 |
+
labels = ["Count Based LM", "Single Linear Layer LM", "Embedding Based Single Hidden Layer LM"]
|
| 102 |
results = {labels[idx]: preds[idx] for idx in range(len(preds))}
|
| 103 |
st.write(pd.DataFrame(results, index=range(num_words)))
|
| 104 |
except ValueError as e:
|
mlp.pt
ADDED
|
Binary file (49.3 kB). View file
|
|
|