mc0c0z commited on
Commit
a307848
·
1 Parent(s): ec6055f

Update space

Browse files
app.py CHANGED
@@ -1,63 +1,290 @@
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ import os
2
+ import sys
3
+
4
  import gradio as gr
5
+ import html
6
+ from tqdm import tqdm
7
+ import torch
8
+
9
+ from transformers import MBartForConditionalGeneration, AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, AutoModelForTokenClassification, pipeline
10
+
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ from underthesea import word_tokenize
14
+
15
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
16
+
17
+ # Load multi task model
18
+ bartpho_mt_base = MBartForConditionalGeneration.from_pretrained("mc0c0z/BARTPho-multi-task")
19
+ bartpho_mt_base_tokenizer = AutoTokenizer.from_pretrained("mc0c0z/BARTPho-multi-task")
20
+ bartpho_mt_base.to(device)
21
+
22
+ bartpho_mt = MBartForConditionalGeneration.from_pretrained("mc0c0z/BARTPho-Large-multi-task")
23
+ bartpho_mt_tokenizer = AutoTokenizer.from_pretrained("mc0c0z/BARTPho-Large-multi-task")
24
+ bartpho_mt.to(device)
25
+
26
+ def segmenter(text):
27
+ text = html.unescape(text)
28
+ tokens = word_tokenize(text)
29
+ result = []
30
+ for token in tokens:
31
+ if ' ' in token:
32
+ result.append(token.replace(' ', '_'))
33
+ else:
34
+ result.append(token)
35
+ return result
36
+
37
+ class MultiTaskModel:
38
+ def __init__(self, model, tokenizer, device):
39
+ self.model = model
40
+ self.tokenizer = tokenizer
41
+ self.device = device
42
+
43
+ def get_prompt(self, task):
44
+ if task == 'sa':
45
+ return "Classify the sentiment: "
46
+ elif task == 'mt-en-vi':
47
+ return "Translate English to Vietnamese: "
48
+ elif task == 'mt-vi-en':
49
+ return "Translate Vietnamese to English: "
50
+ else:
51
+ return ""
52
+
53
+ def inference(self, task, sentence, device):
54
+ # Tiền xử lý câu đầu vào tương tự như trong CustomDataset
55
+ tokenized_text = segmenter(sentence)
56
+ source = self.get_prompt(task) + " ".join(tokenized_text)
57
+
58
+ # Tokenize input
59
+ inputs = self.tokenizer(source, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
60
+
61
+ # Di chuyển input sang device
62
+ input_ids = inputs["input_ids"].to(device)
63
+ attention_mask = inputs["attention_mask"].to(device)
64
+
65
+ # Sinh dự đoán
66
+ self.model.eval()
67
+ with torch.no_grad():
68
+ generated_output = self.model.generate(input_ids, attention_mask=attention_mask, max_length=128)
69
+
70
+ # Giải mã dự đoán
71
+ prediction = self.tokenizer.decode(generated_output[0], skip_special_tokens=True)
72
+
73
+ if task == 'sa':
74
+ class_names = ["Negative", "Positive"]
75
+ return class_names[int(prediction[0])]
76
+ return html.unescape(prediction)
77
+
78
+ #Load SA model
79
+ class CustomModel(nn.Module):
80
+ def __init__(self, bert_model):
81
+ super(CustomModel, self).__init__()
82
+ self.bert = bert_model
83
+ self.mlp = nn.Sequential(
84
+ nn.Linear(768 * 5, 512), # 768*5 cho BERT
85
+ nn.ReLU(),
86
+ nn.Linear(512, 256),
87
+ nn.ReLU(),
88
+ nn.Linear(256, 3) # num_classes là số lượng lớp trong bài toán
89
+ )
90
+
91
+ def forward(self, input_ids, attention_mask):
92
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
93
+
94
+ # Lấy 5 lớp ẩn cuối cùng của token [CLS]
95
+ last_hidden_states = outputs.hidden_states[-5:]
96
+ cls_embeddings = torch.cat([state[:, 0, :] for state in last_hidden_states], dim=1)
97
+
98
+ # Đưa qua MLP
99
+ logits = self.mlp(cls_embeddings)
100
+ return logits
101
+
102
+ ## PhoBERT
103
+ phobert_sa = AutoModel.from_pretrained("vinai/phobert-base", output_hidden_states=True)
104
+ phobert_sa_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
105
+ phobert_sa = CustomModel(phobert_sa)
106
+ phobert_sa.load_state_dict(torch.load('sa_model\phobert_sentiment_analysis.pth', map_location=device))
107
+ phobert_sa.to(device)
108
+
109
+ ## PhoBERTv2
110
+ phobertv2_sa = AutoModel.from_pretrained("vinai/phobert-base-v2", output_hidden_states=True)
111
+ phobertv2_sa_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
112
+ phobertv2_sa = CustomModel(phobertv2_sa)
113
+ phobertv2_sa.load_state_dict(torch.load('sa_model\phobertv2_sentiment_analysis.pth', map_location=device))
114
+ phobertv2_sa.to(device)
115
+
116
+ ## Multilingual BERT
117
+ m_bert_sa = AutoModel.from_pretrained("google-bert/bert-base-multilingual-cased", output_hidden_states=True)
118
+ m_bert_sa_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-cased")
119
+ m_bert_sa = CustomModel(m_bert_sa)
120
+ m_bert_sa.load_state_dict(torch.load('sa_model\\bert_model_sentiment_analysis.pth', map_location=device))
121
+ m_bert_sa.to(device)
122
+
123
+ # Load Q&A model
124
+ roberta_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/Vietnamese-QA-fine-tuned")
125
+ roberta_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/Vietnamese-QA-fine-tuned")
126
+ roberta_qa.to(device)
127
+
128
+ # Load NER model
129
+ label_map = {
130
+ 'B-LOC': 0,
131
+ 'B-MISC': 1,
132
+ 'B-ORG': 2,
133
+ 'B-PER': 3,
134
+ 'I-LOC': 4,
135
+ 'I-MISC': 5,
136
+ 'I-ORG': 6,
137
+ 'I-PER': 7,
138
+ 'O': 8
139
+ }
140
+
141
+ ## PhoBERT
142
+ phobert_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER-PhoBERT", num_labels=len(label_map))
143
+ phobert_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER-PhoBERT")
144
+ phobert_ner.to(device)
145
+
146
+ ## PhoBERTv2
147
+ phobertv2_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER-PhoBERTv2", num_labels=len(label_map))
148
+ phobertv2_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER-PhoBERTv2")
149
+ phobertv2_ner.to(device)
150
+
151
+ ## Multilingual BERT
152
+ m_bert_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER_MultilingualBERT", num_labels=len(label_map))
153
+ m_bert_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER_MultilingualBERT")
154
+ m_bert_ner.to(device)
155
+
156
+ # Inference function
157
+ def sentiment_inference(model, tokenizer, text, device):
158
+ # Segment the input text
159
+ text = " ".join(segmenter(text))
160
+
161
+ # Tokenize the segmented text
162
+ inputs = tokenizer(
163
+ text,
164
+ padding='max_length',
165
+ truncation=True,
166
+ max_length=128,
167
+ return_tensors='pt'
168
+ )
169
+
170
+ # Move inputs to the correct device
171
+ input_ids = inputs['input_ids'].to(device)
172
+ attention_mask = inputs['attention_mask'].to(device)
173
+
174
+ # Ensure inputs have the correct shape
175
+ input_ids = input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids
176
+ attention_mask = attention_mask.unsqueeze(0) if attention_mask.dim() == 1 else attention_mask
177
+
178
+ # Perform inference
179
+ model.eval()
180
+ with torch.no_grad():
181
+ outputs = model(input_ids, attention_mask)
182
+ _, preds = torch.max(outputs, dim=1)
183
+
184
+ # Map predictions to class names
185
+ class_names = ["Negative", "Positive", "Neutral"]
186
+ return class_names[preds.cpu().item()]
187
+
188
+ def multitask_inference(model, tokenizer, text, task, device):
189
+ multitask_model = MultiTaskModel(model, tokenizer, device)
190
+ return multitask_model.inference(task, text, device)
191
+
192
+ def qa_inference(model, tokenizer, question, context, device):
193
+ qa_pipeline = pipeline('question-answering', model=model, tokenizer=tokenizer)
194
+ res = qa_pipeline(question=question, context=context)
195
+ return res['answer']
196
+
197
+ def ner_inference(model, tokenizer, text, device):
198
+ predictions = []
199
+ # Tokenize the segmented text
200
+ inputs = tokenizer(
201
+ text,
202
+ padding='max_length',
203
+ truncation=True,
204
+ max_length=128,
205
+ return_tensors='pt'
206
+ )
207
+
208
+ # Move inputs to the correct device
209
+ input_ids = inputs['input_ids'].to(device)
210
+ attention_mask = inputs['attention_mask'].to(device)
211
+
212
+ # Perform inference
213
+ model.eval()
214
+ with torch.no_grad():
215
+ outputs = model(input_ids, attention_mask)
216
+ _, preds = torch.max(outputs.logits, dim=2)
217
+
218
+ # Convert predictions to labels
219
+ id_to_label = {v: k for k, v in label_map.items()}
220
+ predictions = preds[attention_mask.bool()].cpu().numpy().flatten()
221
+ labels = [id_to_label[p] for p in predictions]
222
+
223
+ # Decode the input ids to tokens
224
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=True)
225
+
226
+ # Combine tokens with their NER labels
227
+ ner_tags = list(zip(tokens, labels))
228
+
229
+ return ner_tags
230
+
231
+ def process_input(input_text, context, task):
232
+ results = {}
233
+
234
+ if task == "Sentiment Analysis":
235
+ results["PhoBERT"] = sentiment_inference(phobert_sa, phobert_sa_tokenizer, input_text, device)
236
+ results["PhoBERTv2"] = sentiment_inference(phobertv2_sa, phobertv2_sa_tokenizer, input_text, device)
237
+ results["Multilingual BERT"] = sentiment_inference(m_bert_sa, m_bert_sa_tokenizer, input_text, device)
238
+ results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "sa", device)
239
+ results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "sa", device)
240
+ elif task == "English to Vietnamese":
241
+ results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "mt-en-vi", device)
242
+ results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "mt-en-vi", device)
243
+ elif task == "Vietnamese to English":
244
+ results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "mt-vi-en", device)
245
+ results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "mt-vi-en", device)
246
+ elif task == "Question Answering":
247
+ results["RoBERTa"] = qa_inference(roberta_qa, roberta_qa_tokenizer, input_text, context, device)
248
+ elif task == "Named Entity Recognition":
249
+ results["PhoBERT"] = ner_inference(phobert_ner, phobert_ner_tokenizer, input_text, device)
250
+ results["PhoBERTv2"] = ner_inference(phobertv2_ner, phobertv2_ner_tokenizer, input_text, device)
251
+ results["Multilingual BERT"] = ner_inference(m_bert_ner, m_bert_ner_tokenizer, input_text, device)
252
+ return results
253
 
254
+ with gr.Blocks() as iface:
255
+ gr.Markdown("# Multi-task NLP Demo")
256
+ gr.Markdown("Perform sentiment analysis, machine translation, question answering, or named entity recognition using various models.")
257
+
258
+ with gr.Row():
259
+ task = gr.Radio(["Sentiment Analysis", "Question Answering", "Named Entity Recognition", "English to Vietnamese", "Vietnamese to English"], label="Task")
260
+
261
+ with gr.Row():
262
+ input_text = gr.Textbox(label="Input Text")
263
+ context = gr.Textbox(label="Context", visible=False)
264
+
265
+ output = gr.JSON(label="Results")
266
+
267
+ submit = gr.Button("Submit")
268
+
269
+ def on_task_change(task):
270
+ if task == "Question Answering":
271
+ return {
272
+ input_text: gr.update(label="Question", visible=True),
273
+ context: gr.update(visible=True)
274
+ }
275
+ else:
276
+ return {
277
+ input_text: gr.update(label="Input Text", visible=True),
278
+ context: gr.update(visible=False)
279
+ }
280
+
281
+ task.change(on_task_change, task, [input_text, context])
282
+
283
+ submit.click(
284
+ process_input,
285
+ inputs=[input_text, context, task],
286
+ outputs=output
287
+ )
288
 
289
  if __name__ == "__main__":
290
+ iface.launch(share=True)
sa_model/bert_model_sentiment_analysis.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:852b53ae6d6f1db4129b1de8a87eee9d12b3a2407ec2c9c827d523194103e879
3
+ size 719896142
sa_model/phobert_sentiment_analysis.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e98704ecef05aaef8209231fdaf73040a5ca01ca9dc3baad9a2a31d20c257c3
3
+ size 548474843
sa_model/phobertv2_sentiment_analysis.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b96944fa9531778d34bbd299c5d1ba6581dbb638486867002edc31c3ce15696
3
+ size 548475261