data-silence commited on
Commit
5396265
1 Parent(s): 8380ce4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -25
app.py CHANGED
@@ -1,30 +1,46 @@
1
- from transformers import pipeline
 
 
2
 
3
- # Загрузка модели через pipeline
4
- classifier = pipeline("text-classification", model="data-silence/news_classifier_ft")
 
 
5
 
6
- # Словарь для преобразования меток
7
- category_mapper = {
8
- 'LABEL_0': 'climate',
9
- 'LABEL_1': 'conflicts',
10
- 'LABEL_2': 'culture',
11
- 'LABEL_3': 'economy',
12
- 'LABEL_4': 'gloss',
13
- 'LABEL_5': 'health',
14
- 'LABEL_6': 'politics',
15
- 'LABEL_7': 'science',
16
- 'LABEL_8': 'society',
17
- 'LABEL_9': 'sports',
18
- 'LABEL_10': 'travel'
19
  }
20
 
21
- def classify(text):
22
- result = classifier(text)
23
- category = category_mapper[result[0]['label']]
24
- score = result[0]['score']
25
- return {"category": category, "confidence": score}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Для Gradio интерфейса
28
- def run_inference(text):
29
- result = classify(text)
30
- return f"Predicted category: {result['category']} (confidence: {result['confidence']:.2f})"
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
+ # Загрузка модели и токенизатора
6
+ model_name = "your-username/your-model-name" # Замените на путь к вашей модели на HF
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
+ # Перевод модели в режим оценки
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = model.to(device)
13
+ model.eval()
14
+
15
+ # Словарь для маппинга индексов на категории
16
+ id2label = {
17
+ 0: 'climate', 1: 'conflicts', 2: 'culture', 3: 'economy', 4: 'gloss',
18
+ 5: 'health', 6: 'politics', 7: 'science', 8: 'society', 9: 'sports', 10: 'travel'
 
 
 
 
19
  }
20
 
21
+ def predict(text):
22
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ predicted_label_id = outputs.logits.argmax(-1).item()
26
+ predicted_label = id2label[predicted_label_id]
27
+
28
+ # Получаем вероятности для всех классов
29
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
+ probs_dict = {id2label[i]: float(prob) for i, prob in enumerate(probabilities[0])}
31
+
32
+ return predicted_label, probs_dict
33
+
34
+ # Создание интерфейса Gradio
35
+ iface = gr.Interface(
36
+ fn=predict,
37
+ inputs=gr.Textbox(lines=5, label="Введите текст новости"),
38
+ outputs=[
39
+ gr.Label(label="Предсказанная категория"),
40
+ gr.Label(label="Вероятности категорий")
41
+ ],
42
+ title="Классификатор новостей",
43
+ description="Введите текст новости, и модель предскажет её категорию."
44
+ )
45
 
46
+ iface.launch()