Something / app.py
Pclanglais's picture
Update app.py
26cf348 verified
import transformers
import re
from transformers import AutoTokenizer, pipeline
import torch
import html
import gradio as gr
import tempfile
import os
import pandas as pd
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models
editorial_model = "LLMDH/Estienne"
bibliography_model = "PleIAs/Bibliography-Formatter"
bibliography_style = "PleIAs/Bibliography-Classifier"
tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512)
editorial_classifier = pipeline(
"token-classification", model=editorial_model, aggregation_strategy="simple", device=device
)
bibliography_classifier = pipeline(
"token-classification", model=bibliography_model, aggregation_strategy="simple", device=device
)
# Helper functions
def preprocess_text(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'\n', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def split_text(text, max_tokens=500):
parts = text.split("\n")
chunks = []
current_chunk = ""
for part in parts:
temp_chunk = current_chunk + "\n" + part if current_chunk else part
num_tokens = len(tokenizer.tokenize(temp_chunk))
if num_tokens <= max_tokens:
current_chunk = temp_chunk
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = part
if current_chunk:
chunks.append(current_chunk)
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
long_text = chunks[0]
chunks = []
while len(tokenizer.tokenize(long_text)) > max_tokens:
split_point = len(long_text) // 2
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
split_point += 1
if split_point >= len(long_text):
split_point = len(long_text) - 1
chunks.append(long_text[:split_point].strip())
long_text = long_text[split_point:].strip()
if long_text:
chunks.append(long_text)
return chunks
def disambiguate_bibtex_ids(bibtex_entries):
id_count = {}
disambiguated_entries = []
for entry in bibtex_entries:
# Extract the current ID
match = re.search(r'@\w+{(\w+),', entry)
if not match:
disambiguated_entries.append(entry)
continue
original_id = match.group(1)
# Check if this ID has been seen before
if original_id in id_count:
id_count[original_id] += 1
new_id = f"{original_id}{chr(96 + id_count[original_id])}" # 'a', 'b', 'c', etc.
new_entry = re.sub(r'(@\w+{)(\w+)(,)', f'\\1{new_id}\\3', entry, 1)
disambiguated_entries.append(new_entry)
else:
id_count[original_id] = 0
disambiguated_entries.append(entry)
return disambiguated_entries
def remove_punctuation(text):
return re.sub(r'[^\w\s]', '', text)
def extract_year(text):
year_match = re.search(r'\b(\d{4})\b', text)
return year_match.group(1) if year_match else None
def create_bibtex_entry(data):
if 'journal' in data:
entry_type = 'article'
elif 'booktitle' in data:
entry_type = 'inproceedings'
else:
entry_type = 'book'
none_content = data.pop('none', '')
year = extract_year(none_content)
if year and 'year' not in data:
data['year'] = year
if "year" in data:
match_year = re.search(r'(\d{4})', data['year'])
if match_year:
data['year'] = match_year.group(1)
year = data['year']
else:
data.pop('year', '')
#Pages conformity.
if 'pages' in data:
match = re.search(r'(\d+(-\d+)?)', data['pages'])
if match:
data['pages'] = match.group(1)
else:
data.pop('pages', '')
author_words = data.get('author', '').split()
first_author = author_words[0] if author_words else 'unknown'
bibtex_id = f"{first_author}{year}" if year else first_author
bibtex_id = remove_punctuation(bibtex_id.lower())
bibtex = f"@{entry_type}{{{bibtex_id},\n"
for key, value in data.items():
if value.strip():
if key in ['volume', 'year']:
value = remove_punctuation(value)
if key == 'pages':
value = value.replace('p. ', '')
if key != "separator":
bibtex += f" {key.lower()} = {{{value.strip()}}},\n"
bibtex = bibtex.rstrip(',\n') + "\n}"
return bibtex
def save_bibtex(bibtex_content):
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.bib') as temp_file:
temp_file.write(bibtex_content)
return temp_file.name
class CombinedProcessor:
def process(self, user_message):
#Precaution to reinforce bibliography detection.
editorial_text = "Bibliography\n" + user_message
#Our fix for the lack of newline in deberta
editorial_text = re.sub("\n", " ¶ ", editorial_text)
print(editorial_text)
num_tokens = len(tokenizer.tokenize(editorial_text))
batch_prompts = split_text(editorial_text, max_tokens=500) if num_tokens > 500 else [editorial_text]
editorial_out = editorial_classifier(batch_prompts)
editorial_df = pd.concat([pd.DataFrame(classification) for classification in editorial_out])
# Filter out only bibliography entries
bibliography_entries = editorial_df[editorial_df['entity_group'] == 'bibliography']['word'].tolist()
bibtex_entries = []
list_style = []
for entry in bibliography_entries:
print(entry)
entry = re.sub(r'- ?[\n¶] ?', r'', entry)
entry = re.sub(r' ?[\n¶] ?', r' ', entry)
#style = pd.DataFrame(style_classifier(entry, truncation=True, padding=True, top_k=1))
#list_style.append(style)
entry = re.sub(r'\s*([;:,\.])\s*', r' \1 ', entry)
#print(entry)
bib_out = bibliography_classifier(entry)
bib_df = pd.DataFrame(bib_out)
bibtex_data = {}
current_entity = None
for _, row in bib_df.iterrows():
entity_group = row['entity_group']
word = row['word']
if entity_group != 'None':
if entity_group in bibtex_data:
print(entity_group)
if entity_group == "author":
bibtex_data[entity_group] += ', ' + word
else:
bibtex_data[entity_group] += ' ' + word
else:
bibtex_data[entity_group] = word
current_entity = entity_group
else:
if current_entity:
if current_entity == "author":
bibtex_data[current_entity] += ', ' + word
else:
bibtex_data[current_entity] += ' ' + word
else:
bibtex_data['None'] = bibtex_data.get('None', '') + ' ' + word
bibtex_entry = create_bibtex_entry(bibtex_data)
bibtex_entries.append(bibtex_entry)
#list_style = pd.concat(list_style)
#list_style = list_style.groupby('label')['score'].mean().sort_values(ascending=False).reset_index()
#top_style = list_style.iloc[0]['label']
#top_style_score = list_style.iloc[0]['score']
# Create the style information string
#style_info = f"Top bibliography style: {top_style} (Mean score: {top_style_score:.6f})"
# Join BibTeX entries
bibtex_content = "\n\n".join(bibtex_entries)
#return style_info, bibtex_content
return bibtex_content
# Create the processor instance
processor = CombinedProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Reversed Zotero</h1>""")
text_input = gr.Textbox(label="Your text", type="text", lines=10)
text_button = gr.Button("Process Text")
bibtex_output = gr.Textbox(label="BibTeX Entries", lines=15)
export_button = gr.Button("Export BibTeX")
export_output = gr.File(label="Exported BibTeX File")
text_button.click(processor.process, inputs=text_input, outputs=[bibtex_output])
export_button.click(save_bibtex, inputs=[bibtex_output], outputs=[export_output])
if __name__ == "__main__":
demo.queue().launch()