Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, MT5Model, AutoModelForSeq2SeqLM, BartForConditionalGeneration
|
3 |
|
4 |
# Frasi di esempio
|
5 |
examples = [
|
@@ -51,16 +51,16 @@ input_text = st.text_area(
|
|
51 |
#model_name = "morenolq/bart-it"
|
52 |
model_name = "Cicciokr/BART-la-s"
|
53 |
model = BartForConditionalGeneration.from_pretrained(model_name)
|
54 |
-
tokenizer =
|
55 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
56 |
|
57 |
model_name_mio = "Cicciokr/BART-CC100-la"
|
58 |
model_mio = BartForConditionalGeneration.from_pretrained(model_name_mio)
|
59 |
-
tokenizer_mio =
|
60 |
-
tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
|
61 |
|
62 |
-
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
63 |
-
generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
|
64 |
|
65 |
# Se l'utente ha inserito (o selezionato) un testo
|
66 |
if input_text:
|
@@ -73,19 +73,15 @@ if input_text:
|
|
73 |
#generated_text = output[0]["generated_text"]
|
74 |
if 'input_text_value_correct' in st.session_state:
|
75 |
st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
|
|
|
|
|
76 |
st.subheader("Risultato BART TheLatinLibrary:")
|
77 |
st.write(f" Frase predetta: {generated_text}\n")
|
78 |
#st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
|
79 |
#print(output)
|
80 |
-
|
81 |
inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
|
82 |
-
output_mio = model_mio.generate(
|
83 |
-
**inputs_mio,
|
84 |
-
max_length=64,
|
85 |
-
do_sample=True,
|
86 |
-
num_beams=1,
|
87 |
-
forced_bos_token_id=tokenizer.bos_token_id
|
88 |
-
)
|
89 |
print(output_mio)
|
90 |
generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
|
91 |
#generated_text_mio = output_mio[0]["generated_text"]
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, MT5Model, AutoModelForSeq2SeqLM, BartForConditionalGeneration, BartTokenizer
|
3 |
|
4 |
# Frasi di esempio
|
5 |
examples = [
|
|
|
51 |
#model_name = "morenolq/bart-it"
|
52 |
model_name = "Cicciokr/BART-la-s"
|
53 |
model = BartForConditionalGeneration.from_pretrained(model_name)
|
54 |
+
tokenizer = BartTokenizer.from_pretrained(model_name)
|
55 |
+
#tokenizer.pad_token_id = tokenizer.eos_token_id
|
56 |
|
57 |
model_name_mio = "Cicciokr/BART-CC100-la"
|
58 |
model_mio = BartForConditionalGeneration.from_pretrained(model_name_mio)
|
59 |
+
tokenizer_mio = BartTokenizer.from_pretrained(model_name_mio)
|
60 |
+
#tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
|
61 |
|
62 |
+
#generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
|
63 |
+
#generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
|
64 |
|
65 |
# Se l'utente ha inserito (o selezionato) un testo
|
66 |
if input_text:
|
|
|
73 |
#generated_text = output[0]["generated_text"]
|
74 |
if 'input_text_value_correct' in st.session_state:
|
75 |
st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
|
76 |
+
st.write(f" Parola corretta: {input_text}\n")
|
77 |
+
st.write(f" -----------------------------------------------------------\n")
|
78 |
st.subheader("Risultato BART TheLatinLibrary:")
|
79 |
st.write(f" Frase predetta: {generated_text}\n")
|
80 |
#st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
|
81 |
#print(output)
|
82 |
+
st.write(f" -----------------------------------------------------------\n")
|
83 |
inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
|
84 |
+
output_mio = model_mio.generate(**inputs_mio, forced_bos_token_id=tokenizer.bos_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
print(output_mio)
|
86 |
generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
|
87 |
#generated_text_mio = output_mio[0]["generated_text"]
|