Moleys commited on
Commit
9278a67
·
verified ·
1 Parent(s): 39829ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -23
app.py CHANGED
@@ -2,50 +2,65 @@ import torch
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import gradio as gr
4
 
5
- # Load model và tokenizer
6
  model_name = "b3x0m/hirashiba-xomdich-tokenizer"
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
 
 
11
 
12
  def translate_text(input_text):
13
- lines = input_text.split('\n') # Tách từng dòng
14
  translated_lines = []
15
-
16
  for line in lines:
17
  raw_text = line.strip()
18
  if not raw_text:
19
- translated_lines.append('') # Giữ dòng trống
20
  continue
21
-
22
- # Tokenize input
23
- inputs = tokenizer(raw_text, return_tensors="pt", padding=True, truncation=True).to(device)
24
-
25
- # Dịch với mô hình (không cần tính gradient)
 
 
 
 
 
 
26
  with torch.no_grad():
27
- output_tokens = model.generate(**inputs, max_length=512)
28
-
29
- # Giải mã kết quả
 
 
 
30
  translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
31
  translated_lines.append(translated_text)
32
-
33
  return '\n'.join(translated_lines)
34
 
35
  if __name__ == '__main__':
36
  with gr.Blocks() as app:
37
  gr.Markdown('## Chinese to Vietnamese Translation')
38
-
39
  with gr.Row():
40
  with gr.Column(scale=1):
41
  input_text = gr.Textbox(label='Input Chinese Text', lines=5, placeholder='Enter Chinese text here...')
42
  translate_button = gr.Button('Translate')
43
  output_text = gr.Textbox(label='Output Vietnamese Text', lines=5, interactive=False)
44
-
45
- translate_button.click(
46
- fn=translate_text,
47
- inputs=input_text,
48
- outputs=output_text
49
- )
50
-
51
- app.launch()
 
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import gradio as gr
4
 
 
5
  model_name = "b3x0m/hirashiba-xomdich-tokenizer"
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
9
+
10
+ # đảm bảo có pad_token để padding không lỗi
11
+ if tokenizer.pad_token is None:
12
+ # ưu tiên dùng eos_token làm pad nếu có
13
+ if tokenizer.eos_token is not None:
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ else:
16
+ tokenizer.add_special_tokens({"pad_token": "<pad>"})
17
+
18
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
19
+ # nếu vừa thêm pad_token mới, cập nhật cho model
20
+ model.config.pad_token_id = tokenizer.pad_token_id
21
 
22
  def translate_text(input_text):
23
+ lines = input_text.split('\n')
24
  translated_lines = []
25
+
26
  for line in lines:
27
  raw_text = line.strip()
28
  if not raw_text:
29
+ translated_lines.append('')
30
  continue
31
+
32
+ # KHÔNG trả về token_type_ids để tránh lỗi
33
+ inputs = tokenizer(
34
+ raw_text,
35
+ return_tensors="pt",
36
+ padding=True,
37
+ truncation=True,
38
+ max_length=1024, # tránh cảnh báo truncation
39
+ return_token_type_ids=False
40
+ ).to(device)
41
+
42
  with torch.no_grad():
43
+ # dùng max_new_tokens thay vì max_length cho sinh đầu ra
44
+ output_tokens = model.generate(
45
+ **inputs,
46
+ max_new_tokens=512
47
+ )
48
+
49
  translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
50
  translated_lines.append(translated_text)
51
+
52
  return '\n'.join(translated_lines)
53
 
54
  if __name__ == '__main__':
55
  with gr.Blocks() as app:
56
  gr.Markdown('## Chinese to Vietnamese Translation')
57
+
58
  with gr.Row():
59
  with gr.Column(scale=1):
60
  input_text = gr.Textbox(label='Input Chinese Text', lines=5, placeholder='Enter Chinese text here...')
61
  translate_button = gr.Button('Translate')
62
  output_text = gr.Textbox(label='Output Vietnamese Text', lines=5, interactive=False)
63
+
64
+ translate_button.click(fn=translate_text, inputs=input_text, outputs=output_text)
65
+
66
+ app.launch()