|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import torch |
|
import re |
|
import spaces |
|
|
|
def split_into_sentences(text): |
|
sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s') |
|
sentences = sentence_endings.split(text) |
|
return [sentence.strip() for sentence in sentences if sentence] |
|
|
|
@spaces.GPU |
|
def process_paragraph(paragraph, progress=gr.Progress()): |
|
sentences = split_into_sentences(paragraph) |
|
results = [] |
|
total_sentences = len(sentences) |
|
for i, sentence in enumerate(sentences): |
|
progress((i + 1) / total_sentences) |
|
messages.append({"role": "user", "content": sentence}) |
|
sentence_response = "" |
|
inputs = tokenizer(sentence, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_new_tokens=300, temperature=0.7, top_p=0.9, top_k=50) |
|
sentence_response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
category = sentence_response.strip().lower().replace(' ', '_') |
|
if category != "fair": |
|
results.append((sentence, category)) |
|
else: |
|
results.append((sentence, "fair")) |
|
messages.append({"role": "assistant", "content": sentence_response}) |
|
torch.cuda.empty_cache() |
|
return results |
|
|
|
|
|
model_name = "princeton-nlp/Llama-3-Instruct-8B-SimPO" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
messages = [] |
|
|
|
|
|
label_to_color = { |
|
"fair": "green", |
|
"limitation_of_liability": "red", |
|
"unilateral_termination": "orange", |
|
"unilateral_change": "yellow", |
|
"content_removal": "purple", |
|
"contract_by_using": "blue", |
|
"choice_of_law": "cyan", |
|
"jurisdiction": "magenta", |
|
"arbitration": "brown", |
|
} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
input_text = gr.Textbox(label="Input Paragraph", lines=10, placeholder="Enter the paragraph here...") |
|
btn = gr.Button("Process") |
|
with gr.Column(): |
|
output = gr.HighlightedText(label="Processed Paragraph", color_map=label_to_color) |
|
progress = gr.Progress() |
|
|
|
def on_click(paragraph): |
|
results = process_paragraph(paragraph, progress=progress) |
|
return results |
|
|
|
btn.click(on_click, inputs=input_text, outputs=[output]) |
|
|
|
demo.launch(share=True) |
|
|