Spaces:
Sleeping
Sleeping
import gradio as gr | |
from tow.model_byt5.tokenizer import Tokenizer_byt5 | |
from tow.model_byt5.model import Transformer_byt5 | |
import json | |
import torch | |
from huggingface_hub import hf_hub_download | |
model_weights_path = hf_hub_download(repo_id="towing/byt5-base-alibi-mt", filename="pytorch_model.bin") | |
model_config_path = hf_hub_download(repo_id="towing/byt5-base-alibi-mt", filename="config.json") | |
def translate(inputs): | |
with open(model_config_path, 'r') as f: | |
config = json.load(f) | |
state_dict = torch.load(model_weights_path, map_location=torch.device('cpu')) | |
model = Transformer_byt5(config=config) | |
model.load_state_dict(state_dict) | |
model = model.eval() | |
tokenizer = Tokenizer_byt5() | |
ids = tokenizer(inputs, max_length=512) | |
len_pad = 512 - len(ids) | |
if len_pad > 0: | |
ids = ids + [0 for x in range(len_pad)] | |
inputs = torch.tensor([ids]).to(torch.device('cpu')) | |
outputs = model.generate(inputs, max_length=512, stream=True) | |
text = '' | |
for value in outputs: | |
text = tokenizer.ids2text(value.tolist()[0]) | |
yield text | |
return text | |
demo = gr.Interface( | |
fn=translate, | |
inputs=[ | |
gr.components.Textbox(label="input", value="zh2en:一个描述实际事物的函数,其中的高频信息往往对应着很小的 “振幅”, 否则整个函数会很奇怪是个压扁的 “弹簧” ,不具实际意义。"), | |
], | |
outputs=["text"], | |
cache_examples=False, | |
title="Translation", | |
description="Support tasks: en2es, en2ja, en2zh, ja2zh, es2zh, es2ja, as well as their reverse language pairs.") | |
demo.launch(debug=True, share=True, server_name="0.0.0.0") | |