Spaces:
Build error
Build error
Merging from rollback
Browse files- app.py +8 -0
- data_measurements/embeddings.py +3 -4
- data_measurements/streamlit_utils.py +14 -12
app.py
CHANGED
@@ -122,6 +122,12 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
|
|
122 |
dstats.load_or_prepare_zipf()
|
123 |
return dstats
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
|
126 |
"""
|
127 |
Loader specifically for the widgets used in the app.
|
@@ -144,6 +150,8 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
|
|
144 |
dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
|
145 |
# Don't recalculate; we're live
|
146 |
dstats.set_deployment(True)
|
|
|
|
|
147 |
# Header widget
|
148 |
dstats.load_or_prepare_dset_peek()
|
149 |
# General stats widget
|
|
|
122 |
dstats.load_or_prepare_zipf()
|
123 |
return dstats
|
124 |
|
125 |
+
@st.cache(
|
126 |
+
hash_funcs={
|
127 |
+
dataset_statistics.DatasetStatisticsCacheClass: lambda dstats: dstats.cache_path
|
128 |
+
},
|
129 |
+
allow_output_mutation=True,
|
130 |
+
)
|
131 |
def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
|
132 |
"""
|
133 |
Loader specifically for the widgets used in the app.
|
|
|
150 |
dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
|
151 |
# Don't recalculate; we're live
|
152 |
dstats.set_deployment(True)
|
153 |
+
# We need to have the text_dset loaded for further load_or_prepare
|
154 |
+
dstats.load_or_prepare_dataset()
|
155 |
# Header widget
|
156 |
dstats.load_or_prepare_dset_peek()
|
157 |
# General stats widget
|
data_measurements/embeddings.py
CHANGED
@@ -146,11 +146,12 @@ class Embeddings:
|
|
146 |
[(node["nid"], nid) for nid, node in enumerate(self.node_list)]
|
147 |
)
|
148 |
torch.save((self.node_list, self.nid_map), self.node_list_fid)
|
|
|
149 |
if self.use_cache and exists(self.fig_tree_fid):
|
150 |
self.fig_tree = read_json(self.fig_tree_fid)
|
151 |
else:
|
152 |
self.fig_tree = make_tree_plot(
|
153 |
-
self.node_list, self.text_dset, self.text_field_name
|
154 |
)
|
155 |
self.fig_tree.write_json(self.fig_tree_fid)
|
156 |
|
@@ -460,14 +461,12 @@ def fast_cluster(
|
|
460 |
return node_list
|
461 |
|
462 |
|
463 |
-
def make_tree_plot(node_list, text_dset, text_field_name):
|
464 |
"""
|
465 |
Makes a graphical representation of the tree encoded
|
466 |
in node-list. The hover label for each node shows the number
|
467 |
of descendants and the 5 examples that are closest to the centroid
|
468 |
"""
|
469 |
-
nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
|
470 |
-
|
471 |
for nid, node in enumerate(node_list):
|
472 |
# get list of
|
473 |
node_examples = {}
|
|
|
146 |
[(node["nid"], nid) for nid, node in enumerate(self.node_list)]
|
147 |
)
|
148 |
torch.save((self.node_list, self.nid_map), self.node_list_fid)
|
149 |
+
print(exists(self.fig_tree_fid), self.fig_tree_fid)
|
150 |
if self.use_cache and exists(self.fig_tree_fid):
|
151 |
self.fig_tree = read_json(self.fig_tree_fid)
|
152 |
else:
|
153 |
self.fig_tree = make_tree_plot(
|
154 |
+
self.node_list, self.nid_map, self.text_dset, self.text_field_name
|
155 |
)
|
156 |
self.fig_tree.write_json(self.fig_tree_fid)
|
157 |
|
|
|
461 |
return node_list
|
462 |
|
463 |
|
464 |
+
def make_tree_plot(node_list, nid_map, text_dset, text_field_name):
|
465 |
"""
|
466 |
Makes a graphical representation of the tree encoded
|
467 |
in node-list. The hover label for each node shows the number
|
468 |
of descendants and the 5 examples that are closest to the centroid
|
469 |
"""
|
|
|
|
|
470 |
for nid, node in enumerate(node_list):
|
471 |
# get list of
|
472 |
node_examples = {}
|
data_measurements/streamlit_utils.py
CHANGED
@@ -21,6 +21,7 @@ from st_aggrid import AgGrid, GridOptionsBuilder
|
|
21 |
|
22 |
from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
|
23 |
|
|
|
24 |
def sidebar_header():
|
25 |
st.sidebar.markdown(
|
26 |
"""
|
@@ -107,9 +108,7 @@ def expander_general_stats(dstats, column_id):
|
|
107 |
"Use this widget to check whether the terms you see most represented"
|
108 |
" in the dataset make sense for the goals of the dataset."
|
109 |
)
|
110 |
-
st.markdown(
|
111 |
-
"There are {0} total words".format(str(dstats.total_words))
|
112 |
-
)
|
113 |
st.markdown(
|
114 |
"There are {0} words after removing closed "
|
115 |
"class words".format(str(dstats.total_open_words))
|
@@ -129,14 +128,10 @@ def expander_general_stats(dstats, column_id):
|
|
129 |
st.markdown(
|
130 |
"There are {0} duplicate items in the dataset. "
|
131 |
"For more information about the duplicates, "
|
132 |
-
"click the 'Duplicates' tab below.".format(
|
133 |
-
str(dstats.dedup_total)
|
134 |
-
)
|
135 |
)
|
136 |
else:
|
137 |
-
st.markdown(
|
138 |
-
"There are 0 duplicate items in the dataset. ")
|
139 |
-
|
140 |
|
141 |
|
142 |
### Show the label distribution from the datasets
|
@@ -166,7 +161,6 @@ def expander_text_lengths(dstats, column_id):
|
|
166 |
st.markdown(
|
167 |
"### Here is the relative frequency of different text lengths in your dataset:"
|
168 |
)
|
169 |
-
#TODO: figure out more elegant way to do this:
|
170 |
try:
|
171 |
st.image(dstats.fig_tok_length_png)
|
172 |
except:
|
@@ -181,8 +175,16 @@ def expander_text_lengths(dstats, column_id):
|
|
181 |
# This is quite a large file and is breaking our ability to navigate the app development.
|
182 |
# Just passing if it's not already there for launch v0
|
183 |
if dstats.length_df is not None:
|
184 |
-
start_id_show_lengths= st.selectbox(
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
### Third, use a sentence embedding model
|
|
|
21 |
|
22 |
from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
|
23 |
|
24 |
+
|
25 |
def sidebar_header():
|
26 |
st.sidebar.markdown(
|
27 |
"""
|
|
|
108 |
"Use this widget to check whether the terms you see most represented"
|
109 |
" in the dataset make sense for the goals of the dataset."
|
110 |
)
|
111 |
+
st.markdown("There are {0} total words".format(str(dstats.total_words)))
|
|
|
|
|
112 |
st.markdown(
|
113 |
"There are {0} words after removing closed "
|
114 |
"class words".format(str(dstats.total_open_words))
|
|
|
128 |
st.markdown(
|
129 |
"There are {0} duplicate items in the dataset. "
|
130 |
"For more information about the duplicates, "
|
131 |
+
"click the 'Duplicates' tab below.".format(str(dstats.dedup_total))
|
|
|
|
|
132 |
)
|
133 |
else:
|
134 |
+
st.markdown("There are 0 duplicate items in the dataset. ")
|
|
|
|
|
135 |
|
136 |
|
137 |
### Show the label distribution from the datasets
|
|
|
161 |
st.markdown(
|
162 |
"### Here is the relative frequency of different text lengths in your dataset:"
|
163 |
)
|
|
|
164 |
try:
|
165 |
st.image(dstats.fig_tok_length_png)
|
166 |
except:
|
|
|
175 |
# This is quite a large file and is breaking our ability to navigate the app development.
|
176 |
# Just passing if it's not already there for launch v0
|
177 |
if dstats.length_df is not None:
|
178 |
+
start_id_show_lengths = st.selectbox(
|
179 |
+
"Show examples of length:",
|
180 |
+
sorted(dstats.length_df["length"].unique().tolist()),
|
181 |
+
key=f"select_show_length_{column_id}",
|
182 |
+
)
|
183 |
+
st.table(
|
184 |
+
dstats.length_df[
|
185 |
+
dstats.length_df["length"] == start_id_show_lengths
|
186 |
+
].set_index("length")
|
187 |
+
)
|
188 |
|
189 |
|
190 |
### Third, use a sentence embedding model
|