Cicciokr commited on
Commit
10e698d
·
verified ·
1 Parent(s): f646995

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -50,10 +50,14 @@ input_text = st.text_area(
50
  #tokenizer = GPT2Tokenizer.from_pretrained("Cicciokr/GPT2-Latin-GenText")
51
  model_name = "facebook/mbart-large-50"
52
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
53
- tokenizer = AutoTokenizer.from_pretrained(model_name) # Latino (la_XX)
 
 
 
 
54
  #tokenizer.pad_token_id = tokenizer.eos_token_id
55
  generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
56
-
57
 
58
  # Se l'utente ha inserito (o selezionato) un testo
59
  if input_text:
@@ -77,3 +81,22 @@ if input_text:
77
  st.write(f" Frase predetta: {generated_text}\n")
78
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
79
  print(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  #tokenizer = GPT2Tokenizer.from_pretrained("Cicciokr/GPT2-Latin-GenText")
51
  model_name = "facebook/mbart-large-50"
52
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
53
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
54
+
55
+ model_name_mio = "Cicciokr/mbart50-large-latin"
56
+ model_mio = AutoModelForSeq2SeqLM.from_pretrained(model_name_mio)
57
+ tokenizer_mio = AutoTokenizer.from_pretrained(tokenizer_mio)
58
  #tokenizer.pad_token_id = tokenizer.eos_token_id
59
  generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
60
+ generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
61
 
62
  # Se l'utente ha inserito (o selezionato) un testo
63
  if input_text:
 
81
  st.write(f" Frase predetta: {generated_text}\n")
82
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
83
  print(output)
84
+
85
+ inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
86
+ #output = model.generate(**inputs, max_length=512, num_return_sequences=1)
87
+ output_mio = model_mio.generate(
88
+ **inputs_mio,
89
+ max_length=512,
90
+ num_return_sequences=1,
91
+ do_sample=True,
92
+ temperature=0.8,
93
+ top_k=50,
94
+ top_p=0.95
95
+ )
96
+ generated_text_mio = tokenizer_mio.decode(output_mio[0], skip_special_tokens=True)
97
+ st.subheader("Risultato Mio:")
98
+ if 'input_text_value_correct' in st.session_state:
99
+ st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
100
+ st.write(f" Frase predetta: {generated_text_mio}\n")
101
+ #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
102
+ print(output)