File size: 3,565 Bytes
00c4428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d566459
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import json
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from collections import namedtuple
from component2date import time2date

fields = ['device', 'model_name', 'max_source_length', 'max_target_length', 'beam_size']
params = namedtuple('params', field_names=fields)

args = params(
    device="cuda" if torch.cuda.is_available() else "cpu",
    model_name='facebook/mbart-large-50-many-to-many-mmt',
    max_source_length=256,
    max_target_length=256,
    beam_size=1
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    "Huy1432884/db_retrieval",
    use_auth_token="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"
)

model.eval()

if "mbart" in args.model_name.lower():
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, src_lang="vi_VN", tgt_lang="vi_VN"
    )
else:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

with open("output2url.json") as f:
    output2url = json.loads(list(f)[0])


def text_analysis(text):
    text = text.lower()

    inputs = tokenizer(
        [text],
        text_target=None,
        padding="longest",
        max_length=args.max_source_length,
        truncation=True,
        return_tensors="pt",
    )

    for k, v in inputs.items():
        inputs[k] = v.to(args.device)


    if "mbart" in args.model_name:
        inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]

    outputs = model.generate(
        **inputs,
        max_length=args.max_target_length,
        num_beams=args.beam_size,
        early_stopping=True,
    )

    output_sentences = tokenizer.batch_decode(
        outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )
    out = json.loads("{" + output_sentences[0] + "}")
    if out['LOẠI BIỂU ĐỒ']=='dashboard':
        if out['CHU KỲ THỜI GIAN']!='tháng':
            chu_ky_in = 'ngày'
        else:
            chu_ky_in = 'tháng'
        out['CHU KỲ THỜI GIAN']='ngày' if out['CHU KỲ THỜI GIAN'] not in ['ngày', 'tháng'] else out['CHU KỲ THỜI GIAN']
        check_dashboard = out['ĐƠN VỊ']+"_"+chu_ky_in
        out['DB URL'] = output2url[check_dashboard]
        out['DATE'] = time2date(out)
        out['FINAL URL'] = "https://vsds.viettel.vn"+ out['DB URL'] + "?toDate=" + str(out['DATE']).replace("-", "").replace("-", "")
        show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN', 'DB URL', 'DATE', 'FINAL URL']}
    elif out['LOẠI BIỂU ĐỒ']=='biểu đồ':
        show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN']}
    else:
        show = out
    return show

demo = gr.Interface(
    text_analysis,
    gr.Textbox(placeholder="Enter sentence here..."),
    ["json"],
    examples=[
        ["Mở dashboard vtc ngày hôm qua"],
        ["Mở biểu đồ cột td ngày này"],
        ["Hãy mở biểu đồ cơ cấu của tập đoàn trong ngày hôm nay"],
        ["Tháng này, vtc cần tôi mở biểu đồ rank để cập nhật danh sách khách hàng"],
        ["Các thông số NAT ngày hôm qua đã được ghi nhận trên đát bọt"],
        ["Hôm nay hãy mở của Viettel tt không gian mạng Viettel vtcc để kiểm tra"],
        ["Mở DB CTM ngày gốc"],
        ["Tôi đã sử dụng Dashboard để truy cập thông tin qti vào ngày hôm nay"],
        ["Trưởng phòng đã ra lệnh mở biểu đồ kết hợp đường và cột cho toàn tập đoàn vào hôm nay"]
    ],
)

demo.launch()