File size: 2,760 Bytes
81d4014
 
ff287e4
81d4014
 
 
 
 
 
 
ac24b2e
81d4014
 
 
 
 
 
 
 
d064352
ac24b2e
81d4014
51a8540
 
 
 
81d4014
ac24b2e
51a8540
 
 
 
 
 
81d4014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8853ea7
81d4014
 
 
 
8853ea7
81d4014
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# from responses import start
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "BSC-LT/salamandraTA-2b"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
# Move model to GPU if available
languages = [ "Spanish", "Catalan", "English", "French", "German", "Italian", "Portuguese", "Euskera", "Galician",
             "Bulgarian", "Czech", "Lithuanian", "Croatian", "Dutch", "Romanian", "Danish", "Greek", "Finnish",
             "Hungarian", "Slovak", "Slovenian", "Estonian", "Polish", "Latvian", "Swedish", "Maltese",
             "Irish", "Aranese", "Aragonese", "Asturian" ]

example_sentence = ["Ahir se'n va anar, va agafar les seves coses i es va posar a navegar."]


@spaces.GPU(duration=120)
def translate(input_text, source, target):
    sentences = input_text.split('\n')
    generated_text = []
    for sentence in sentences:
      prompt = f'[{source}] {sentence} \n[{target}]'

      input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
      output_ids = model.generate(input_ids, max_length=500, num_beams=5)
      input_length = input_ids.shape[1]

      generated_text.append(tokenizer.decode(output_ids[0, input_length:], skip_special_tokens=True).strip())

    return '\n'.join(generated_text), ""


with gr.Blocks() as demo:
    gr.HTML("""<html>
  <head>
    <style>
      h1 {
        text-align: center;
      }
    </style>
  </head>
  <body>
    <h1>SalamandraTA 2B Translate</h1>
  </body>
</html>""")
    with gr.Row():
        with gr.Column():
            source_language_dropdown = gr.Dropdown(choices=languages,
                                                   value="English",
                                                  label="Source Language")
            input_textbox = gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text")
        with gr.Column():
            target_language_dropdown = gr.Dropdown(choices=languages,
                                                   value="Greek",
                                                   label="Target Language")
            translated_textbox = gr.Textbox(lines=5, placeholder="", label="Translated Text")
    info_label = gr.HTML("")
    btn = gr.Button("Translate")
    btn.click(translate, inputs=[input_textbox,
                                 source_language_dropdown,
                                 target_language_dropdown],
                                   outputs=[translated_textbox, info_label])
    gr.Examples(example_sentence, inputs=[input_textbox])

if __name__ == "__main__":
    demo.launch()