Nainglinthu commited on
Commit
3d56e9e
·
verified ·
1 Parent(s): ae9a20f

Upload legaltextclassification.py

Browse files
Files changed (1) hide show
  1. legaltextclassification.py +232 -0
legaltextclassification.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """LegalTextClassification.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1x6EcLSN3qEgm6sVIcmX0bYeXj7AdDQlW
8
+ !pip install gradio
9
+ import gradio as gr
10
+
11
+ def greet(name):
12
+ return "Hello " + name + "!!"
13
+
14
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
15
+ demo.launch()
16
+
17
+ #About Data
18
+ The dataset contains a total of 25000 legal cases in the form of text documents. Each document has been annotated with catchphrases, citations sentences, citation catchphrases, and citation classes. Citation classes indicate the type of treatment given to the cases cited by the present case.
19
+
20
+ The Legal Citation Text Classification dataset is provided in CSV format. The dataset has ***four columns***, ***namely Case ID, Case Outcome, Case Title, and Case Text***. The Case ID column contains a unique identifier for each legal case, the Case Outcome column indicates the outcome of the case, the Case Title column contains the title of the legal case, and the Case Text column contains the text of the legal case.
21
+
22
+ Kaggle Dataset Link: https://www.kaggle.com/datasets/amohankumar/legal-text-classification-dataset/data
23
+
24
+ #Importing Data
25
+ """
26
+
27
+ from google.colab import files
28
+ import pandas as pd
29
+
30
+ df = pd.read_csv('legal_text_classification.csv')
31
+ df.head()
32
+
33
+ """#Data Preprocessing and Description"""
34
+
35
+ print(df.columns) # Lists all column names
36
+ print(len(df.columns)) # Shows the number of columns
37
+
38
+ print(df.shape) # Output: (rows, columns)
39
+
40
+ print(df.isnull().sum())
41
+
42
+ df = df.dropna(subset=['case_text'])
43
+
44
+ df = df.drop(columns=["case_id", "case_title"])
45
+
46
+ print(df.isnull().sum())
47
+
48
+ import re
49
+
50
+ def text_ready(text):
51
+ text = text.lower() #lowercase
52
+ text = re.sub(r'[^\w\s]', '', text) #special char
53
+ text = re.sub(r'\s+', ' ', text).strip() #whitespace
54
+ return text
55
+
56
+ df["text_ready"] = df["case_text"].apply(text_ready)
57
+
58
+ import matplotlib.pyplot as plt
59
+
60
+ text_data = df['text_ready']
61
+ word_count = [len(text.split()) for text in text_data]
62
+
63
+ plt.hist(word_count, bins=50, color='skyblue', edgecolor='black')
64
+ plt.title('Distribution of Word Counts in text_ready')
65
+ plt.xlabel('Word Count')
66
+ plt.ylabel('Frequency')
67
+ plt.show()
68
+
69
+ print(df.shape) # Output: (rows, columns)
70
+
71
+ df.describe()
72
+
73
+ df['text']=df['text_ready']
74
+ df['label']=df['case_outcome']
75
+ data=df[['text','label']]
76
+
77
+ df = df.drop(columns=["case_outcome", "case_text"])
78
+
79
+ df.head()
80
+
81
+ df = df.drop(columns=["text_ready"])
82
+
83
+ df.head()
84
+
85
+ data['label'].value_counts()
86
+
87
+ class_label=sorted(data['label'].unique())
88
+ lbl2id={label:id for id,label in enumerate(class_label)}
89
+ id2lb={id:label for label,id in lbl2id.items()}
90
+ print(lbl2id)
91
+ print(id2lb)
92
+
93
+
94
+
95
+ data.head()
96
+
97
+ data['label']=data['label'].map(lbl2id)
98
+ data.head()
99
+
100
+ data.label.value_counts()
101
+
102
+ import matplotlib.pyplot as plt
103
+
104
+ df['label'].value_counts().plot.bar()
105
+ plt.show()
106
+
107
+ from transformers import AutoModelForSequenceClassification,AutoTokenizer
108
+ model_name='nlpaueb/legal-bert-base-uncased'
109
+ tokenizer=AutoTokenizer.from_pretrained(model_name)
110
+
111
+ from transformers import AutoModelForSequenceClassification
112
+ model = AutoModelForSequenceClassification.from_pretrained(
113
+ model_name,
114
+ num_labels=len(id2lb),
115
+ id2label=id2lb,
116
+ label2id=lbl2id
117
+ )
118
+
119
+ !pip install datasets
120
+ from datasets import Dataset
121
+ ds=Dataset.from_pandas(data)
122
+ ds
123
+
124
+ ds['label'][:11]
125
+
126
+ from datasets import ClassLabel
127
+ unique_labels = sorted(set(ds['label']))
128
+ print(f"Unique labels in Y: {unique_labels}")
129
+
130
+ new_features = ds.features.copy()
131
+ new_features['label'] = ClassLabel(names=unique_labels)
132
+
133
+ ds = ds.cast(new_features)
134
+ data = ds.train_test_split(test_size=0.2, shuffle=True, seed=42)
135
+ data
136
+
137
+ split_ds = data['test'].remove_columns('__index_level_0__').train_test_split(test_size=0.5, shuffle=True, seed=42)
138
+ split_ds
139
+
140
+ train_data=data['train']
141
+ test_data=split_ds['train']
142
+ val_data=split_ds['test']
143
+
144
+ train_data[0]
145
+
146
+ def tokenize_fun(data):
147
+ return tokenizer(data['text'],padding=True,truncation=True,return_tensors='pt')
148
+
149
+ tokenized_train_data=train_data.map(tokenize_fun,batched=True)
150
+
151
+ tokenized_train_data.features
152
+
153
+ !pip install evaluate
154
+ import evaluate
155
+ accuracy=evaluate.load('accuracy')
156
+
157
+ import numpy as np
158
+
159
+ def compute_metrics(eval_pred):
160
+ predictions, labels = eval_pred
161
+ predictions = np.argmax(predictions, axis=1)
162
+ return accuracy.compute(predictions=predictions, references=labels)
163
+
164
+ tokenized_test_data=test_data.map(tokenize_fun,batched=True)
165
+ tokenized_val_data=val_data.map(tokenize_fun,batched=True)
166
+
167
+ from huggingface_hub import login
168
+ login()
169
+
170
+ from transformers import Trainer,TrainingArguments
171
+
172
+ training_args=TrainingArguments(
173
+ output_dir='./quest_model',
174
+ learning_rate=2e-3,
175
+ per_device_eval_batch_size=16,
176
+ per_device_train_batch_size=16,
177
+ num_train_epochs=2,
178
+ weight_decay=0.01,
179
+ eval_strategy='epoch',
180
+ save_strategy='epoch',
181
+ load_best_model_at_end=True,
182
+ push_to_hub=True
183
+ )
184
+
185
+ trainer=Trainer(
186
+ model=model,
187
+ tokenizer=tokenizer,
188
+ args=training_args,
189
+ train_dataset=tokenized_train_data,
190
+ eval_dataset=tokenized_val_data,
191
+ compute_metrics=compute_metrics
192
+ )
193
+ trainer.train()
194
+
195
+ model.config.id2label
196
+
197
+ import os
198
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
199
+
200
+ model.save_pretrained('./quest_model')
201
+ tokenizer.save_pretrained("./quest_model")
202
+
203
+ tokenized_train_data[0]['text']
204
+
205
+ from transformers import pipeline
206
+ pipe=pipeline('text-classification',model='Nainglinthu/quest_model')
207
+ output=pipe('Hexal Australia Pty Ltd v Roche Therapeutics Inc (2005) 66 IPR 325, the likelihood of irreparable harm was regarded by Stone J as, indeed, a separate element that had to be established by an applicant for an interlocutory injunction.')
208
+ output
209
+
210
+ !pip install --upgrade gradio
211
+ import gradio as gr
212
+ from transformers import pipeline
213
+
214
+ # Initialize the pipeline
215
+ pipe = pipeline('text-classification', model='Nainglinthu/quest_model')
216
+
217
+ # Function to classify text
218
+ def classify_text(input_text):
219
+ output = pipe(input_text)
220
+ return output
221
+
222
+ # Create Gradio interface
223
+ interface = gr.Interface(
224
+ fn=classify_text, # Function to call
225
+ inputs="text", # Input type (text box)
226
+ outputs="json", # Output type (JSON for displaying result)
227
+ title="Legal Text Classifier", # Title of the Gradio app
228
+ description="Classify legal text using the Nainglinthu/quest_model!", # Description
229
+ )
230
+
231
+ # Launch the Gradio app
232
+ interface.launch()