Spaces:
Runtime error
Runtime error
Commit
·
cc7c5b8
1
Parent(s):
3e4ffd8
fix non-alphabets
Browse files
app.py
CHANGED
@@ -27,6 +27,8 @@ def predict_with_count(starting_char:str, num_words):
|
|
27 |
g = torch.Generator().manual_seed(SEED)
|
28 |
output = []
|
29 |
for _ in range(num_words):
|
|
|
|
|
30 |
prev = ctoi[starting_char]
|
31 |
out = []
|
32 |
out.append(starting_char)
|
@@ -45,6 +47,8 @@ def predict_with_single_layer_nn(starting_char:str, num_words):
|
|
45 |
output = []
|
46 |
for _ in range(num_words):
|
47 |
out = []
|
|
|
|
|
48 |
ix = ctoi[starting_char]
|
49 |
out.append(starting_char)
|
50 |
while True:
|
@@ -61,10 +65,13 @@ def predict_with_single_layer_nn(starting_char:str, num_words):
|
|
61 |
return output
|
62 |
|
63 |
def predict(query, num_words):
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
|
69 |
# title and description
|
70 |
st.title("""
|
@@ -79,4 +86,4 @@ query = st.text_input("Please input the starting character...", "", max_chars=1)
|
|
79 |
num_words = st.slider("Number of names to generate:", min_value=1, max_value=50, value=5)
|
80 |
|
81 |
if query != "":
|
82 |
-
predict(query, num_words)
|
|
|
27 |
g = torch.Generator().manual_seed(SEED)
|
28 |
output = []
|
29 |
for _ in range(num_words):
|
30 |
+
if starting_char not in ctoi:
|
31 |
+
raise ValueError("Starting Character is not a valid alphabet. Please input a valid alphabet.")
|
32 |
prev = ctoi[starting_char]
|
33 |
out = []
|
34 |
out.append(starting_char)
|
|
|
47 |
output = []
|
48 |
for _ in range(num_words):
|
49 |
out = []
|
50 |
+
if starting_char not in ctoi:
|
51 |
+
raise ValueError("Starting Character is not a valid alphabet. Please input a valid alphabet.")
|
52 |
ix = ctoi[starting_char]
|
53 |
out.append(starting_char)
|
54 |
while True:
|
|
|
65 |
return output
|
66 |
|
67 |
def predict(query, num_words):
|
68 |
+
try:
|
69 |
+
preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words)]
|
70 |
+
labels = ["Count Based Language Model", "Single Linear Layer Language Model"]
|
71 |
+
results = {labels[idx]: preds[idx] for idx in range(len(preds))}
|
72 |
+
st.write(pd.DataFrame(results, index=range(num_words)))
|
73 |
+
except ValueError as e:
|
74 |
+
st.write(f"ERROR: {e.args[0]}")
|
75 |
|
76 |
# title and description
|
77 |
st.title("""
|
|
|
86 |
num_words = st.slider("Number of names to generate:", min_value=1, max_value=50, value=5)
|
87 |
|
88 |
if query != "":
|
89 |
+
predict(query.lower(), num_words)
|