中文多标签意图识别模型(BERT)
这是一个基于 bert-base-chinese
微调的多标签分类模型,支持以下任务:
对中文query进行分类
- 多分类:意图识别(chat / simple question / complex question)
- 二分类:是否时间相关、是否位置(LBS)相关
模型结构
- 基础模型:
bert-base-chinese
- 输出层:一个 5 维的 sigmoid 多标签输出向量
[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]
使用方法
import torch
from transformers import BertTokenizer
from bert_classifier_3 import BertMultiLabelClassifier
# 加载 tokenizer 和模型
bert_base = "bert-base-chinese"
model_id = "Xiaoxi2333/bert_multilabel_chinese"
tokenizer = BertTokenizer.from_pretrained(model_id)
model = BertMultiLabelClassifier(pretrained_model_path=bert_base, num_labels=5)
state_dict = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin",
map_location="cpu"
)
model.load_state_dict(state_dict)
model.eval()
# 定义标签
intent_labels = ["chat", "simple question", "complex question"]
yesno_labels = ["否", "是"]
# 定义预测函数
def predict(query):
enc = tokenizer(
query,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = torch.sigmoid(logits).squeeze(0)
intent_index = torch.argmax(probs[:3]).item()
is_time = int(probs[3] > 0.5)
is_lbs = int(probs[4] > 0.5)
return {
"query": query,
"意图": intent_labels[intent_index],
"是否时间相关": yesno_labels[is_time],
"是否lbs相关": yesno_labels[is_lbs],
"原始概率": probs.tolist()
}
# 示例查询
result = predict("明天北京天气怎么样?")
print(result)
- Downloads last month
- 9
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for Xiaoxi2333/bert_multilabel_chinese
Base model
google-bert/bert-base-chinese