Cicciokr commited on
Commit
a022c55
·
verified ·
1 Parent(s): 6aba9f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, MT5Model, AutoModelForSeq2SeqLM, BartForConditionalGeneration
3
 
4
  # Frasi di esempio
5
  examples = [
@@ -51,16 +51,16 @@ input_text = st.text_area(
51
  #model_name = "morenolq/bart-it"
52
  model_name = "Cicciokr/BART-la-s"
53
  model = BartForConditionalGeneration.from_pretrained(model_name)
54
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
55
- tokenizer.pad_token_id = tokenizer.eos_token_id
56
 
57
  model_name_mio = "Cicciokr/BART-CC100-la"
58
  model_mio = BartForConditionalGeneration.from_pretrained(model_name_mio)
59
- tokenizer_mio = AutoTokenizer.from_pretrained(model_name_mio, use_fast=False)
60
- tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
61
 
62
- generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
63
- generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
64
 
65
  # Se l'utente ha inserito (o selezionato) un testo
66
  if input_text:
@@ -73,19 +73,15 @@ if input_text:
73
  #generated_text = output[0]["generated_text"]
74
  if 'input_text_value_correct' in st.session_state:
75
  st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
 
 
76
  st.subheader("Risultato BART TheLatinLibrary:")
77
  st.write(f" Frase predetta: {generated_text}\n")
78
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
79
  #print(output)
80
-
81
  inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
82
- output_mio = model_mio.generate(
83
- **inputs_mio,
84
- max_length=64,
85
- do_sample=True,
86
- num_beams=1,
87
- forced_bos_token_id=tokenizer.bos_token_id
88
- )
89
  print(output_mio)
90
  generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
91
  #generated_text_mio = output_mio[0]["generated_text"]
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, MT5Model, AutoModelForSeq2SeqLM, BartForConditionalGeneration, BartTokenizer
3
 
4
  # Frasi di esempio
5
  examples = [
 
51
  #model_name = "morenolq/bart-it"
52
  model_name = "Cicciokr/BART-la-s"
53
  model = BartForConditionalGeneration.from_pretrained(model_name)
54
+ tokenizer = BartTokenizer.from_pretrained(model_name)
55
+ #tokenizer.pad_token_id = tokenizer.eos_token_id
56
 
57
  model_name_mio = "Cicciokr/BART-CC100-la"
58
  model_mio = BartForConditionalGeneration.from_pretrained(model_name_mio)
59
+ tokenizer_mio = BartTokenizer.from_pretrained(model_name_mio)
60
+ #tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
61
 
62
+ #generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
63
+ #generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
64
 
65
  # Se l'utente ha inserito (o selezionato) un testo
66
  if input_text:
 
73
  #generated_text = output[0]["generated_text"]
74
  if 'input_text_value_correct' in st.session_state:
75
  st.write(f" Parola corretta: {st.session_state['input_text_value_correct']}\n")
76
+ st.write(f" Parola corretta: {input_text}\n")
77
+ st.write(f" -----------------------------------------------------------\n")
78
  st.subheader("Risultato BART TheLatinLibrary:")
79
  st.write(f" Frase predetta: {generated_text}\n")
80
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
81
  #print(output)
82
+ st.write(f" -----------------------------------------------------------\n")
83
  inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
84
+ output_mio = model_mio.generate(**inputs_mio, forced_bos_token_id=tokenizer.bos_token_id)
 
 
 
 
 
 
85
  print(output_mio)
86
  generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
87
  #generated_text_mio = output_mio[0]["generated_text"]