Rehan3024 commited on
Commit
a7af8b0
·
verified ·
1 Parent(s): d218527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline, AutoModelForQuestionAnswering
3
  from sentence_transformers import SentenceTransformer
4
  import fitz # PyMuPDF
5
  import os
@@ -9,10 +9,9 @@ summarization_model_name = 'facebook/bart-large-cnn'
9
  tokenizer = AutoTokenizer.from_pretrained(summarization_model_name)
10
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name)
11
 
12
- qa_model_name = 'distilbert-base-uncased-distilled-squad'
13
  qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
14
  qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
15
- qa_pipeline = pipeline('question-answering', model=qa_model, tokenizer=qa_tokenizer)
16
 
17
  # Function to extract text from a PDF file
18
  def extract_text_from_pdf(file):
@@ -28,6 +27,15 @@ def summarize_document(document):
28
  summary_ids = summarization_model.generate(inputs['input_ids'], max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
29
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
30
 
 
 
 
 
 
 
 
 
 
31
  # Streamlit app
32
  st.title("PDF Summarizer and Q&A")
33
  st.write("Upload a PDF file to get a summary and ask questions about the content.")
@@ -57,9 +65,9 @@ if uploaded_file is not None:
57
  if st.button("Get Answer"):
58
  if question:
59
  with st.spinner('Generating answer...'):
60
- answer = qa_pipeline({'question': question, 'context': document_text})
61
  st.write("**Answer:**")
62
- st.write(answer['answer'])
63
  else:
64
  st.write("Please enter a question.")
65
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForQuestionAnswering
3
  from sentence_transformers import SentenceTransformer
4
  import fitz # PyMuPDF
5
  import os
 
9
  tokenizer = AutoTokenizer.from_pretrained(summarization_model_name)
10
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name)
11
 
12
+ qa_model_name = 'deepset/bert-large-uncased-whole-word-masking-squad2'
13
  qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
14
  qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
 
15
 
16
  # Function to extract text from a PDF file
17
  def extract_text_from_pdf(file):
 
27
  summary_ids = summarization_model.generate(inputs['input_ids'], max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
28
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
 
30
+ # Function to get answer to question
31
+ def get_answer(question, context):
32
+ inputs = qa_tokenizer(question, context, return_tensors="pt")
33
+ start_positions, end_positions = qa_model(**inputs)
34
+ answer_start = torch.argmax(start_positions)
35
+ answer_end = torch.argmax(end_positions) + 1
36
+ answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
37
+ return answer
38
+
39
  # Streamlit app
40
  st.title("PDF Summarizer and Q&A")
41
  st.write("Upload a PDF file to get a summary and ask questions about the content.")
 
65
  if st.button("Get Answer"):
66
  if question:
67
  with st.spinner('Generating answer...'):
68
+ answer = get_answer(question, document_text)
69
  st.write("**Answer:**")
70
+ st.write(answer)
71
  else:
72
  st.write("Please enter a question.")
73