Spaces:
Runtime error
Runtime error
add hash function, temperature, random seed
Browse files- app/app.py +34 -14
- requirements.txt +2 -1
app/app.py
CHANGED
@@ -5,6 +5,7 @@ from prompts import PROMPT_LIST
|
|
5 |
import random
|
6 |
import time
|
7 |
from transformers import pipeline, set_seed
|
|
|
8 |
|
9 |
# st.set_page_config(page_title="Image Search")
|
10 |
|
@@ -19,13 +20,14 @@ def get_generator():
|
|
19 |
return text_generator
|
20 |
|
21 |
|
22 |
-
|
23 |
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
24 |
-
temperature: float = 1.0, max_time: float =
|
25 |
-
|
26 |
set_seed(seed)
|
27 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
28 |
-
top_k=top_k, top_p=top_p, temperature=temperature,
|
|
|
29 |
return result
|
30 |
|
31 |
|
@@ -65,37 +67,55 @@ max_length = st.sidebar.number_input(
|
|
65 |
help="The maximum length of the sequence to be generated."
|
66 |
)
|
67 |
|
68 |
-
|
69 |
"Temperature",
|
70 |
value=1.0,
|
71 |
min_value=0.0,
|
72 |
-
max_value=
|
73 |
)
|
74 |
|
75 |
-
|
76 |
-
"
|
77 |
-
value=
|
78 |
)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
)
|
84 |
|
|
|
85 |
text_generator = get_generator()
|
86 |
if st.button("Run"):
|
87 |
with st.spinner(text="Getting results..."):
|
88 |
st.subheader("Result")
|
89 |
time_start = time.time()
|
90 |
-
result = process(text=session_state.text, max_length=int(max_length),
|
|
|
|
|
91 |
time_end = time.time()
|
92 |
time_diff = time_end-time_start
|
93 |
-
#print(f"Text generated in {time_diff} seconds")
|
94 |
result = result[0]["generated_text"]
|
95 |
st.write(result.replace("\n", " \n"))
|
96 |
st.text("Translation")
|
97 |
translation = translate(result, "en", "id")
|
98 |
st.write(translation.replace("\n", " \n"))
|
|
|
|
|
99 |
|
100 |
# Reset state
|
101 |
session_state.prompt = None
|
|
|
5 |
import random
|
6 |
import time
|
7 |
from transformers import pipeline, set_seed
|
8 |
+
import tokenizers
|
9 |
|
10 |
# st.set_page_config(page_title="Image Search")
|
11 |
|
|
|
20 |
return text_generator
|
21 |
|
22 |
|
23 |
+
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
24 |
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
25 |
+
temperature: float = 1.0, max_time: float = 10.0, seed=42):
|
26 |
+
st.write("Cache miss: process")
|
27 |
set_seed(seed)
|
28 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
29 |
+
top_k=top_k, top_p=top_p, temperature=temperature,
|
30 |
+
max_time=max_time)
|
31 |
return result
|
32 |
|
33 |
|
|
|
67 |
help="The maximum length of the sequence to be generated."
|
68 |
)
|
69 |
|
70 |
+
temperature = st.sidebar.slider(
|
71 |
"Temperature",
|
72 |
value=1.0,
|
73 |
min_value=0.0,
|
74 |
+
max_value=10.0
|
75 |
)
|
76 |
|
77 |
+
do_sample = st.sidebar.checkbox(
|
78 |
+
"Use sampling",
|
79 |
+
value=True
|
80 |
)
|
81 |
|
82 |
+
top_k = 25
|
83 |
+
top_p = 0.95
|
84 |
+
|
85 |
+
if do_sample:
|
86 |
+
top_k = st.sidebar.number_input(
|
87 |
+
"Top k",
|
88 |
+
value=top_k
|
89 |
+
)
|
90 |
+
top_p = st.sidebar.number_input(
|
91 |
+
"Top p",
|
92 |
+
value=top_p
|
93 |
+
)
|
94 |
+
|
95 |
+
seed = st.sidebar.number_input(
|
96 |
+
"Random Seed",
|
97 |
+
value=25,
|
98 |
+
help="The number used to initialize a pseudorandom number generator"
|
99 |
)
|
100 |
|
101 |
+
|
102 |
text_generator = get_generator()
|
103 |
if st.button("Run"):
|
104 |
with st.spinner(text="Getting results..."):
|
105 |
st.subheader("Result")
|
106 |
time_start = time.time()
|
107 |
+
result = process(text=session_state.text, max_length=int(max_length),
|
108 |
+
temperature=temperature, do_sample=do_sample,
|
109 |
+
top_k=int(top_k), top_p=float(top_p), seed=seed)
|
110 |
time_end = time.time()
|
111 |
time_diff = time_end-time_start
|
|
|
112 |
result = result[0]["generated_text"]
|
113 |
st.write(result.replace("\n", " \n"))
|
114 |
st.text("Translation")
|
115 |
translation = translate(result, "en", "id")
|
116 |
st.write(translation.replace("\n", " \n"))
|
117 |
+
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
|
118 |
+
st.write(f"*Text generated in {time_diff:.5} seconds*")
|
119 |
|
120 |
# Reset state
|
121 |
session_state.prompt = None
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ transformers
|
|
4 |
datasets
|
5 |
mtranslate
|
6 |
# streamlit version 0.67.1 is needed due to issue with caching
|
7 |
-
streamlit==0.67.1
|
|
|
|
4 |
datasets
|
5 |
mtranslate
|
6 |
# streamlit version 0.67.1 is needed due to issue with caching
|
7 |
+
# streamlit==0.67.1
|
8 |
+
streamlit
|