GigachatProj / app.py
han7ter's picture
more trestings
32785a3
raw
history blame
1.89 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_name = "DeepPavlov/rubert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
texts = [
"Я хочу купить дом у своей тёти, как мне это сделать?",
"У меня прорвало трубу в доме, звонил в ЖКХ, они не отвечают.",
"Я убил человека и совершал много плохих действий"
]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)
num_labels = model.config.num_labels
labels = ["купля-продажа", "нарушение закона", "проблема с трубопроводом"][:num_labels]
for text, pred in zip(texts, predictions):
print(f"Текст: {text}")
for i, score in enumerate(pred):
if i < len(labels):
print(f"{labels[i]}: {score:.4f}")
else:
print(f"Класс {i}: {score:.4f} (метка не определена)")
print("---")
with gr.Blocks() as demo:
gr.Markdown("## Результаты классификации")
for text, pred in zip(texts, predictions):
with gr.Group():
gr.Textbox(text, label="Исходный текст", interactive=False)
for i, score in enumerate(pred):
if i < len(labels):
gr.Textbox(f"{labels[i]}: {score:.4f}",
label=f"Вероятность класса {i}",
interactive=False)
gr.Markdown("### Логи работы модели")
demo.launch()