import gradio as gr import json import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from collections import namedtuple from component2date import time2date fields = ['device', 'model_name', 'max_source_length', 'max_target_length', 'beam_size'] params = namedtuple('params', field_names=fields) args = params( device="cuda" if torch.cuda.is_available() else "cpu", model_name='facebook/mbart-large-50-many-to-many-mmt', max_source_length=256, max_target_length=256, beam_size=1 ) model = AutoModelForSeq2SeqLM.from_pretrained( "Huy1432884/db_retrieval", use_auth_token="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" ) model.eval() if "mbart" in args.model_name.lower(): tokenizer = AutoTokenizer.from_pretrained( args.model_name, src_lang="vi_VN", tgt_lang="vi_VN" ) else: tokenizer = AutoTokenizer.from_pretrained(args.model_name) with open("output2url.json") as f: output2url = json.loads(list(f)[0]) def text_analysis(text): text = text.lower() inputs = tokenizer( [text], text_target=None, padding="longest", max_length=args.max_source_length, truncation=True, return_tensors="pt", ) for k, v in inputs.items(): inputs[k] = v.to(args.device) if "mbart" in args.model_name: inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"] outputs = model.generate( **inputs, max_length=args.max_target_length, num_beams=args.beam_size, early_stopping=True, ) output_sentences = tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True ) out = json.loads("{" + output_sentences[0] + "}") if out['LOẠI BIỂU ĐỒ']=='dashboard': if out['CHU KỲ THỜI GIAN']!='tháng': chu_ky_in = 'ngày' else: chu_ky_in = 'tháng' out['CHU KỲ THỜI GIAN']='ngày' if out['CHU KỲ THỜI GIAN'] not in ['ngày', 'tháng'] else out['CHU KỲ THỜI GIAN'] check_dashboard = out['ĐƠN VỊ']+"_"+chu_ky_in out['DB URL'] = output2url[check_dashboard] out['DATE'] = time2date(out) out['FINAL URL'] = "https://vsds.viettel.vn"+ out['DB URL'] + "?toDate=" + str(out['DATE']).replace("-", "").replace("-", "") show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN', 'DB URL', 'DATE', 'FINAL URL']} elif out['LOẠI BIỂU ĐỒ']=='biểu đồ': show = {i: out[i] for i in ['LOẠI BIỂU ĐỒ', 'ĐƠN VỊ', 'CHU KỲ THỜI GIAN']} else: show = out return show demo = gr.Interface( text_analysis, gr.Textbox(placeholder="Enter sentence here..."), ["json"], examples=[ ["Mở dashboard vtc ngày hôm qua"], ["Mở biểu đồ cột td ngày này"], ["Hãy mở biểu đồ cơ cấu của tập đoàn trong ngày hôm nay"], ["Tháng này, vtc cần tôi mở biểu đồ rank để cập nhật danh sách khách hàng"], ["Các thông số NAT ngày hôm qua đã được ghi nhận trên đát bọt"], ["Hôm nay hãy mở của Viettel tt không gian mạng Viettel vtcc để kiểm tra"], ["Mở DB CTM ngày gốc"], ["Tôi đã sử dụng Dashboard để truy cập thông tin qti vào ngày hôm nay"], ["Trưởng phòng đã ra lệnh mở biểu đồ kết hợp đường và cột cho toàn tập đoàn vào hôm nay"] ], ) demo.launch()