中文多标签意图识别模型(BERT)

这是一个基于 bert-base-chinese 微调的多标签分类模型,支持以下任务:

对中文query进行分类

  • 多分类:意图识别(chat / simple question / complex question)
  • 二分类:是否时间相关、是否位置(LBS)相关

模型结构

  • 基础模型:bert-base-chinese
  • 输出层:一个 5 维的 sigmoid 多标签输出向量
    • [意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]

使用方法

import torch
from transformers import BertTokenizer
from bert_classifier_3 import BertMultiLabelClassifier

# 加载 tokenizer 和模型
bert_base = "bert-base-chinese"
model_id = "Xiaoxi2333/bert_multilabel_chinese"
tokenizer = BertTokenizer.from_pretrained(model_id)
model = BertMultiLabelClassifier(pretrained_model_path=bert_base, num_labels=5)
state_dict = torch.hub.load_state_dict_from_url(
    f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin",
    map_location="cpu"
)
model.load_state_dict(state_dict)
model.eval()

# 定义标签
intent_labels = ["chat", "simple question", "complex question"]
yesno_labels = ["否", "是"]

# 定义预测函数
def predict(query):
    enc = tokenizer(
        query,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    )
    with torch.no_grad():
        logits = model(enc["input_ids"], enc["attention_mask"])
        probs = torch.sigmoid(logits).squeeze(0)
        intent_index = torch.argmax(probs[:3]).item()
        is_time = int(probs[3] > 0.5)
        is_lbs = int(probs[4] > 0.5)

        return {
            "query": query,
            "意图": intent_labels[intent_index],
            "是否时间相关": yesno_labels[is_time],
            "是否lbs相关": yesno_labels[is_lbs],
            "原始概率": probs.tolist()
        }

# 示例查询
result = predict("明天北京天气怎么样?")
print(result)
Downloads last month
9
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Xiaoxi2333/bert_multilabel_chinese

Finetuned
(191)
this model