Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	
		meg-huggingface
		
	commited on
		
		
					Commit 
							
							·
						
						a2ae370
	
1
								Parent(s):
							
							335424f
								
More modularizing; npmi and labels
Browse files- app.py +5 -12
- data_measurements/dataset_statistics.py +20 -20
- data_measurements/streamlit_utils.py +4 -5
    	
        app.py
    CHANGED
    
    | @@ -118,9 +118,8 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False): | |
| 118 | 
             
                if show_embeddings:
         | 
| 119 | 
             
                    logs.warning("Loading Embeddings")
         | 
| 120 | 
             
                    dstats.load_or_prepare_embeddings()
         | 
| 121 | 
            -
                 | 
| 122 | 
            -
                 | 
| 123 | 
            -
                dstats.load_or_prepare_npmi_terms()
         | 
| 124 | 
             
                logs.warning("Loading Zipf")
         | 
| 125 | 
             
                dstats.load_or_prepare_zipf()
         | 
| 126 | 
             
                return dstats
         | 
| @@ -156,6 +155,8 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False): | |
| 156 | 
             
                    # Embeddings widget
         | 
| 157 | 
             
                    dstats.load_or_prepare_embeddings()
         | 
| 158 | 
             
                dstats.load_or_prepare_text_duplicates()
         | 
|  | |
|  | |
| 159 |  | 
| 160 | 
             
            def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
         | 
| 161 | 
             
                """
         | 
| @@ -179,17 +180,9 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T | |
| 179 | 
             
                st_utils.expander_label_distribution(dstats.fig_labels, column_id)
         | 
| 180 | 
             
                st_utils.expander_text_lengths(dstats, column_id)
         | 
| 181 | 
             
                st_utils.expander_text_duplicates(dstats, column_id)
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                # We do the loading of these after the others in order to have some time
         | 
| 184 | 
            -
                # to compute while the user works with the details above.
         | 
| 185 | 
             
                # Uses an interaction; handled a bit differently than other widgets.
         | 
| 186 | 
             
                logs.info("showing npmi widget")
         | 
| 187 | 
            -
                npmi_stats  | 
| 188 | 
            -
                    dstats, use_cache=use_cache
         | 
| 189 | 
            -
                )
         | 
| 190 | 
            -
                available_terms = npmi_stats.get_available_terms()
         | 
| 191 | 
            -
                st_utils.npmi_widget(
         | 
| 192 | 
            -
                    column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT)
         | 
| 193 | 
             
                logs.info("showing zipf")
         | 
| 194 | 
             
                st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
         | 
| 195 | 
             
                if show_embeddings:
         | 
|  | |
| 118 | 
             
                if show_embeddings:
         | 
| 119 | 
             
                    logs.warning("Loading Embeddings")
         | 
| 120 | 
             
                    dstats.load_or_prepare_embeddings()
         | 
| 121 | 
            +
                logs.warning("Loading nPMI")
         | 
| 122 | 
            +
                dstats.load_or_prepare_npmi()
         | 
|  | |
| 123 | 
             
                logs.warning("Loading Zipf")
         | 
| 124 | 
             
                dstats.load_or_prepare_zipf()
         | 
| 125 | 
             
                return dstats
         | 
|  | |
| 155 | 
             
                    # Embeddings widget
         | 
| 156 | 
             
                    dstats.load_or_prepare_embeddings()
         | 
| 157 | 
             
                dstats.load_or_prepare_text_duplicates()
         | 
| 158 | 
            +
                dstats.load_or_prepare_npmi()
         | 
| 159 | 
            +
                dstats.load_or_prepare_zipf()
         | 
| 160 |  | 
| 161 | 
             
            def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
         | 
| 162 | 
             
                """
         | 
|  | |
| 180 | 
             
                st_utils.expander_label_distribution(dstats.fig_labels, column_id)
         | 
| 181 | 
             
                st_utils.expander_text_lengths(dstats, column_id)
         | 
| 182 | 
             
                st_utils.expander_text_duplicates(dstats, column_id)
         | 
|  | |
|  | |
|  | |
| 183 | 
             
                # Uses an interaction; handled a bit differently than other widgets.
         | 
| 184 | 
             
                logs.info("showing npmi widget")
         | 
| 185 | 
            +
                st_utils.npmi_widget(dstats.npmi_stats, _MIN_VOCAB_COUNT, column_id)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 186 | 
             
                logs.info("showing zipf")
         | 
| 187 | 
             
                st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
         | 
| 188 | 
             
                if show_embeddings:
         | 
    	
        data_measurements/dataset_statistics.py
    CHANGED
    
    | @@ -231,10 +231,6 @@ class DatasetStatisticsCacheClass: | |
| 231 | 
             
                    # nPMI
         | 
| 232 | 
             
                    # Holds a nPMIStatisticsCacheClass object
         | 
| 233 | 
             
                    self.npmi_stats = None
         | 
| 234 | 
            -
                    # TODO: Users ideally can type in whatever words they want.
         | 
| 235 | 
            -
                    self.termlist = _IDENTITY_TERMS
         | 
| 236 | 
            -
                    # termlist terms that are available more than _MIN_VOCAB_COUNT times
         | 
| 237 | 
            -
                    self.available_terms = _IDENTITY_TERMS
         | 
| 238 | 
             
                    # TODO: Have lowercase be an option for a user to set.
         | 
| 239 | 
             
                    self.to_lowercase = True
         | 
| 240 | 
             
                    # The minimum amount of times a word should occur to be included in
         | 
| @@ -627,24 +623,27 @@ class DatasetStatisticsCacheClass: | |
| 627 | 
             
                            if save:
         | 
| 628 | 
             
                                write_plotly(self.fig_labels, self.fig_labels_fid)
         | 
| 629 | 
             
                        else:
         | 
| 630 | 
            -
                            self. | 
| 631 | 
            -
                            self.label_dset = self.dset.map(
         | 
| 632 | 
            -
                                lambda examples: extract_field(
         | 
| 633 | 
            -
                                    examples, self.label_field, OUR_LABEL_FIELD
         | 
| 634 | 
            -
                                ),
         | 
| 635 | 
            -
                                batched=True,
         | 
| 636 | 
            -
                                remove_columns=list(self.dset.features),
         | 
| 637 | 
            -
                            )
         | 
| 638 | 
            -
                            self.label_df = self.label_dset.to_pandas()
         | 
| 639 | 
            -
                            self.fig_labels = make_fig_labels(
         | 
| 640 | 
            -
                                self.label_df, self.label_names, OUR_LABEL_FIELD
         | 
| 641 | 
            -
                            )
         | 
| 642 | 
             
                            if save:
         | 
| 643 | 
             
                                # save extracted label instances
         | 
| 644 | 
             
                                self.label_dset.save_to_disk(self.label_dset_fid)
         | 
| 645 | 
             
                                write_plotly(self.fig_labels, self.fig_labels_fid)
         | 
| 646 |  | 
| 647 | 
            -
                def  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 648 | 
             
                    self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
         | 
| 649 | 
             
                    self.npmi_stats.load_or_prepare_npmi_terms()
         | 
| 650 |  | 
| @@ -693,7 +692,10 @@ class nPMIStatisticsCacheClass: | |
| 693 | 
             
                        # We need to preprocess everything.
         | 
| 694 | 
             
                        mkdir(self.pmi_cache_path)
         | 
| 695 | 
             
                    self.joint_npmi_df_dict = {}
         | 
| 696 | 
            -
                     | 
|  | |
|  | |
|  | |
| 697 | 
             
                    logs.info(self.termlist)
         | 
| 698 | 
             
                    self.use_cache = use_cache
         | 
| 699 | 
             
                    # TODO: Let users specify
         | 
| @@ -701,8 +703,6 @@ class nPMIStatisticsCacheClass: | |
| 701 | 
             
                    self.min_vocab_count = self.dstats.min_vocab_count
         | 
| 702 | 
             
                    self.subgroup_files = {}
         | 
| 703 | 
             
                    self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
         | 
| 704 | 
            -
                    self.available_terms = self.dstats.available_terms
         | 
| 705 | 
            -
                    logs.info(self.available_terms)
         | 
| 706 |  | 
| 707 | 
             
                def load_or_prepare_npmi_terms(self):
         | 
| 708 | 
             
                    """
         | 
|  | |
| 231 | 
             
                    # nPMI
         | 
| 232 | 
             
                    # Holds a nPMIStatisticsCacheClass object
         | 
| 233 | 
             
                    self.npmi_stats = None
         | 
|  | |
|  | |
|  | |
|  | |
| 234 | 
             
                    # TODO: Have lowercase be an option for a user to set.
         | 
| 235 | 
             
                    self.to_lowercase = True
         | 
| 236 | 
             
                    # The minimum amount of times a word should occur to be included in
         | 
|  | |
| 623 | 
             
                            if save:
         | 
| 624 | 
             
                                write_plotly(self.fig_labels, self.fig_labels_fid)
         | 
| 625 | 
             
                        else:
         | 
| 626 | 
            +
                            self.prepare_labels()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 627 | 
             
                            if save:
         | 
| 628 | 
             
                                # save extracted label instances
         | 
| 629 | 
             
                                self.label_dset.save_to_disk(self.label_dset_fid)
         | 
| 630 | 
             
                                write_plotly(self.fig_labels, self.fig_labels_fid)
         | 
| 631 |  | 
| 632 | 
            +
                def prepare_labels(self):
         | 
| 633 | 
            +
                    self.get_base_dataset()
         | 
| 634 | 
            +
                    self.label_dset = self.dset.map(
         | 
| 635 | 
            +
                        lambda examples: extract_field(
         | 
| 636 | 
            +
                            examples, self.label_field, OUR_LABEL_FIELD
         | 
| 637 | 
            +
                        ),
         | 
| 638 | 
            +
                        batched=True,
         | 
| 639 | 
            +
                        remove_columns=list(self.dset.features),
         | 
| 640 | 
            +
                    )
         | 
| 641 | 
            +
                    self.label_df = self.label_dset.to_pandas()
         | 
| 642 | 
            +
                    self.fig_labels = make_fig_labels(
         | 
| 643 | 
            +
                        self.label_df, self.label_names, OUR_LABEL_FIELD
         | 
| 644 | 
            +
                    )
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                def load_or_prepare_npmi(self):
         | 
| 647 | 
             
                    self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
         | 
| 648 | 
             
                    self.npmi_stats.load_or_prepare_npmi_terms()
         | 
| 649 |  | 
|  | |
| 692 | 
             
                        # We need to preprocess everything.
         | 
| 693 | 
             
                        mkdir(self.pmi_cache_path)
         | 
| 694 | 
             
                    self.joint_npmi_df_dict = {}
         | 
| 695 | 
            +
                    # TODO: Users ideally can type in whatever words they want.
         | 
| 696 | 
            +
                    self.termlist = _IDENTITY_TERMS
         | 
| 697 | 
            +
                    # termlist terms that are available more than _MIN_VOCAB_COUNT times
         | 
| 698 | 
            +
                    self.available_terms = _IDENTITY_TERMS
         | 
| 699 | 
             
                    logs.info(self.termlist)
         | 
| 700 | 
             
                    self.use_cache = use_cache
         | 
| 701 | 
             
                    # TODO: Let users specify
         | 
|  | |
| 703 | 
             
                    self.min_vocab_count = self.dstats.min_vocab_count
         | 
| 704 | 
             
                    self.subgroup_files = {}
         | 
| 705 | 
             
                    self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
         | 
|  | |
|  | |
| 706 |  | 
| 707 | 
             
                def load_or_prepare_npmi_terms(self):
         | 
| 708 | 
             
                    """
         | 
    	
        data_measurements/streamlit_utils.py
    CHANGED
    
    | @@ -273,7 +273,6 @@ def expander_text_duplicates(dstats, column_id): | |
| 273 | 
             
                    st.write(
         | 
| 274 | 
             
                        "### Here is the list of all the duplicated items and their counts in your dataset:"
         | 
| 275 | 
             
                    )
         | 
| 276 | 
            -
                    # Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
         | 
| 277 | 
             
                    if dstats.dup_counts_df is None:
         | 
| 278 | 
             
                        st.write("There are no duplicates in this dataset! 🥳")
         | 
| 279 | 
             
                    else:
         | 
| @@ -393,7 +392,7 @@ with an ideal α value of 1.""" | |
| 393 |  | 
| 394 |  | 
| 395 | 
             
            ### Finally finally finally, show nPMI stuff.
         | 
| 396 | 
            -
            def npmi_widget( | 
| 397 | 
             
                """
         | 
| 398 | 
             
                Part of the main app, but uses a user interaction so pulled out as its own f'n.
         | 
| 399 | 
             
                :param use_cache:
         | 
| @@ -403,16 +402,16 @@ def npmi_widget(column_id, available_terms, npmi_stats, min_vocab): | |
| 403 | 
             
                :return:
         | 
| 404 | 
             
                """
         | 
| 405 | 
             
                with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
         | 
| 406 | 
            -
                    if len(available_terms) > 0:
         | 
| 407 | 
             
                        expander_npmi_description(min_vocab)
         | 
| 408 | 
             
                        st.markdown("-----")
         | 
| 409 | 
             
                        term1 = st.selectbox(
         | 
| 410 | 
             
                            f"What is the first term you want to select?{column_id}",
         | 
| 411 | 
            -
                            available_terms,
         | 
| 412 | 
             
                        )
         | 
| 413 | 
             
                        term2 = st.selectbox(
         | 
| 414 | 
             
                            f"What is the second term you want to select?{column_id}",
         | 
| 415 | 
            -
                            reversed(available_terms),
         | 
| 416 | 
             
                        )
         | 
| 417 | 
             
                        # We calculate/grab nPMI data based on a canonical (alphabetic)
         | 
| 418 | 
             
                        # subgroup ordering.
         | 
|  | |
| 273 | 
             
                    st.write(
         | 
| 274 | 
             
                        "### Here is the list of all the duplicated items and their counts in your dataset:"
         | 
| 275 | 
             
                    )
         | 
|  | |
| 276 | 
             
                    if dstats.dup_counts_df is None:
         | 
| 277 | 
             
                        st.write("There are no duplicates in this dataset! 🥳")
         | 
| 278 | 
             
                    else:
         | 
|  | |
| 392 |  | 
| 393 |  | 
| 394 | 
             
            ### Finally finally finally, show nPMI stuff.
         | 
| 395 | 
            +
            def npmi_widget(npmi_stats, min_vocab, column_id):
         | 
| 396 | 
             
                """
         | 
| 397 | 
             
                Part of the main app, but uses a user interaction so pulled out as its own f'n.
         | 
| 398 | 
             
                :param use_cache:
         | 
|  | |
| 402 | 
             
                :return:
         | 
| 403 | 
             
                """
         | 
| 404 | 
             
                with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
         | 
| 405 | 
            +
                    if len(npmi_stats.available_terms) > 0:
         | 
| 406 | 
             
                        expander_npmi_description(min_vocab)
         | 
| 407 | 
             
                        st.markdown("-----")
         | 
| 408 | 
             
                        term1 = st.selectbox(
         | 
| 409 | 
             
                            f"What is the first term you want to select?{column_id}",
         | 
| 410 | 
            +
                            npmi_stats.available_terms,
         | 
| 411 | 
             
                        )
         | 
| 412 | 
             
                        term2 = st.selectbox(
         | 
| 413 | 
             
                            f"What is the second term you want to select?{column_id}",
         | 
| 414 | 
            +
                            reversed(npmi_stats.available_terms),
         | 
| 415 | 
             
                        )
         | 
| 416 | 
             
                        # We calculate/grab nPMI data based on a canonical (alphabetic)
         | 
| 417 | 
             
                        # subgroup ordering.
         | 
