AritORR commited on
Commit
f753300
·
2 Parent(s): c53610b 32785a3

Merge remote-tracking branch 'origin/main'

Browse files
Files changed (1) hide show
  1. app.py +43 -63
app.py CHANGED
@@ -1,66 +1,46 @@
1
- import datasets
2
- import evaluate
3
- import pandas as pd
4
- import numpy as np
5
- from datasets import Dataset
6
- from sklearn.model_selection import train_test_split
7
- from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
8
- TrainingArguments, Trainer)
9
 
10
  model_name = "DeepPavlov/rubert-base-cased"
11
-
12
- # Login using e.g. `huggingface-cli login` to access this dataset
13
- splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
14
- df = pd.read_parquet("hf://datasets/mteb/RuSciBenchOECDClassification/" + splits["train"])
15
-
16
- # Конвертируем датафрейм в Dataset
17
- train, test = train_test_split(df, test_size=0.2)
18
- train = Dataset.from_pandas(train)
19
- test = Dataset.from_pandas(test)
20
-
21
- # Выполняем предобработку текста
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
-
24
- def tokenize_function(examples):
25
- return tokenizer(examples['text'], padding='max_length', truncation=True)
26
-
27
- tokenized_train = train.map(tokenize_function)
28
- tokenized_test = test.map(tokenize_function)
29
-
30
- # Загружаем предобученную модель
31
- model = AutoModelForSequenceClassification.from_pretrained(
32
- model_name,
33
- num_labels=28)
34
-
35
- # Задаем параметры обучения
36
- training_args = TrainingArguments(
37
- output_dir = 'test_trainer_log',
38
- evaluation_strategy = 'epoch',
39
- per_device_train_batch_size = 6,
40
- per_device_eval_batch_size = 6,
41
- num_train_epochs = 5,
42
- report_to='none')
43
-
44
- # Определяем как считать метрику
45
- metric = evaluate.load('f1')
46
- def compute_metrics(eval_pred):
47
- logits, labels = eval_pred
48
- predictions = np.argmax(logits, axis=-1)
49
- return metric.compute(predictions=predictions, references=labels)
50
-
51
- # Выполняем обучение
52
- trainer = Trainer(
53
- model = model,
54
- args = training_args,
55
- train_dataset = tokenized_train,
56
- eval_dataset = tokenized_test,
57
- compute_metrics = compute_metrics)
58
-
59
- trainer.train()
60
-
61
- # Сохраняем модель
62
- save_directory = './pt_save_pretrained'
63
- #tokenizer.save_pretrained(save_directory)
64
- model.save_pretrained(save_directory)
65
- #alternatively save the trainer
66
- #trainer.save_model('CustomModels/CustomHamSpam')
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
 
 
 
 
 
4
 
5
  model_name = "DeepPavlov/rubert-base-cased"
 
 
 
 
 
 
 
 
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
+
9
+ texts = [
10
+ "Я хочу купить дом у своей тёти, как мне это сделать?",
11
+ меня прорвало трубу в доме, звонил в ЖКХ, они не отвечают.",
12
+ убил человека и совершал много плохих действий"
13
+ ]
14
+
15
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
16
+
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+ predictions = torch.softmax(outputs.logits, dim=1)
20
+
21
+ num_labels = model.config.num_labels
22
+ labels = ["купля-продажа", "нарушение закона", "проблема с трубопроводом"][:num_labels]
23
+
24
+ for text, pred in zip(texts, predictions):
25
+ print(f"Текст: {text}")
26
+ for i, score in enumerate(pred):
27
+ if i < len(labels):
28
+ print(f"{labels[i]}: {score:.4f}")
29
+ else:
30
+ print(f"Класс {i}: {score:.4f} (метка не определена)")
31
+ print("---")
32
+
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("## Результаты классификации")
35
+ for text, pred in zip(texts, predictions):
36
+ with gr.Group():
37
+ gr.Textbox(text, label="Исходный текст", interactive=False)
38
+ for i, score in enumerate(pred):
39
+ if i < len(labels):
40
+ gr.Textbox(f"{labels[i]}: {score:.4f}",
41
+ label=f"Вероятность класса {i}",
42
+ interactive=False)
43
+
44
+ gr.Markdown("### Логи работы модели")
45
+
46
+ demo.launch()