darkbreakerk's picture
Refactor + convert onnx model
280d87f
raw
history blame contribute delete
411 Bytes
from transformers import BertForTokenClassification
from .config_train import device, pretrain_model_name
from .load_data import tag2idx
# Fine-tuning BERT for token classification
model = BertForTokenClassification.from_pretrained(
# "bert-base-multilingual-cased",
pretrain_model_name,
num_labels=len(tag2idx),
output_attentions = False,
output_hidden_states = False
)
model.to(device)