Cicciokr commited on
Commit
68c1bf8
·
verified ·
1 Parent(s): d696144

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -6,8 +6,7 @@ examples = [
6
  "Caesar in Gallia bellum gessit. <mask> instructae erant, sed subito",
7
  "Est autem et aliud genus testudinis, quod <mask>, quemadmodum quae supra scripta sunt",
8
  "Quemadmodum vero minores rotae duriores et <mask>, sic phalangae et iuga, in quibus partibus habent minora a centro ad capita intervalla",
9
- "illud additur, ne, qui certum ordinem ex <mask>, ulli vos alteri hominum generi haerere vereamini nec timeatis vos",
10
- "Gli italiani perdono le partite di calcio come se <mask> e perdono le guerre come se fossero partite di calcio"
11
  ]
12
  examples_correct = [
13
  "Omnes legiones",
@@ -53,9 +52,13 @@ input_text = st.text_area(
53
  model_name = "Cicciokr/BART-la-s"
54
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
55
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
 
 
 
 
 
 
56
 
57
- #model_name_mio = "Cicciokr/mbart50-large-latin"
58
- #tokenizer.pad_token_id = tokenizer.eos_token_id
59
  #generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
60
  #generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
61
 
@@ -67,16 +70,30 @@ if input_text:
67
  output = model.generate(
68
  **inputs,
69
  max_length=512,
70
- # num_beams=4,
71
- # num_return_sequences=1,
72
  do_sample=True,
73
  temperature=0.9
74
- # top_k=1,
75
  )
76
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
77
  if 'input_text_value_correct' in st.session_state:
78
  st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
79
- st.subheader("Risultato BART:")
80
  st.write(f" Frase predetta: {generated_text}\n")
81
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
82
  #print(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "Caesar in Gallia bellum gessit. <mask> instructae erant, sed subito",
7
  "Est autem et aliud genus testudinis, quod <mask>, quemadmodum quae supra scripta sunt",
8
  "Quemadmodum vero minores rotae duriores et <mask>, sic phalangae et iuga, in quibus partibus habent minora a centro ad capita intervalla",
9
+ "illud additur, ne, qui certum ordinem ex <mask>, ulli vos alteri hominum generi haerere vereamini nec timeatis vos"
 
10
  ]
11
  examples_correct = [
12
  "Omnes legiones",
 
52
  model_name = "Cicciokr/BART-la-s"
53
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
54
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
55
+ tokenizer.pad_token_id = tokenizer.eos_token_id
56
+
57
+ model_name_mio = "Cicciokr/BART-CC100-la"
58
+ model_mio = AutoModelForSeq2SeqLM.from_pretrained(model_name_mio)
59
+ tokenizer_mio = AutoTokenizer.from_pretrained(model_name_mio, use_fast=False)
60
+ tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
61
 
 
 
62
  #generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
63
  #generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
64
 
 
70
  output = model.generate(
71
  **inputs,
72
  max_length=512,
73
+ num_beams=4,
74
+ num_return_sequences=1,
75
  do_sample=True,
76
  temperature=0.9
77
+ top_k=1,
78
  )
79
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
80
  if 'input_text_value_correct' in st.session_state:
81
  st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
82
+ st.subheader("Risultato BART TheLatinLibrary:")
83
  st.write(f" Frase predetta: {generated_text}\n")
84
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
85
  #print(output)
86
+
87
+ inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
88
+ output_mio = model_mio.generate(
89
+ **inputs_mio,
90
+ max_length=512,
91
+ num_beams=4,
92
+ num_return_sequences=1,
93
+ do_sample=True,
94
+ temperature=0.9
95
+ top_k=1,
96
+ )
97
+ generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
98
+ st.subheader("Risultato BART CC100:")
99
+ st.write(f" Frase predetta: {generated_text_mio}\n")