Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as TF | |
| import streamlit as st | |
| option = st.selectbox("Select a text analysis model:", ("BERT", "Fine-tuned BERT")) | |
| bert_path = "bert-base-uncased" | |
| if (option == "BERT"): | |
| tokenizer = AutoTokenizer.from_pretrained(bert_path) | |
| model = BertForSequenceClassification.from_pretrained(bert_path, num_labels=6) | |
| else: | |
| tweets_raw = pd.read_csv("train.csv", nrows=20) | |
| label_set = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] | |
| # Run encoding through model to get classification output. | |
| encoding = tokenizer.encode(txt, return_tensors='pt') | |
| result = model(encoding) | |
| # Transform logit to get probabilities. | |
| if (result.logits.size(dim=1) < 2): | |
| pad = (0, 1) | |
| result.logits = nn.functional.pad(result.logits, pad, "constant", 0) | |
| prediction = nn.functional.softmax(result.logits, dim=-1) | |
| neutralProb = prediction.data[0][neutralIndex] | |
| toxicProb = prediction.data[0][toxicIndex] | |
| # Write results | |
| st.write("Classification Probabilities") | |
| st.write(f"{neutralProb:.4f} - NEUTRAL") | |
| st.write(f"{toxicProb:.4f} - TOXIC") |