File size: 7,803 Bytes
333cd91 acc8063 0eb38a5 333cd91 acc8063 0eb38a5 340628c acc8063 333cd91 acc8063 fb6b907 340628c acc8063 cdfc9b6 340628c acc8063 340628c fb6b907 9519c42 340628c 333cd91 acc8063 fb6b907 340628c acc8063 cdfc9b6 340628c acc8063 340628c fb6b907 9519c42 340628c 333cd91 acc8063 fb6b907 340628c acc8063 cdfc9b6 340628c acc8063 340628c fb6b907 9519c42 340628c 333cd91 0eb38a5 cdfc9b6 340628c 77122ee 0eb38a5 77122ee cdfc9b6 acc8063 cdfc9b6 340628c cdfc9b6 77122ee cdfc9b6 acc8063 cdfc9b6 340628c cdfc9b6 77122ee cdfc9b6 acc8063 cdfc9b6 340628c cdfc9b6 77122ee 340628c 77122ee 85abbe0 c16aaec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
# Setup logging (optional, but helpful for debugging)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# Load the Flan-T5 Small model and tokenizer
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
def correct_htr(raw_htr_text, max_new_tokens, temperature):
try:
if not raw_htr_text:
raise ValueError("Input text cannot be empty.")
logging.info("Processing HTR correction with Flan-T5 Small...")
prompt = f"Correct this text: {raw_htr_text}"
inputs = tokenizer(prompt, return_tensors="pt")
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.debug(f"Generated output for HTR correction: {corrected_text}")
return corrected_text
except ValueError as ve:
logging.warning(f"Validation error: {ve}")
return str(ve)
except Exception as e:
logging.error(f"Error in HTR correction: {e}", exc_info=True)
return "An error occurred while processing the text."
def summarize_text(legal_text, max_new_tokens, temperature):
try:
if not legal_text:
raise ValueError("Input text cannot be empty.")
logging.info("Processing summarization with Flan-T5 Small...")
prompt = f"Summarize the following legal text: {legal_text}"
inputs = tokenizer(prompt, return_tensors="pt")
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.debug(f"Generated summary: {summary}")
return summary
except ValueError as ve:
logging.warning(f"Validation error: {ve}")
return str(ve)
except Exception as e:
logging.error(f"Error in summarization: {e}", exc_info=True)
return "An error occurred while summarizing the text."
def answer_question(legal_text, question, max_new_tokens, temperature):
try:
if not legal_text or not question:
raise ValueError("Both legal text and question must be provided.")
logging.info("Processing question-answering with Flan-T5 Small...")
prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
inputs = tokenizer(prompt, return_tensors="pt")
max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.debug(f"Generated answer: {answer}")
return answer
except ValueError as ve:
logging.warning(f"Validation error: {ve}")
return str(ve)
except Exception as e:
logging.error(f"Error in question-answering: {e}", exc_info=True)
return "An error occurred while answering the question."
def clear_fields():
return "", "", ""
# Create the Gradio Blocks interface
with gr.Blocks(css=".block .input-slider { color: blue !important }") as demo:
gr.Markdown("# Flan-T5 Small Legal Assistant")
gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")
with gr.Row():
gr.HTML('''
<div style="display: flex; gap: 10px;">
<div style="border: 2px solid black; padding: 10px; display: inline-block;">
<a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" target="_blank">
<button style="font-weight:bold;">Admiralty Court Legal Glossary</button>
</a>
</div>
<div style="border: 2px solid black; padding: 10px; display: inline-block;">
<a href="https://raw.githubusercontent.com/Addaci/HCA/refs/heads/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" target="_blank">
<button style="font-weight:bold;">HCA 13/70 Ground Truth (1654-55)</button>
</a>
</div>
</div>
''')
with gr.Tab("Correct HTR"):
gr.Markdown("### Correct Raw HTR Text")
raw_htr_input = gr.Textbox(lines=5, placeholder="Enter raw HTR text here...")
corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
correct_button = gr.Button("Correct HTR")
clear_button = gr.Button("Clear")
correct_max_new_tokens = gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens")
correct_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
correct_button.click(correct_htr, inputs=[raw_htr_input, correct_max_new_tokens, correct_temperature], outputs=corrected_output)
clear_button.click(clear_fields, outputs=[raw_htr_input, corrected_output])
gr.Markdown("### Set Parameters")
correct_max_new_tokens.render()
correct_temperature.render()
with gr.Tab("Summarize Legal Text"):
gr.Markdown("### Summarize Legal Text")
legal_text_input = gr.Textbox(lines=10, placeholder="Enter legal text to summarize...")
summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
summarize_button = gr.Button("Summarize Text")
clear_button = gr.Button("Clear")
summarize_max_new_tokens = gr.Slider(minimum=10, maximum=1024, value=256, step=1, label="Max New Tokens")
summarize_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Temperature")
summarize_button.click(summarize_text, inputs=[legal_text_input, summarize_max_new_tokens, summarize_temperature], outputs=summary_output)
clear_button.click(clear_fields, outputs=[legal_text_input, summary_output])
gr.Markdown("### Set Parameters")
summarize_max_new_tokens.render()
summarize_temperature.render()
with gr.Tab("Answer Legal Question"):
gr.Markdown("### Answer a Question Based on Legal Text")
legal_text_input_q = gr.Textbox(lines=10, placeholder="Enter legal text...")
question_input = gr.Textbox(lines=2, placeholder="Enter your question...")
answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
answer_button = gr.Button("Get Answer")
clear_button = gr.Button("Clear")
answer_max_new_tokens = gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max New Tokens")
answer_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature")
answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, answer_max_new_tokens, answer_temperature], outputs=answer_output)
clear_button.click(clear_fields, outputs=[legal_text_input_q, question_input, answer_output])
gr.Markdown("### Set Parameters")
answer_max_new_tokens.render()
answer_temperature.render()
# Model warm-up (optional, but useful for performance)
model.generate(**tokenizer("Warm-up", return_tensors="pt"), max_length=10)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch() |