maximuspowers commited on
Commit
9b562d8
1 Parent(s): 61b8e6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizerFast, BertForTokenClassification
3
+ import gradio as gr
4
+
5
+ # Load tokenizer and model
6
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
7
+ model = BertForTokenClassification.from_pretrained('./saved_model3')
8
+ model.eval()
9
+ model.to('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Define label mappings
12
+ id2label = {
13
+ 0: 'O',
14
+ 1: 'B-STEREO',
15
+ 2: 'I-STEREO',
16
+ 3: 'B-GEN',
17
+ 4: 'I-GEN',
18
+ 5: 'B-UNFAIR',
19
+ 6: 'I-UNFAIR'
20
+ }
21
+
22
+ def predict_ner_tags(sentence):
23
+ inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
24
+ input_ids = inputs['input_ids'].to(model.device)
25
+ attention_mask = inputs['attention_mask'].to(model.device)
26
+
27
+ with torch.no_grad():
28
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
29
+ logits = outputs.logits
30
+ probabilities = torch.sigmoid(logits)
31
+ predicted_labels = (probabilities > 0.5).int()
32
+
33
+ result = []
34
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
35
+ for i, token in enumerate(tokens):
36
+ if token not in tokenizer.all_special_tokens:
37
+ label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
38
+ labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
39
+ result.append((token, labels))
40
+
41
+ return result
42
+
43
+ def format_output(result):
44
+ formatted_output = ""
45
+ for token, labels in result:
46
+ formatted_output += f"{token}: {', '.join(labels)}\n"
47
+ return formatted_output
48
+
49
+ iface = gr.Interface(
50
+ fn=predict_ner_tags,
51
+ inputs="text",
52
+ outputs="text",
53
+ title="Named Entity Recognition with BERT",
54
+ description="Enter a sentence to predict NER tags using BERT model trained for multi-label classification.",
55
+ examples=["Tall men are so clumsy."],
56
+ allow_flagging="never",
57
+ interpretation="default",
58
+ postprocessing_fn=format_output
59
+ )
60
+
61
+ if __name__ == "__main__":
62
+ iface.launch()