Spaces:
Runtime error
Runtime error
File size: 3,565 Bytes
00c4428 d566459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import json
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from collections import namedtuple
from component2date import time2date
fields = ['device', 'model_name', 'max_source_length', 'max_target_length', 'beam_size']
params = namedtuple('params', field_names=fields)
args = params(
device="cuda" if torch.cuda.is_available() else "cpu",
model_name='facebook/mbart-large-50-many-to-many-mmt',
max_source_length=256,
max_target_length=256,
beam_size=1
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"Huy1432884/db_retrieval",
use_auth_token="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"
)
model.eval()
if "mbart" in args.model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, src_lang="vi_VN", tgt_lang="vi_VN"
)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
with open("output2url.json") as f:
output2url = json.loads(list(f)[0])
def text_analysis(text):
text = text.lower()
inputs = tokenizer(
[text],
text_target=None,
padding="longest",
max_length=args.max_source_length,
truncation=True,
return_tensors="pt",
)
for k, v in inputs.items():
inputs[k] = v.to(args.device)
if "mbart" in args.model_name:
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]
outputs = model.generate(
**inputs,
max_length=args.max_target_length,
num_beams=args.beam_size,
early_stopping=True,
)
output_sentences = tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
out = json.loads("{" + output_sentences[0] + "}")
if out['LOẠI BIỂU ĐỒ']=='dashboard':
if out['CHU KỲ THỜI GIAN']!='tháng':
chu_ky_in = 'ngày'
else:
chu_ky_in = 'tháng'
out['CHU KỲ THỜI GIAN']='ngày' if out['CHU KỲ THỜI GIAN'] not in ['ngày', 'tháng'] else out['CHU KỲ THỜI GIAN']
check_dashboard = out['ĐƠN VỊ']+"_"+chu_ky_in
out['DB URL'] = output2url[check_dashboard]
out['DATE'] = time2date(out)
out['FINAL URL'] = "https://vsds.viettel.vn"+ out['DB URL'] + "?toDate=" + str(out['DATE']).replace("-", "").replace("-", "")
show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN', 'DB URL', 'DATE', 'FINAL URL']}
elif out['LOẠI BIỂU ĐỒ']=='biểu đồ':
show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN']}
else:
show = out
return show
demo = gr.Interface(
text_analysis,
gr.Textbox(placeholder="Enter sentence here..."),
["json"],
examples=[
["Mở dashboard vtc ngày hôm qua"],
["Mở biểu đồ cột td ngày này"],
["Hãy mở biểu đồ cơ cấu của tập đoàn trong ngày hôm nay"],
["Tháng này, vtc cần tôi mở biểu đồ rank để cập nhật danh sách khách hàng"],
["Các thông số NAT ngày hôm qua đã được ghi nhận trên đát bọt"],
["Hôm nay hãy mở của Viettel tt không gian mạng Viettel vtcc để kiểm tra"],
["Mở DB CTM ngày gốc"],
["Tôi đã sử dụng Dashboard để truy cập thông tin qti vào ngày hôm nay"],
["Trưởng phòng đã ra lệnh mở biểu đồ kết hợp đường và cột cho toàn tập đoàn vào hôm nay"]
],
)
demo.launch() |