Spaces:
Runtime error
Runtime error
xyh1756
commited on
Commit
·
ad16774
1
Parent(s):
f35fe56
first commit
Browse files- app.py +16 -0
- bert-base-chinese/README.md +3 -0
- bert-base-chinese/config.json +25 -0
- bert-base-chinese/flax_model.msgpack +3 -0
- bert-base-chinese/pytorch_model.bin +3 -0
- bert-base-chinese/tf_model.h5 +3 -0
- bert-base-chinese/tokenizer.json +0 -0
- bert-base-chinese/tokenizer_config.json +3 -0
- bert-base-chinese/vocab.txt +0 -0
- bert/__init__.py +15 -0
- bert/modeling_jointbert.py +63 -0
- bert/module.py +23 -0
- book_model/config.json +27 -0
- book_model/pytorch_model.bin +3 -0
- book_model/training_args.bin +3 -0
- data/intent_label.txt +3 -0
- data/slot_label.txt +13 -0
- predictOnce.py +180 -0
app.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from predictOnce import Estimator
|
3 |
+
|
4 |
+
|
5 |
+
def predict(inputText):
|
6 |
+
global e
|
7 |
+
res = e.predict(inputText)
|
8 |
+
return res[0], res[1]
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
e = Estimator()
|
13 |
+
iface = gr.Interface(fn=e.predict, inputs=gr.inputs.Textbox(lines=2, label="输入语句", placeholder="输入要识别的语句..."),
|
14 |
+
outputs=[gr.outputs.Textbox(label="意图"), gr.outputs.Textbox(label="槽值")], live=True,
|
15 |
+
theme="huggingface", allow_screenshot=False, allow_flagging=False)
|
16 |
+
iface.launch(share=True)
|
bert-base-chinese/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: zh
|
3 |
+
---
|
bert-base-chinese/config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"directionality": "bidi",
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"layer_norm_eps": 1e-12,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"model_type": "bert",
|
15 |
+
"num_attention_heads": 12,
|
16 |
+
"num_hidden_layers": 12,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"pooler_fc_size": 768,
|
19 |
+
"pooler_num_attention_heads": 12,
|
20 |
+
"pooler_num_fc_layers": 3,
|
21 |
+
"pooler_size_per_head": 128,
|
22 |
+
"pooler_type": "first_token_transform",
|
23 |
+
"type_vocab_size": 2,
|
24 |
+
"vocab_size": 21128
|
25 |
+
}
|
bert-base-chinese/flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76df8425215fb9ede22e0393e356f82a99d84e79f078cd141afbbf9277460c8e
|
3 |
+
size 409168515
|
bert-base-chinese/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a693db616eaf647ed2bfe531e1fa446637358fc108a8bf04e8d4db17e837ee9
|
3 |
+
size 411577189
|
bert-base-chinese/tf_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:612acd33db45677c3d6ba70615336619dc65cddf1ecf9d39a22dd1934af4aff2
|
3 |
+
size 478309336
|
bert-base-chinese/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bert-base-chinese/tokenizer_config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_lower_case": false
|
3 |
+
}
|
bert-base-chinese/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bert/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
bert/modeling_jointbert.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers.modeling_bert import BertPreTrainedModel, BertModel, BertConfig
|
4 |
+
from torchcrf import CRF
|
5 |
+
from .module import IntentClassifier, SlotClassifier
|
6 |
+
|
7 |
+
|
8 |
+
class JointBERT(BertPreTrainedModel):
|
9 |
+
def __init__(self, config, args, intent_label_lst, slot_label_lst):
|
10 |
+
super(JointBERT, self).__init__(config)
|
11 |
+
self.args = args
|
12 |
+
self.num_intent_labels = len(intent_label_lst)
|
13 |
+
self.num_slot_labels = len(slot_label_lst)
|
14 |
+
self.bert = BertModel(config=config) # Load pretrained bert
|
15 |
+
|
16 |
+
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
|
17 |
+
self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
|
18 |
+
|
19 |
+
if args.use_crf:
|
20 |
+
self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
|
21 |
+
|
22 |
+
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
|
23 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask,
|
24 |
+
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
|
25 |
+
sequence_output = outputs[0]
|
26 |
+
pooled_output = outputs[1] # [CLS]
|
27 |
+
|
28 |
+
intent_logits = self.intent_classifier(pooled_output)
|
29 |
+
slot_logits = self.slot_classifier(sequence_output)
|
30 |
+
|
31 |
+
total_loss = 0
|
32 |
+
# 1. Intent Softmax
|
33 |
+
if intent_label_ids is not None:
|
34 |
+
if self.num_intent_labels == 1:
|
35 |
+
intent_loss_fct = nn.MSELoss()
|
36 |
+
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
|
37 |
+
else:
|
38 |
+
intent_loss_fct = nn.CrossEntropyLoss()
|
39 |
+
intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))
|
40 |
+
total_loss += intent_loss
|
41 |
+
|
42 |
+
# 2. Slot Softmax
|
43 |
+
if slot_labels_ids is not None:
|
44 |
+
if self.args.use_crf:
|
45 |
+
slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean')
|
46 |
+
slot_loss = -1 * slot_loss # negative log-likelihood
|
47 |
+
else:
|
48 |
+
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
|
49 |
+
# Only keep active parts of the loss
|
50 |
+
if attention_mask is not None:
|
51 |
+
active_loss = attention_mask.view(-1) == 1
|
52 |
+
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
|
53 |
+
active_labels = slot_labels_ids.view(-1)[active_loss]
|
54 |
+
slot_loss = slot_loss_fct(active_logits, active_labels)
|
55 |
+
else:
|
56 |
+
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
|
57 |
+
total_loss += self.args.slot_loss_coef * slot_loss
|
58 |
+
|
59 |
+
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
|
60 |
+
|
61 |
+
outputs = (total_loss,) + outputs
|
62 |
+
|
63 |
+
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
|
bert/module.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class IntentClassifier(nn.Module):
|
5 |
+
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
|
6 |
+
super(IntentClassifier, self).__init__()
|
7 |
+
self.dropout = nn.Dropout(dropout_rate)
|
8 |
+
self.linear = nn.Linear(input_dim, num_intent_labels)
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
x = self.dropout(x)
|
12 |
+
return self.linear(x)
|
13 |
+
|
14 |
+
|
15 |
+
class SlotClassifier(nn.Module):
|
16 |
+
def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
|
17 |
+
super(SlotClassifier, self).__init__()
|
18 |
+
self.dropout = nn.Dropout(dropout_rate)
|
19 |
+
self.linear = nn.Linear(input_dim, num_slot_labels)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.dropout(x)
|
23 |
+
return self.linear(x)
|
book_model/config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"JointBERT"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"directionality": "bidi",
|
7 |
+
"finetuning_task": "book",
|
8 |
+
"gradient_checkpointing": false,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-12,
|
15 |
+
"max_position_embeddings": 512,
|
16 |
+
"model_type": "bert",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"pooler_fc_size": 768,
|
21 |
+
"pooler_num_attention_heads": 12,
|
22 |
+
"pooler_num_fc_layers": 3,
|
23 |
+
"pooler_size_per_head": 128,
|
24 |
+
"pooler_type": "first_token_transform",
|
25 |
+
"type_vocab_size": 2,
|
26 |
+
"vocab_size": 21128
|
27 |
+
}
|
book_model/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6389ddb0d25ffbbea13eae6adbeb3a8e9dde3dd71ad811abd019862f51570ede
|
3 |
+
size 409203155
|
book_model/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e2a51024072e0d7bf7f5c695ade9f2bf7b52f85696d12e73389e95a8d63fe9c
|
3 |
+
size 1199
|
data/intent_label.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
UNK
|
2 |
+
query
|
3 |
+
chat
|
data/slot_label.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PAD
|
2 |
+
UNK
|
3 |
+
O
|
4 |
+
B-Author
|
5 |
+
I-Author
|
6 |
+
B-Book
|
7 |
+
I-Book
|
8 |
+
B-Press
|
9 |
+
I-Press
|
10 |
+
B-Tag
|
11 |
+
I-Tag
|
12 |
+
B-Topic
|
13 |
+
I-Topic
|
predictOnce.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import BertTokenizer
|
7 |
+
from bert.modeling_jointbert import JointBERT
|
8 |
+
|
9 |
+
|
10 |
+
class Estimator:
|
11 |
+
class Args:
|
12 |
+
adam_epsilon = 1e-08
|
13 |
+
batch_size = 16
|
14 |
+
data_dir = 'data'
|
15 |
+
device = 'cpu'
|
16 |
+
do_eval = True
|
17 |
+
do_train = False
|
18 |
+
dropout_rate = 0.1
|
19 |
+
eval_batch_size = 64
|
20 |
+
gradient_accumulation_steps = 1
|
21 |
+
ignore_index = 0
|
22 |
+
intent_label_file = 'data/intent_label.txt'
|
23 |
+
learning_rate = 5e-05
|
24 |
+
logging_steps = 50
|
25 |
+
max_grad_norm = 1.0
|
26 |
+
max_seq_len = 50
|
27 |
+
max_steps = -1
|
28 |
+
model_dir = 'book_model'
|
29 |
+
model_name_or_path = 'bert-base-chinese'
|
30 |
+
model_type = 'bert-chinese'
|
31 |
+
no_cuda = False
|
32 |
+
num_train_epochs = 5.0
|
33 |
+
save_steps = 200
|
34 |
+
seed = 1234
|
35 |
+
slot_label_file = 'data/slot_label.txt'
|
36 |
+
slot_loss_coef = 1.0
|
37 |
+
slot_pad_label = 'PAD'
|
38 |
+
task = 'book'
|
39 |
+
train_batch_size = 32
|
40 |
+
use_crf = False
|
41 |
+
warmup_steps = 0
|
42 |
+
weight_decay = 0.0
|
43 |
+
|
44 |
+
def __init__(self, args=Args):
|
45 |
+
self.intent_label_lst = [label.strip() for label in open(args.intent_label_file, 'r', encoding='utf-8')]
|
46 |
+
self.slot_label_lst = [label.strip() for label in open(args.slot_label_file, 'r', encoding='utf-8')]
|
47 |
+
|
48 |
+
# Check whether model exists
|
49 |
+
if not os.path.exists(args.model_dir):
|
50 |
+
raise Exception("Model doesn't exists! Train first!")
|
51 |
+
|
52 |
+
self.model = JointBERT.from_pretrained(args.model_dir,
|
53 |
+
args=args,
|
54 |
+
intent_label_lst=self.intent_label_lst,
|
55 |
+
slot_label_lst=self.slot_label_lst)
|
56 |
+
self.model.to(args.device)
|
57 |
+
self.model.eval()
|
58 |
+
self.args = args
|
59 |
+
self.tokenizer = BertTokenizer.from_pretrained(self.args.model_name_or_path)
|
60 |
+
|
61 |
+
def convert_input_to_tensor_data(self, input, tokenizer, pad_token_label_id,
|
62 |
+
cls_token_segment_id=0,
|
63 |
+
pad_token_segment_id=0,
|
64 |
+
sequence_a_segment_id=0,
|
65 |
+
mask_padding_with_zero=True):
|
66 |
+
# Setting based on the current model type
|
67 |
+
cls_token = tokenizer.cls_token
|
68 |
+
sep_token = tokenizer.sep_token
|
69 |
+
unk_token = tokenizer.unk_token
|
70 |
+
pad_token_id = tokenizer.pad_token_id
|
71 |
+
|
72 |
+
slot_label_mask = []
|
73 |
+
|
74 |
+
words = list(input)
|
75 |
+
tokens = []
|
76 |
+
for word in words:
|
77 |
+
word_tokens = tokenizer.tokenize(word)
|
78 |
+
if not word_tokens:
|
79 |
+
word_tokens = [unk_token] # For handling the bad-encoded word
|
80 |
+
tokens.extend(word_tokens)
|
81 |
+
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
82 |
+
slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))
|
83 |
+
|
84 |
+
# Account for [CLS] and [SEP]
|
85 |
+
special_tokens_count = 2
|
86 |
+
if len(tokens) > self.args.max_seq_len - special_tokens_count:
|
87 |
+
tokens = tokens[: (self.args.max_seq_len - special_tokens_count)]
|
88 |
+
slot_label_mask = slot_label_mask[:(self.args.max_seq_len - special_tokens_count)]
|
89 |
+
|
90 |
+
# Add [SEP] token
|
91 |
+
tokens += [sep_token]
|
92 |
+
token_type_ids = [sequence_a_segment_id] * len(tokens)
|
93 |
+
slot_label_mask += [pad_token_label_id]
|
94 |
+
|
95 |
+
# Add [CLS] token
|
96 |
+
tokens = [cls_token] + tokens
|
97 |
+
token_type_ids = [cls_token_segment_id] + token_type_ids
|
98 |
+
slot_label_mask = [pad_token_label_id] + slot_label_mask
|
99 |
+
|
100 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
101 |
+
|
102 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
|
103 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
104 |
+
|
105 |
+
# Zero-pad up to the sequence length.
|
106 |
+
padding_length = self.args.max_seq_len - len(input_ids)
|
107 |
+
input_ids = input_ids + ([pad_token_id] * padding_length)
|
108 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
109 |
+
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
110 |
+
slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
|
111 |
+
|
112 |
+
# Change to Tensor
|
113 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
114 |
+
attention_mask = torch.tensor([attention_mask], dtype=torch.long)
|
115 |
+
token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
|
116 |
+
slot_label_mask = torch.tensor([slot_label_mask], dtype=torch.long)
|
117 |
+
|
118 |
+
data = [input_ids, attention_mask, token_type_ids, slot_label_mask]
|
119 |
+
|
120 |
+
return data
|
121 |
+
|
122 |
+
def predict(self, input):
|
123 |
+
# Convert input file to TensorDataset
|
124 |
+
pad_token_label_id = self.args.ignore_index
|
125 |
+
batch = self.convert_input_to_tensor_data(input, self.tokenizer, pad_token_label_id)
|
126 |
+
|
127 |
+
# Predict
|
128 |
+
batch = tuple(t.to(self.args.device) for t in batch)
|
129 |
+
with torch.no_grad():
|
130 |
+
inputs = {"input_ids": batch[0],
|
131 |
+
"attention_mask": batch[1],
|
132 |
+
"token_type_ids": batch[2],
|
133 |
+
"intent_label_ids": None,
|
134 |
+
"slot_labels_ids": None}
|
135 |
+
outputs = self.model(**inputs)
|
136 |
+
_, (intent_logits, slot_logits) = outputs[:2]
|
137 |
+
|
138 |
+
# Intent Prediction
|
139 |
+
intent_pred = intent_logits.detach().cpu().numpy()
|
140 |
+
|
141 |
+
# Slot prediction
|
142 |
+
if self.args.use_crf:
|
143 |
+
# decode() in `torchcrf` returns list with best index directly
|
144 |
+
slot_preds = np.array(self.model.crf.decode(slot_logits))
|
145 |
+
else:
|
146 |
+
slot_preds = slot_logits.detach().cpu().numpy()
|
147 |
+
all_slot_label_mask = batch[3].detach().cpu().numpy()
|
148 |
+
|
149 |
+
intent_pred = np.argmax(intent_pred, axis=1)[0]
|
150 |
+
|
151 |
+
if not self.args.use_crf:
|
152 |
+
slot_preds = np.argmax(slot_preds, axis=2)
|
153 |
+
|
154 |
+
slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
|
155 |
+
slot_preds_list = []
|
156 |
+
|
157 |
+
for i in range(slot_preds.shape[1]):
|
158 |
+
if all_slot_label_mask[0, i] != pad_token_label_id:
|
159 |
+
slot_preds_list.append(slot_label_map[slot_preds[0][i]])
|
160 |
+
|
161 |
+
words = list(input)
|
162 |
+
slots = dict()
|
163 |
+
slot = str()
|
164 |
+
for i in range(len(words)):
|
165 |
+
if slot_preds_list[i] == 'O':
|
166 |
+
if slot == '':
|
167 |
+
continue
|
168 |
+
slots[slot_preds_list[i - 1].split('-')[1]] = slot
|
169 |
+
slot = str()
|
170 |
+
else:
|
171 |
+
slot += words[i]
|
172 |
+
if slot != '':
|
173 |
+
slots[slot_preds_list[len(words) - 1].split('-')[1]] = slot
|
174 |
+
return self.intent_label_lst[intent_pred], slots
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
e = Estimator()
|
179 |
+
while True:
|
180 |
+
print(e.predict(input(">>")))
|