Nainglinthu commited on
Commit
d47933d
·
verified ·
1 Parent(s): 96e0a00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -225
app.py CHANGED
@@ -1,232 +1,25 @@
1
- # -*- coding: utf-8 -*-
2
- """LegalTextClassification.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1x6EcLSN3qEgm6sVIcmX0bYeXj7AdDQlW
8
-
9
- import gradio as gr
10
-
11
- def greet(name):
12
- return "Hello " + name + "!!"
13
-
14
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
15
- demo.launch()
16
-
17
- #About Data
18
- The dataset contains a total of 25000 legal cases in the form of text documents. Each document has been annotated with catchphrases, citations sentences, citation catchphrases, and citation classes. Citation classes indicate the type of treatment given to the cases cited by the present case.
19
-
20
- The Legal Citation Text Classification dataset is provided in CSV format. The dataset has ***four columns***, ***namely Case ID, Case Outcome, Case Title, and Case Text***. The Case ID column contains a unique identifier for each legal case, the Case Outcome column indicates the outcome of the case, the Case Title column contains the title of the legal case, and the Case Text column contains the text of the legal case.
21
-
22
- Kaggle Dataset Link: https://www.kaggle.com/datasets/amohankumar/legal-text-classification-dataset/data
23
-
24
- #Importing Data
25
- """
26
-
27
-
28
- import pandas as pd
29
-
30
- df = pd.read_csv('legal_text_classification.csv')
31
- df.head()
32
-
33
- """#Data Preprocessing and Description"""
34
-
35
- print(df.columns) # Lists all column names
36
- print(len(df.columns)) # Shows the number of columns
37
-
38
- print(df.shape) # Output: (rows, columns)
39
-
40
- print(df.isnull().sum())
41
-
42
- df = df.dropna(subset=['case_text'])
43
-
44
- df = df.drop(columns=["case_id", "case_title"])
45
-
46
- print(df.isnull().sum())
47
-
48
- import re
49
-
50
- def text_ready(text):
51
- text = text.lower() #lowercase
52
- text = re.sub(r'[^\w\s]', '', text) #special char
53
- text = re.sub(r'\s+', ' ', text).strip() #whitespace
54
- return text
55
-
56
- df["text_ready"] = df["case_text"].apply(text_ready)
57
-
58
- import matplotlib.pyplot as plt
59
-
60
- text_data = df['text_ready']
61
- word_count = [len(text.split()) for text in text_data]
62
-
63
- plt.hist(word_count, bins=50, color='skyblue', edgecolor='black')
64
- plt.title('Distribution of Word Counts in text_ready')
65
- plt.xlabel('Word Count')
66
- plt.ylabel('Frequency')
67
- plt.show()
68
-
69
- print(df.shape) # Output: (rows, columns)
70
-
71
- df.describe()
72
-
73
- df['text']=df['text_ready']
74
- df['label']=df['case_outcome']
75
- data=df[['text','label']]
76
-
77
- df = df.drop(columns=["case_outcome", "case_text"])
78
-
79
- df.head()
80
-
81
- df = df.drop(columns=["text_ready"])
82
-
83
- df.head()
84
-
85
- data['label'].value_counts()
86
-
87
- class_label=sorted(data['label'].unique())
88
- lbl2id={label:id for id,label in enumerate(class_label)}
89
- id2lb={id:label for label,id in lbl2id.items()}
90
- print(lbl2id)
91
- print(id2lb)
92
-
93
-
94
-
95
- data.head()
96
-
97
- data['label']=data['label'].map(lbl2id)
98
- data.head()
99
-
100
- data.label.value_counts()
101
-
102
- import matplotlib.pyplot as plt
103
-
104
- df['label'].value_counts().plot.bar()
105
- plt.show()
106
-
107
- from transformers import AutoModelForSequenceClassification,AutoTokenizer
108
- model_name='nlpaueb/legal-bert-base-uncased'
109
- tokenizer=AutoTokenizer.from_pretrained(model_name)
110
-
111
- from transformers import AutoModelForSequenceClassification
112
- model = AutoModelForSequenceClassification.from_pretrained(
113
- model_name,
114
- num_labels=len(id2lb),
115
- id2label=id2lb,
116
- label2id=lbl2id
117
- )
118
-
119
-
120
- from datasets import Dataset
121
- ds=Dataset.from_pandas(data)
122
- ds
123
-
124
- ds['label'][:11]
125
-
126
- from datasets import ClassLabel
127
- unique_labels = sorted(set(ds['label']))
128
- print(f"Unique labels in Y: {unique_labels}")
129
-
130
- new_features = ds.features.copy()
131
- new_features['label'] = ClassLabel(names=unique_labels)
132
-
133
- ds = ds.cast(new_features)
134
- data = ds.train_test_split(test_size=0.2, shuffle=True, seed=42)
135
- data
136
-
137
- split_ds = data['test'].remove_columns('__index_level_0__').train_test_split(test_size=0.5, shuffle=True, seed=42)
138
- split_ds
139
-
140
- train_data=data['train']
141
- test_data=split_ds['train']
142
- val_data=split_ds['test']
143
-
144
- train_data[0]
145
-
146
- def tokenize_fun(data):
147
- return tokenizer(data['text'],padding=True,truncation=True,return_tensors='pt')
148
-
149
- tokenized_train_data=train_data.map(tokenize_fun,batched=True)
150
-
151
- tokenized_train_data.features
152
-
153
-
154
- import evaluate
155
- accuracy=evaluate.load('accuracy')
156
-
157
- import numpy as np
158
-
159
- def compute_metrics(eval_pred):
160
- predictions, labels = eval_pred
161
- predictions = np.argmax(predictions, axis=1)
162
- return accuracy.compute(predictions=predictions, references=labels)
163
-
164
- tokenized_test_data=test_data.map(tokenize_fun,batched=True)
165
- tokenized_val_data=val_data.map(tokenize_fun,batched=True)
166
-
167
- from huggingface_hub import login
168
- login()
169
-
170
- from transformers import Trainer,TrainingArguments
171
-
172
- training_args=TrainingArguments(
173
- output_dir='./quest_model',
174
- learning_rate=2e-3,
175
- per_device_eval_batch_size=16,
176
- per_device_train_batch_size=16,
177
- num_train_epochs=2,
178
- weight_decay=0.01,
179
- eval_strategy='epoch',
180
- save_strategy='epoch',
181
- load_best_model_at_end=True,
182
- push_to_hub=True
183
- )
184
-
185
- trainer=Trainer(
186
- model=model,
187
- tokenizer=tokenizer,
188
- args=training_args,
189
- train_dataset=tokenized_train_data,
190
- eval_dataset=tokenized_val_data,
191
- compute_metrics=compute_metrics
192
- )
193
- trainer.train()
194
-
195
- model.config.id2label
196
-
197
- import os
198
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
199
-
200
- model.save_pretrained('./quest_model')
201
- tokenizer.save_pretrained("./quest_model")
202
-
203
- tokenized_train_data[0]['text']
204
-
205
  from transformers import pipeline
206
- pipe=pipeline('text-classification',model='Nainglinthu/quest_model')
207
- output=pipe('Hexal Australia Pty Ltd v Roche Therapeutics Inc (2005) 66 IPR 325, the likelihood of irreparable harm was regarded by Stone J as, indeed, a separate element that had to be established by an applicant for an interlocutory injunction.')
208
- output
209
-
210
-
211
  import gradio as gr
212
- from transformers import pipeline
213
 
214
- # Initialize the pipeline
215
- pipe = pipeline('text-classification', model='Nainglinthu/quest_model')
 
 
 
216
 
217
- # Function to classify text
218
- def classify_text(input_text):
219
- output = pipe(input_text)
220
- return output
221
 
222
- # Create Gradio interface
223
- interface = gr.Interface(
224
- fn=classify_text, # Function to call
225
- inputs="text", # Input type (text box)
226
- outputs="json", # Output type (JSON for displaying result)
227
- title="Legal Text Classifier", # Title of the Gradio app
228
- description="Classify legal text using the Nainglinthu/quest_model!", # Description
229
  )
230
 
231
- # Launch the Gradio ap
232
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
 
 
 
 
 
2
  import gradio as gr
 
3
 
4
+ # Load your model & tokenizer from your saved local folder or HF repo
5
+ model_path = "Nainglinthu/quest_model" # your Hugging Face model repo name
6
+
7
+ # Initialize pipeline once
8
+ classifier = pipeline("text-classification", model=model_path)
9
 
10
+ # Define function to classify text
11
+ def classify_text(text):
12
+ results = classifier(text)
13
+ return results
14
 
15
+ # Gradio interface setup
16
+ iface = gr.Interface(
17
+ fn=classify_text,
18
+ inputs=gr.Textbox(lines=5, placeholder="Enter legal text here..."),
19
+ outputs=gr.JSON(),
20
+ title="Legal Text Classification",
21
+ description="Classify legal text using your fine-tuned Legal BERT model."
22
  )
23
 
24
+ if __name__ == "__main__":
25
+ iface.launch()