choging commited on
Commit
6fc5a2a
·
verified ·
1 Parent(s): d5ee1d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -3,25 +3,37 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  model_name = "winninghealth/WiNGPT-Babel"
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) #, torch_dtype="auto", device_map="auto"
7
 
8
  def translate(text):
9
  prompt = f"<|im_start|>system\n中英互译下面的内容<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
10
- inputs = tokenizer([prompt], return_tensors="pt") #.to(model.device)
11
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
12
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
  return translated_text
14
 
15
- def predict(text):
16
- return translate(text)
 
 
 
 
 
 
17
 
 
 
 
18
  iface = gr.Interface(
19
- fn=predict,
20
- inputs=gr.Textbox(lines=5, label="输入文本 (支持中英互译)"),
21
- outputs=gr.Textbox(label="翻译结果"),
 
 
 
 
22
  title="WiNGPT-Babel 翻译 Demo",
23
  description="基于 WiNGPT-Babel 模型的翻译演示。支持中英互译。",
24
- examples=[["Hello, world!"], ["你好,世界!"]],
25
  )
26
 
27
  iface.launch()
 
3
 
4
  model_name = "winninghealth/WiNGPT-Babel"
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
7
 
8
  def translate(text):
9
  prompt = f"<|im_start|>system\n中英互译下面的内容<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
10
+ inputs = tokenizer([prompt], return_tensors="pt")
11
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
12
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
  return translated_text
14
 
15
+ def custom_api(text_list, source_lang, target_lang):
16
+ # 假设你的模型只支持中英互译
17
+ if source_lang == "zh-CN" and target_lang == "en":
18
+ translated_list = [translate(text) for text in text_list]
19
+ elif source_lang == "en" and target_lang == "zh-CN":
20
+ translated_list = [translate(text) for text in text_list]
21
+ else:
22
+ return {"error": "Unsupported language pair"}
23
 
24
+ return {"translations": [{"detected_source_lang": source_lang, "text": translated_text} for translated_text in translated_list]}
25
+
26
+ # 创建 Gradio 接口
27
  iface = gr.Interface(
28
+ fn=custom_api,
29
+ inputs=[
30
+ gr.Textbox(lines=5, label="输入文本列表 (支持中英互译)", placeholder='["Hello", "World"]'),
31
+ gr.Textbox(label="源语言", placeholder="zh-CN"),
32
+ gr.Textbox(label="目标语言", placeholder="en")
33
+ ],
34
+ outputs=gr.JSON(label="翻译结果"),
35
  title="WiNGPT-Babel 翻译 Demo",
36
  description="基于 WiNGPT-Babel 模型的翻译演示。支持中英互译。",
 
37
  )
38
 
39
  iface.launch()