Spaces:
Sleeping
Sleeping
| # demo_phobert_gradio.py | |
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| import torch | |
| import re | |
| import json | |
| import emoji | |
| import numpy as np | |
| from underthesea import word_tokenize | |
| from transformers import ( | |
| AutoConfig, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| ############################################################################### | |
| # TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN | |
| ############################################################################### | |
| emoji_mapping = { | |
| "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]", | |
| "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]", | |
| "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]", | |
| "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]", | |
| "🤑": "[satisfaction]", | |
| "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]", | |
| "😏": "[sarcasm]", | |
| "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]", | |
| "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]", | |
| "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]", | |
| "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]", | |
| "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]", | |
| "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]", | |
| "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]", | |
| "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]", | |
| "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]", | |
| "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]" | |
| } | |
| ############################################################################### | |
| # HÀM XỬ LÝ (COPY TỪ FILE TRAIN) | |
| ############################################################################### | |
| def replace_emojis(sentence, emoji_mapping): | |
| processed_sentence = [] | |
| for char in sentence: | |
| if char in emoji_mapping: | |
| processed_sentence.append(emoji_mapping[char]) | |
| elif not emoji.is_emoji(char): | |
| processed_sentence.append(char) | |
| return ''.join(processed_sentence) | |
| def remove_profanity(sentence): | |
| profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"] | |
| words = sentence.split() | |
| filtered = [w for w in words if w.lower() not in profane_words] | |
| return ' '.join(filtered) | |
| def remove_special_characters(sentence): | |
| return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence) | |
| def normalize_whitespace(sentence): | |
| return ' '.join(sentence.split()) | |
| def remove_repeated_characters(sentence): | |
| return re.sub(r"(.)\1{2,}", r"\1", sentence) | |
| def replace_numbers(sentence): | |
| return re.sub(r"\d+", "[number]", sentence) | |
| def tokenize_underthesea(sentence): | |
| tokens = word_tokenize(sentence) | |
| return " ".join(tokens) | |
| # Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng. | |
| try: | |
| with open("abbreviations.json", "r", encoding="utf-8") as f: | |
| abbreviations = json.load(f) | |
| except: | |
| abbreviations = {} | |
| def preprocess_sentence(sentence): | |
| # hạ thấp | |
| sentence = sentence.lower() | |
| # thay thế emoji | |
| sentence = replace_emojis(sentence, emoji_mapping) | |
| # loại bỏ từ nhạy cảm | |
| sentence = remove_profanity(sentence) | |
| # bỏ ký tự đặc biệt | |
| sentence = remove_special_characters(sentence) | |
| # chuẩn hoá khoảng trắng | |
| sentence = normalize_whitespace(sentence) | |
| # thay thế viết tắt | |
| words = sentence.split() | |
| replaced = [] | |
| for w in words: | |
| if w in abbreviations: | |
| replaced.append(" ".join(abbreviations[w])) | |
| else: | |
| replaced.append(w) | |
| sentence = " ".join(replaced) | |
| # bỏ bớt kí tự lặp | |
| sentence = remove_repeated_characters(sentence) | |
| # thay số thành [number] | |
| sentence = replace_numbers(sentence) | |
| # tokenize tiếng Việt | |
| sentence = tokenize_underthesea(sentence) | |
| return sentence | |
| ############################################################################### | |
| # LOAD CHECKPOINT | |
| ############################################################################### | |
| checkpoint_dir = "./checkpoint" # Folder checkpoint nằm trong cùng thư mục với file script | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Loading config...") | |
| config = AutoConfig.from_pretrained(checkpoint_dir) | |
| # Mapping id to label theo thứ tự bạn cung cấp | |
| custom_id2label = { | |
| 0: 'Anger', | |
| 1: 'Disgust', | |
| 2: 'Enjoyment', | |
| 3: 'Fear', | |
| 4: 'Other', | |
| 5: 'Sadness', | |
| 6: 'Surprise' | |
| } | |
| # Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng | |
| if hasattr(config, "id2label") and config.id2label: | |
| # Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping | |
| if all(label.startswith("LABEL_") for label in config.id2label.values()): | |
| id2label = custom_id2label | |
| else: | |
| id2label = {int(k): v for k, v in config.id2label.items()} | |
| else: | |
| id2label = custom_id2label # Sử dụng mapping mặc định nếu config không có id2label | |
| print("id2label loaded:", id2label) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) | |
| print("Loading model...") | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config) | |
| model.to(device) | |
| model.eval() | |
| ############################################################################### | |
| # HÀM PREDICT | |
| ############################################################################### | |
| # Mapping từ label đến thông điệp tương ứng | |
| label2message = { | |
| 'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.', | |
| 'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.', | |
| 'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!', | |
| 'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.', | |
| 'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.', | |
| 'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.', | |
| 'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.' | |
| } | |
| def predict_text(text: str) -> str: | |
| """Tiền xử lý, token hoá và chạy model => trả về label và thông điệp.""" | |
| text_proc = preprocess_sentence(text) | |
| inputs = tokenizer( | |
| [text_proc], | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| pred_id = outputs.logits.argmax(dim=-1).item() | |
| if pred_id in id2label: | |
| label = id2label[pred_id] | |
| message = label2message.get(label, "") | |
| if message: | |
| return f"Dự đoán cảm xúc: {label}. {message}" | |
| else: | |
| return f"Dự đoán cảm xúc: {label}." | |
| else: | |
| return f"Nhãn không xác định (id={pred_id})" | |
| ############################################################################### | |
| # GRADIO APP | |
| ############################################################################### | |
| def run_demo(input_text): | |
| predicted_emotion = predict_text(input_text) | |
| return predicted_emotion | |
| demo = gr.Interface( | |
| fn=run_demo, | |
| inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"), | |
| outputs=gr.Textbox(label="Kết quả"), | |
| title="PhoBERT Emotion Classification", | |
| description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |