Lauraayu commited on
Commit
6412832
·
verified ·
1 Parent(s): 5888fc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -18
app.py CHANGED
@@ -1,29 +1,46 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelWithLMHead
 
3
 
4
- # 加载模型和分词器
5
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
6
  model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
7
 
8
- # 定义摘要函数
 
 
 
 
 
 
 
 
 
9
  def summarize(text, max_length=150):
10
  input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
11
  generated_ids = model.generate(input_ids=input_ids, num_beams=2, max_length=max_length, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True)
12
  preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
13
  return preds[0]
14
 
15
- # Streamlit 应用程序界面
16
- st.title("News Summarization App")
17
- st.write("Enter the news article text below to generate a summary.")
18
-
19
- article = st.text_area("News Article", height=300)
20
- max_len = st.slider("Max Length of Summary", min_value=50, max_value=300, value=150)
21
-
22
- if st.button("Summarize"):
23
- if article:
24
- with st.spinner("Generating summary..."):
25
- summary = summarize(article, max_length=max_len)
26
- st.write("**Summary:**")
27
- st.write(summary)
28
- else:
29
- st.error("Please enter some text to summarize.")
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
+ # Load the tokenizer and model for classification
6
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
7
  model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
8
 
9
+ tokenizer_bb = AutoTokenizer.from_pretrained("your-username/your-model-name")
10
+ model_bb = AutoModelForSequenceClassification.from_pretrained("your-username/your-model-name")
11
+
12
+ # Streamlit application title
13
+ st.title("News Article Summarizer and Classifier")
14
+ st.write("Enter a news article text to get its summary and category.")
15
+
16
+ # Text input for user to enter the news article text
17
+ text = st.text_area("Enter the news article text here:")
18
+
19
  def summarize(text, max_length=150):
20
  input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
21
  generated_ids = model.generate(input_ids=input_ids, num_beams=2, max_length=max_length, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True)
22
  preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
23
  return preds[0]
24
 
25
+
26
+ # Perform summarization and classification when the user clicks the "Classify" button
27
+ if st.button("Classify"):
28
+ # Perform text summarization
29
+ with st.spinner("Generating summary..."):
30
+ summary = summarize(article)
31
+
32
+ # Tokenize the summarized text
33
+ inputs = tokenizer_bb(summary, return_tensors="pt", truncation=True, padding=True, max_length=512)
34
+
35
+ # Perform text classification
36
+ with torch.no_grad():
37
+ outputs = model_bb(**inputs)
38
+
39
+ # Get the predicted label
40
+ predicted_label_id = torch.argmax(outputs.logits, dim=-1).item()
41
+ label_mapping = model_bb.config.id2label
42
+ predicted_label = label_mapping[predicted_label_id]
43
+
44
+ # Display the summary and classification result
45
+ st.write("Summary:", summary)
46
+ st.write("Category:", predicted_label)