Davidsamuel101 commited on
Commit
f0a8738
·
1 Parent(s): ae81020

Add application file

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text_extractor import TextExtractor
2
+ from tqdm import tqdm
3
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
4
+ from transformers import pipeline
5
+ from mdutils.mdutils import MdUtils
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import fitz
10
+ import torch
11
+ import copy
12
+ import os
13
+
14
+ FILENAME = ""
15
+
16
+ preprocess = TextExtractor()
17
+ model_name = "google/pegasus-cnn_dailymail"
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ tokenizer = PegasusTokenizer.from_pretrained(model_name, max_length=500)
20
+ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
21
+ summarizer = pipeline(task="summarization", model="google/pegasus-cnn_dailymail", tokenizer=tokenizer, batch_size=1, device=-1)
22
+
23
+ def summarize(slides):
24
+ generated_slides = copy.deepcopy(slides)
25
+ for page, contents in tqdm(generated_slides.items()):
26
+ for idx, (tag, content) in enumerate(contents):
27
+ if tag.startswith('p'):
28
+ try:
29
+ input = tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(device)
30
+ tensor = model.generate(**input)
31
+ summary = tokenizer.batch_decode(tensor, skip_special_tokens=True)[0]
32
+ contents[idx] = (tag, summary)
33
+ except Exception as e:
34
+ print(e)
35
+ print("Summarization Fails")
36
+ return generated_slides
37
+
38
+
39
+ def convert2markdown(generate_slides):
40
+ # save_path = f"tmp/{FILENAME}"
41
+ mdFile = MdUtils(file_name=FILENAME, title=f'{FILENAME} Presentation')
42
+ for k, v in generate_slides.items():
43
+ mdFile.new_paragraph('---')
44
+ for section in v:
45
+ tag = section[0]
46
+ content = section[1]
47
+ if tag.startswith('h'):
48
+ mdFile.new_header(level=int(tag[1]), title=content)
49
+ if tag == 'p':
50
+ contents = content.split('<n>')
51
+ for content in contents:
52
+ mdFile.new_paragraph(content)
53
+ mdFile.create_md_file()
54
+ return f"{FILENAME}.md"
55
+
56
+ def inference(document):
57
+ global FILENAME
58
+ doc = fitz.open(document)
59
+ FILENAME = Path(doc.name).stem
60
+ print(FILENAME)
61
+ font_counts, styles = preprocess.get_font_info(doc, granularity=False)
62
+ size_tag = preprocess.get_font_tags(font_counts, styles)
63
+ texts = preprocess.assign_tags(doc, size_tag)
64
+ slides = preprocess.get_slides(texts)
65
+ generated_slides = summarize(slides)
66
+ markdown_path = convert2markdown(generated_slides)
67
+
68
+ return markdown_path
69
+
70
+
71
+ with gr.Blocks() as demo:
72
+ inp = gr.File( file_types=['pdf'])
73
+ out = gr.File(type="file", label="Markdown")
74
+ inference_btn = gr.Button("Summarized PDF")
75
+ inference_btn.click(fn=inference, inputs=inp, outputs=out, show_progress=True, api_name="summarize")
76
+
77
+ demo.launch()