Spaces:
Runtime error
Runtime error
from typing import Dict, List, Tuple, Optional | |
from tqdm import tqdm | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
from src.text_extractor import TextExtractor | |
from mdutils.mdutils import MdUtils | |
import torch | |
import fitz | |
import copy | |
class Summarizer(): | |
def __init__(self, model_name: str): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
self.model = PegasusForConditionalGeneration.from_pretrained(model_name).to(self.device) | |
self.preprocess = TextExtractor() | |
def extract_text(self, document: object) -> Dict[str, List[Tuple[str, str]]]: | |
doc = fitz.open(document) | |
self.filename = doc.name.split('/')[-1].split('.')[0] | |
font_counts, styles = self.preprocess.get_font_info(doc, granularity=False) | |
size_tag = self.preprocess.get_font_tags(font_counts, styles) | |
texts = self.preprocess.assign_tags(doc, size_tag) | |
slide_content = self.preprocess.get_slides(texts) | |
return slide_content | |
def __call__(self, slides: Dict[str, List[Tuple[str, str]]]) -> Dict[str, List[Tuple[str, str]]]: | |
summarized_slides = copy.deepcopy(slides) | |
for page, contents in tqdm(summarized_slides.items()): | |
for idx, (tag, content) in enumerate(contents): | |
if tag.startswith('p'): | |
try: | |
input = self.tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(self.device) | |
tensor = self.model.generate(**input) | |
summary = self.tokenizer.batch_decode(tensor, skip_special_tokens=True)[0] | |
contents[idx] = (tag, summary) | |
except Exception as e: | |
print(f"Summarization Fails, Error: {e}") | |
return summarized_slides | |
def convert2markdown(self, summarized_slides: Dict[str, List[Tuple[str, str]]], target_path: Optional[str]=None) -> str: | |
filename = self.filename | |
if target_path: | |
filename = target_path | |
mdFile = MdUtils(file_name=filename) | |
for k, v in summarized_slides.items(): | |
mdFile.new_line('---\n') | |
for section in v: | |
tag = section[0] | |
content = section[1] | |
if tag.startswith('h'): | |
mdFile.new_header(level=int(tag[1]), title=content) | |
if tag == 'p': | |
contents = content.split('<n>') | |
for content in contents: | |
mdFile.new_line(f"{content}\n") | |
markdown = mdFile.create_md_file() | |
return markdown | |