|
|
import streamlit as st |
|
|
from transformers import pipeline |
|
|
import time |
|
|
|
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="cross-encoder/nli-distilroberta-base") |
|
|
|
|
|
|
|
|
st.title("Text Classification App") |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a text file containing keywords", type=["txt"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
content = uploaded_file.read().decode("utf-8") |
|
|
keywords = [line.strip() for line in content.splitlines() if line.strip()] |
|
|
|
|
|
|
|
|
categories = ["shopping", "gaming", "streaming"] |
|
|
|
|
|
|
|
|
shopping_words = [] |
|
|
gaming_words = [] |
|
|
streaming_words = [] |
|
|
unknown_words = [] |
|
|
|
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
pause_button = st.button("Pause") |
|
|
stop_button = st.button("Stop") |
|
|
continue_button = st.button("Continue") |
|
|
paused = False |
|
|
stopped = False |
|
|
current_index = 0 |
|
|
|
|
|
|
|
|
def classify_keywords(keywords, categories, start_index=0): |
|
|
global paused, stopped, current_index |
|
|
total_keywords = len(keywords) |
|
|
for i, word in enumerate(keywords[start_index:], start=start_index): |
|
|
current_index = i |
|
|
if stopped: |
|
|
break |
|
|
if paused: |
|
|
time.sleep(0.5) |
|
|
continue |
|
|
|
|
|
|
|
|
result = classifier(word, categories) |
|
|
best_category = result['labels'][0] |
|
|
score = result['scores'][0] |
|
|
|
|
|
|
|
|
if best_category == "shopping" and score > 0.5: |
|
|
shopping_words.append(word) |
|
|
elif best_category == "gaming" and score > 0.5: |
|
|
gaming_words.append(word) |
|
|
elif best_category == "streaming" and score > 0.5: |
|
|
streaming_words.append(word) |
|
|
else: |
|
|
unknown_words.append(word) |
|
|
|
|
|
|
|
|
progress = (current_index + 1) / total_keywords |
|
|
progress_bar.progress(progress) |
|
|
|
|
|
|
|
|
update_results() |
|
|
|
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
|
|
|
def update_results(): |
|
|
|
|
|
st.session_state.shopping_text = "\n".join(shopping_words) |
|
|
st.session_state.gaming_text = "\n".join(gaming_words) |
|
|
st.session_state.streaming_text = "\n".join(streaming_words) |
|
|
st.session_state.unknown_text = "\n".join(unknown_words) |
|
|
|
|
|
|
|
|
if st.button("Start"): |
|
|
stopped = False |
|
|
paused = False |
|
|
current_index = 0 |
|
|
classify_keywords(keywords, categories, start_index=current_index) |
|
|
|
|
|
|
|
|
if pause_button: |
|
|
paused = True |
|
|
st.write("Classification paused.") |
|
|
|
|
|
|
|
|
if continue_button and paused: |
|
|
paused = False |
|
|
st.write("Classification resumed.") |
|
|
classify_keywords(keywords, categories, start_index=current_index) |
|
|
|
|
|
|
|
|
if stop_button: |
|
|
stopped = True |
|
|
st.write("Classification stopped.") |
|
|
|
|
|
|
|
|
st.header("Shopping Keywords") |
|
|
if 'shopping_text' not in st.session_state: |
|
|
st.session_state.shopping_text = "" |
|
|
st.text_area("Copy the shopping keywords here:", value=st.session_state.shopping_text, height=200, key="shopping") |
|
|
|
|
|
st.header("Gaming Keywords") |
|
|
if 'gaming_text' not in st.session_state: |
|
|
st.session_state.gaming_text = "" |
|
|
st.text_area("Copy the gaming keywords here:", value=st.session_state.gaming_text, height=200, key="gaming") |
|
|
|
|
|
st.header("Streaming Keywords" |