Paula Leonova
commited on
Commit
·
a6b5529
1
Parent(s):
0a49db3
Clean up notes
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# Reference: https://huggingface.co/spaces/team-zero-shot-nli/zero-shot-nli/blob/main/app.py
|
| 2 |
|
| 3 |
from os import write
|
| 4 |
import pandas as pd
|
|
@@ -8,7 +7,6 @@ import streamlit as st
|
|
| 8 |
|
| 9 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
| 10 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
| 11 |
-
# from utils import plot_result, examples_load, example_long_text_load, to_excel
|
| 12 |
import json
|
| 13 |
|
| 14 |
|
|
@@ -31,8 +29,6 @@ if __name__ == '__main__':
|
|
| 31 |
if text_input == display_text:
|
| 32 |
text_input = example_text
|
| 33 |
|
| 34 |
-
# minimum_tokens = 30
|
| 35 |
-
# maximum_tokens = 100
|
| 36 |
labels = st.text_input('Possible labels (comma-separated):',ex_labels, max_chars=1000)
|
| 37 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
| 38 |
submit_button = st.form_submit_button(label='Submit')
|
|
@@ -41,8 +37,6 @@ if __name__ == '__main__':
|
|
| 41 |
if len(labels) == 0:
|
| 42 |
st.write('Enter some text and at least one possible topic to see predictions.')
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
# For each body of text, create text chunks of a certain token size required for the transformer
|
| 47 |
nested_sentences = create_nest_sentences(document = text_input, token_max_length = 1024)
|
| 48 |
|
|
@@ -69,21 +63,17 @@ if __name__ == '__main__':
|
|
| 69 |
st.markdown(final_summary)
|
| 70 |
|
| 71 |
topics, scores = classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
|
| 72 |
-
|
| 73 |
# st.markdown("### Top Label Predictions: Combined Summary")
|
| 74 |
# plot_result(topics[::-1][:], scores[::-1][:])
|
| 75 |
-
|
| 76 |
# st.markdown("### Download Data")
|
| 77 |
data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
|
| 78 |
# st.dataframe(data)
|
| 79 |
-
|
| 80 |
# coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
|
| 81 |
# st.markdown(
|
| 82 |
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
|
| 83 |
# unsafe_allow_html = True
|
| 84 |
# )
|
| 85 |
|
| 86 |
-
|
| 87 |
st.markdown("### Top Label Predictions: Summary & Full Text")
|
| 88 |
topics_ex_text, scores_ex_text = classifier_zero(classifier, sequence=example_text, labels=labels, multi_class=True)
|
| 89 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
|
|
|
|
|
|
| 1 |
|
| 2 |
from os import write
|
| 3 |
import pandas as pd
|
|
|
|
| 7 |
|
| 8 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
| 9 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
|
|
|
| 10 |
import json
|
| 11 |
|
| 12 |
|
|
|
|
| 29 |
if text_input == display_text:
|
| 30 |
text_input = example_text
|
| 31 |
|
|
|
|
|
|
|
| 32 |
labels = st.text_input('Possible labels (comma-separated):',ex_labels, max_chars=1000)
|
| 33 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
| 34 |
submit_button = st.form_submit_button(label='Submit')
|
|
|
|
| 37 |
if len(labels) == 0:
|
| 38 |
st.write('Enter some text and at least one possible topic to see predictions.')
|
| 39 |
|
|
|
|
|
|
|
| 40 |
# For each body of text, create text chunks of a certain token size required for the transformer
|
| 41 |
nested_sentences = create_nest_sentences(document = text_input, token_max_length = 1024)
|
| 42 |
|
|
|
|
| 63 |
st.markdown(final_summary)
|
| 64 |
|
| 65 |
topics, scores = classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
|
|
|
|
| 66 |
# st.markdown("### Top Label Predictions: Combined Summary")
|
| 67 |
# plot_result(topics[::-1][:], scores[::-1][:])
|
|
|
|
| 68 |
# st.markdown("### Download Data")
|
| 69 |
data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
|
| 70 |
# st.dataframe(data)
|
|
|
|
| 71 |
# coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
|
| 72 |
# st.markdown(
|
| 73 |
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
|
| 74 |
# unsafe_allow_html = True
|
| 75 |
# )
|
| 76 |
|
|
|
|
| 77 |
st.markdown("### Top Label Predictions: Summary & Full Text")
|
| 78 |
topics_ex_text, scores_ex_text = classifier_zero(classifier, sequence=example_text, labels=labels, multi_class=True)
|
| 79 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
models.py
CHANGED
|
@@ -33,7 +33,6 @@ def load_summary_model():
|
|
| 33 |
summarizer = pipeline(task='summarization', model=model_name)
|
| 34 |
return summarizer
|
| 35 |
|
| 36 |
-
|
| 37 |
# def load_summary_model():
|
| 38 |
# model_name = "facebook/bart-large-mnli"
|
| 39 |
# tokenizer = BartTokenizer.from_pretrained(model_name)
|
|
@@ -41,7 +40,6 @@ def load_summary_model():
|
|
| 41 |
# summarizer = pipeline(task='summarization', model=model, tokenizer=tokenizer, framework='pt')
|
| 42 |
# return summarizer
|
| 43 |
|
| 44 |
-
|
| 45 |
def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:int):
|
| 46 |
output = summarizer(sequence, num_beams=4, max_length=maximum_tokens, min_length=minimum_tokens, do_sample=False)
|
| 47 |
return output[0].get('summary_text')
|
|
|
|
| 33 |
summarizer = pipeline(task='summarization', model=model_name)
|
| 34 |
return summarizer
|
| 35 |
|
|
|
|
| 36 |
# def load_summary_model():
|
| 37 |
# model_name = "facebook/bart-large-mnli"
|
| 38 |
# tokenizer = BartTokenizer.from_pretrained(model_name)
|
|
|
|
| 40 |
# summarizer = pipeline(task='summarization', model=model, tokenizer=tokenizer, framework='pt')
|
| 41 |
# return summarizer
|
| 42 |
|
|
|
|
| 43 |
def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:int):
|
| 44 |
output = summarizer(sequence, num_beams=4, max_length=maximum_tokens, min_length=minimum_tokens, do_sample=False)
|
| 45 |
return output[0].get('summary_text')
|