zay12121 commited on
Commit
ae69e94
·
verified ·
1 Parent(s): ba26220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -30
app.py CHANGED
@@ -2,49 +2,36 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
- st.title("AI Model Text Classifier")
6
 
7
- # Use valid model names
8
  model_list = [
9
- ("roberta-base", "RoBERTa Base"),
10
- ("bert-base-uncased", "BERT Base"),
11
- ("distilbert-base-uncased", "DistilBERT Base")
12
  ]
13
 
14
  @st.cache_resource
15
- def load_model(model_name):
16
- return AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
17
 
18
- @st.cache_resource
19
- def load_tokenizer(model_name):
20
- return AutoTokenizer.from_pretrained(model_name)
21
-
22
- # Load models + tokenizers
23
- models = [(load_model(name), load_tokenizer(name), label) for name, label in model_list]
24
 
25
- # User input
26
- text_input = st.text_area("Enter text to classify:")
27
 
28
- # Choose model
29
- selected_model_label = st.selectbox("Select a model:", [label for _, _, label in models])
30
-
31
- # Find selected model
32
- for model, tokenizer, label in models:
33
- if label == selected_model_label:
34
- selected_model = model
35
- selected_tokenizer = tokenizer
36
- break
37
 
38
  if st.button("Classify"):
39
- if text_input.strip() == "":
40
  st.warning("Please enter some text!")
41
  else:
42
- inputs = selected_tokenizer(text_input, return_tensors="pt", truncation=True)
 
43
  with torch.no_grad():
44
- outputs = selected_model(**inputs)
45
- logits = outputs.logits
46
  probs = torch.softmax(logits, dim=-1).squeeze().tolist()
47
-
48
- st.write("### Classification probabilities:")
49
  for i, prob in enumerate(probs):
50
  st.write(f"Class {i}: {prob:.4f}")
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ st.title("Text Sentiment Classifier")
6
 
7
+ # Valid fine-tuned models
8
  model_list = [
9
+ ("distilbert-base-uncased-finetuned-sst-2-english", "DistilBERT (SST-2)"),
10
+ ("textattack/roberta-base-imdb", "RoBERTa (IMDB Sentiment)")
 
11
  ]
12
 
13
  @st.cache_resource
14
+ def load_model_and_tokenizer(model_name):
15
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ return model, tokenizer
18
 
19
+ models = {label: load_model_and_tokenizer(name) for name, label in model_list}
 
 
 
 
 
20
 
21
+ # UI
22
+ text_input = st.text_area("Enter text:")
23
 
24
+ model_choice = st.selectbox("Choose model:", list(models.keys()))
 
 
 
 
 
 
 
 
25
 
26
  if st.button("Classify"):
27
+ if not text_input.strip():
28
  st.warning("Please enter some text!")
29
  else:
30
+ model, tokenizer = models[model_choice]
31
+ inputs = tokenizer(text_input, return_tensors="pt", truncation=True)
32
  with torch.no_grad():
33
+ logits = model(**inputs).logits
 
34
  probs = torch.softmax(logits, dim=-1).squeeze().tolist()
35
+ st.write("### Results:")
 
36
  for i, prob in enumerate(probs):
37
  st.write(f"Class {i}: {prob:.4f}")