File size: 7,803 Bytes
333cd91
acc8063
0eb38a5
333cd91
acc8063
0eb38a5
 
340628c
 
acc8063
 
333cd91
acc8063
fb6b907
340628c
 
 
 
acc8063
 
cdfc9b6
340628c
acc8063
 
 
340628c
 
 
fb6b907
9519c42
340628c
333cd91
acc8063
fb6b907
340628c
 
 
 
acc8063
 
cdfc9b6
340628c
acc8063
 
 
340628c
 
 
fb6b907
9519c42
340628c
333cd91
acc8063
fb6b907
340628c
 
 
 
acc8063
 
cdfc9b6
340628c
acc8063
 
 
340628c
 
 
fb6b907
9519c42
340628c
 
 
 
333cd91
0eb38a5
cdfc9b6
340628c
 
77122ee
 
0eb38a5
 
 
 
 
 
 
 
 
 
 
 
 
 
77122ee
 
 
 
 
 
 
cdfc9b6
 
acc8063
cdfc9b6
340628c
cdfc9b6
 
 
 
77122ee
 
 
 
 
 
 
cdfc9b6
 
acc8063
cdfc9b6
340628c
cdfc9b6
 
 
 
77122ee
 
 
 
 
 
 
 
cdfc9b6
 
acc8063
cdfc9b6
340628c
cdfc9b6
 
 
 
77122ee
340628c
 
 
77122ee
85abbe0
c16aaec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging

# Setup logging (optional, but helpful for debugging)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Load the Flan-T5 Small model and tokenizer
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

def correct_htr(raw_htr_text, max_new_tokens, temperature):
    try:
        if not raw_htr_text:
            raise ValueError("Input text cannot be empty.")
        
        logging.info("Processing HTR correction with Flan-T5 Small...")
        prompt = f"Correct this text: {raw_htr_text}"
        inputs = tokenizer(prompt, return_tensors="pt")
        max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
        outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
        corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logging.debug(f"Generated output for HTR correction: {corrected_text}")
        return corrected_text
    except ValueError as ve:
        logging.warning(f"Validation error: {ve}")
        return str(ve)
    except Exception as e:
        logging.error(f"Error in HTR correction: {e}", exc_info=True)
        return "An error occurred while processing the text."

def summarize_text(legal_text, max_new_tokens, temperature):
    try:
        if not legal_text:
            raise ValueError("Input text cannot be empty.")
        
        logging.info("Processing summarization with Flan-T5 Small...")
        prompt = f"Summarize the following legal text: {legal_text}"
        inputs = tokenizer(prompt, return_tensors="pt")
        max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
        outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logging.debug(f"Generated summary: {summary}")
        return summary
    except ValueError as ve:
        logging.warning(f"Validation error: {ve}")
        return str(ve)
    except Exception as e:
        logging.error(f"Error in summarization: {e}", exc_info=True)
        return "An error occurred while summarizing the text."

def answer_question(legal_text, question, max_new_tokens, temperature):
    try:
        if not legal_text or not question:
            raise ValueError("Both legal text and question must be provided.")
        
        logging.info("Processing question-answering with Flan-T5 Small...")
        prompt = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
        inputs = tokenizer(prompt, return_tensors="pt")
        max_length = min(max_new_tokens, len(inputs['input_ids'][0]) + max_new_tokens)
        outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        logging.debug(f"Generated answer: {answer}")
        return answer
    except ValueError as ve:
        logging.warning(f"Validation error: {ve}")
        return str(ve)
    except Exception as e:
        logging.error(f"Error in question-answering: {e}", exc_info=True)
        return "An error occurred while answering the question."

def clear_fields():
    return "", "", ""

# Create the Gradio Blocks interface
with gr.Blocks(css=".block .input-slider { color: blue !important }") as demo:
    gr.Markdown("# Flan-T5 Small Legal Assistant")
    gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases (powered by Flan-T5 Small).")

    with gr.Row():
        gr.HTML('''
            <div style="display: flex; gap: 10px;">
                <div style="border: 2px solid black; padding: 10px; display: inline-block;">
                    <a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" target="_blank">
                        <button style="font-weight:bold;">Admiralty Court Legal Glossary</button>
                    </a>
                </div>
                <div style="border: 2px solid black; padding: 10px; display: inline-block;">
                    <a href="https://raw.githubusercontent.com/Addaci/HCA/refs/heads/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" target="_blank">
                        <button style="font-weight:bold;">HCA 13/70 Ground Truth (1654-55)</button>
                    </a>
                </div>
            </div>
        ''')

    with gr.Tab("Correct HTR"):
        gr.Markdown("### Correct Raw HTR Text")
        raw_htr_input = gr.Textbox(lines=5, placeholder="Enter raw HTR text here...")
        corrected_output = gr.Textbox(lines=5, placeholder="Corrected HTR text")
        correct_button = gr.Button("Correct HTR")
        clear_button = gr.Button("Clear")
        correct_max_new_tokens = gr.Slider(minimum=10, maximum=512, value=128, step=1, label="Max New Tokens")
        correct_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")

        correct_button.click(correct_htr, inputs=[raw_htr_input, correct_max_new_tokens, correct_temperature], outputs=corrected_output)
        clear_button.click(clear_fields, outputs=[raw_htr_input, corrected_output])
        
        gr.Markdown("### Set Parameters")
        correct_max_new_tokens.render()
        correct_temperature.render()

    with gr.Tab("Summarize Legal Text"):
        gr.Markdown("### Summarize Legal Text")
        legal_text_input = gr.Textbox(lines=10, placeholder="Enter legal text to summarize...")
        summary_output = gr.Textbox(lines=5, placeholder="Summary of legal text")
        summarize_button = gr.Button("Summarize Text")
        clear_button = gr.Button("Clear")
        summarize_max_new_tokens = gr.Slider(minimum=10, maximum=1024, value=256, step=1, label="Max New Tokens")
        summarize_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.1, label="Temperature")

        summarize_button.click(summarize_text, inputs=[legal_text_input, summarize_max_new_tokens, summarize_temperature], outputs=summary_output)
        clear_button.click(clear_fields, outputs=[legal_text_input, summary_output])
        
        gr.Markdown("### Set Parameters")
        summarize_max_new_tokens.render()
        summarize_temperature.render()

    with gr.Tab("Answer Legal Question"):
        gr.Markdown("### Answer a Question Based on Legal Text")
        legal_text_input_q = gr.Textbox(lines=10, placeholder="Enter legal text...")
        question_input = gr.Textbox(lines=2, placeholder="Enter your question...")
        answer_output = gr.Textbox(lines=5, placeholder="Answer to your question")
        answer_button = gr.Button("Get Answer")
        clear_button = gr.Button("Clear")
        answer_max_new_tokens = gr.Slider(minimum=10, maximum=512, value=150, step=1, label="Max New Tokens")
        answer_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature")

        answer_button.click(answer_question, inputs=[legal_text_input_q, question_input, answer_max_new_tokens, answer_temperature], outputs=answer_output)
        clear_button.click(clear_fields, outputs=[legal_text_input_q, question_input, answer_output])
        
        gr.Markdown("### Set Parameters")
        answer_max_new_tokens.render()
        answer_temperature.render()

# Model warm-up (optional, but useful for performance)
model.generate(**tokenizer("Warm-up", return_tensors="pt"), max_length=10)

# Launch the Gradio interface
if __name__ == "__main__":
    demo.launch()