AritORR commited on
Commit
bf645f4
·
1 Parent(s): f753300
Files changed (1) hide show
  1. app.py +63 -43
app.py CHANGED
@@ -1,46 +1,66 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
 
 
 
 
 
4
 
5
  model_name = "DeepPavlov/rubert-base-cased"
 
 
 
 
 
 
 
 
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
-
9
- texts = [
10
- "Я хочу купить дом у своей тёти, как мне это сделать?",
11
- меня прорвало трубу в доме, звонил в ЖКХ, они не отвечают.",
12
- убил человека и совершал много плохих действий"
13
- ]
14
-
15
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
16
-
17
- with torch.no_grad():
18
- outputs = model(**inputs)
19
- predictions = torch.softmax(outputs.logits, dim=1)
20
-
21
- num_labels = model.config.num_labels
22
- labels = ["купля-продажа", "нарушение закона", "проблема с трубопроводом"][:num_labels]
23
-
24
- for text, pred in zip(texts, predictions):
25
- print(f"Текст: {text}")
26
- for i, score in enumerate(pred):
27
- if i < len(labels):
28
- print(f"{labels[i]}: {score:.4f}")
29
- else:
30
- print(f"Класс {i}: {score:.4f} (метка не определена)")
31
- print("---")
32
-
33
- with gr.Blocks() as demo:
34
- gr.Markdown("## Результаты классификации")
35
- for text, pred in zip(texts, predictions):
36
- with gr.Group():
37
- gr.Textbox(text, label="Исходный текст", interactive=False)
38
- for i, score in enumerate(pred):
39
- if i < len(labels):
40
- gr.Textbox(f"{labels[i]}: {score:.4f}",
41
- label=f"Вероятность класса {i}",
42
- interactive=False)
43
-
44
- gr.Markdown("### Логи работы модели")
45
-
46
- demo.launch()
 
 
 
 
 
1
+ import datasets
2
+ import evaluate
3
+ import pandas as pd
4
+ import numpy as np
5
+ from datasets import Dataset
6
+ from sklearn.model_selection import train_test_split
7
+ from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
8
+ TrainingArguments, Trainer)
9
 
10
  model_name = "DeepPavlov/rubert-base-cased"
11
+
12
+ # Login using e.g. `huggingface-cli login` to access this dataset
13
+ splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
14
+ df = pd.read_parquet("hf://datasets/mteb/RuSciBenchOECDClassification/" + splits["train"])
15
+
16
+ # Конвертируем датафрейм в Dataset
17
+ train, test = train_test_split(df, test_size=0.2)
18
+ train = Dataset.from_pandas(train)
19
+ test = Dataset.from_pandas(test)
20
+
21
+ # Выполняем предобработку текста
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ def tokenize_function(examples):
25
+ return tokenizer(examples['text'], padding='max_length', truncation=True)
26
+
27
+ tokenized_train = train.map(tokenize_function)
28
+ tokenized_test = test.map(tokenize_function)
29
+
30
+ # Загружаем предобученную модель
31
+ model = AutoModelForSequenceClassification.from_pretrained(
32
+ model_name,
33
+ num_labels=28)
34
+
35
+ # Задаем параметры обучения
36
+ training_args = TrainingArguments(
37
+ output_dir = 'test_trainer_log',
38
+ evaluation_strategy = 'epoch',
39
+ per_device_train_batch_size = 6,
40
+ per_device_eval_batch_size = 6,
41
+ num_train_epochs = 5,
42
+ report_to='none')
43
+
44
+ # Определяем как считать метрику
45
+ metric = evaluate.load('f1')
46
+ def compute_metrics(eval_pred):
47
+ logits, labels = eval_pred
48
+ predictions = np.argmax(logits, axis=-1)
49
+ return metric.compute(predictions=predictions, references=labels)
50
+
51
+ # Выполняем обучение
52
+ trainer = Trainer(
53
+ model = model,
54
+ args = training_args,
55
+ train_dataset = tokenized_train,
56
+ eval_dataset = tokenized_test,
57
+ compute_metrics = compute_metrics)
58
+
59
+ trainer.train()
60
+
61
+ # Сохраняем модель
62
+ save_directory = './pt_save_pretrained'
63
+ #tokenizer.save_pretrained(save_directory)
64
+ model.save_pretrained(save_directory)
65
+ #alternatively save the trainer
66
+ #trainer.save_model('CustomModels/CustomHamSpam')