Cicciokr commited on
Commit
6634984
·
verified ·
1 Parent(s): 7dfff7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -52,24 +52,23 @@ input_text = st.text_area(
52
  model_name = "Cicciokr/BART-la-s"
53
  model = BartForConditionalGeneration.from_pretrained(model_name)
54
  tokenizer = BartTokenizer.from_pretrained(model_name)
55
- #tokenizer.pad_token_id = tokenizer.eos_token_id
56
 
57
  model_name_mio = "Cicciokr/BART-CC100-la"
58
  model_mio = BartForConditionalGeneration.from_pretrained(model_name_mio)
59
  tokenizer_mio = BartTokenizer.from_pretrained(model_name_mio)
60
- #tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
61
 
62
  generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
63
- #generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
64
 
65
  # Se l'utente ha inserito (o selezionato) un testo
66
  if input_text:
67
  # Sostituiamo [MASK] con <mask> (lo tokenizer Roberta se lo aspetta così)
68
  #prompt = "Sostituisci la scritta [MASK] con le parole in latino mancanti per completare la frase: "+input_text
69
- inputs = tokenizer(input_text, return_tensors="pt")
70
  output = generator(
71
- input_text,
72
- forced_bos_token_id=tokenizer.bos_token_id,
73
  num_return_sequences=1,
74
  top_k=50, # 🔹 Maggiore varietà nelle scelte
75
  top_p=0.95, # 🔹 Nucleus sampling per migliorare il realismo
@@ -88,10 +87,9 @@ if input_text:
88
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
89
  #print(output)
90
  st.write(f" -----------------------------------------------------------\n")
91
- inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
92
- output_mio = model_mio.generate(
93
- **inputs_mio,
94
- forced_bos_token_id=tokenizer.bos_token_id,
95
  num_return_sequences=1,
96
  top_k=50, # 🔹 Maggiore varietà nelle scelte
97
  top_p=0.95, # 🔹 Nucleus sampling per migliorare il realismo
@@ -100,7 +98,7 @@ if input_text:
100
  max_length=50 # 🔹 Previene ripetizioni infinite
101
  )
102
  print(output_mio)
103
- generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
104
- #generated_text_mio = output_mio[0]["generated_text"]
105
  st.subheader("Risultato BART CC100:")
106
  st.write(f" Frase predetta: {generated_text_mio}\n")
 
52
  model_name = "Cicciokr/BART-la-s"
53
  model = BartForConditionalGeneration.from_pretrained(model_name)
54
  tokenizer = BartTokenizer.from_pretrained(model_name)
55
+ tokenizer.pad_token_id = tokenizer.eos_token_id
56
 
57
  model_name_mio = "Cicciokr/BART-CC100-la"
58
  model_mio = BartForConditionalGeneration.from_pretrained(model_name_mio)
59
  tokenizer_mio = BartTokenizer.from_pretrained(model_name_mio)
60
+ tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
61
 
62
  generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
63
+ generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
64
 
65
  # Se l'utente ha inserito (o selezionato) un testo
66
  if input_text:
67
  # Sostituiamo [MASK] con <mask> (lo tokenizer Roberta se lo aspetta così)
68
  #prompt = "Sostituisci la scritta [MASK] con le parole in latino mancanti per completare la frase: "+input_text
69
+ #inputs = tokenizer(input_text, return_tensors="pt")
70
  output = generator(
71
+ input_text,
 
72
  num_return_sequences=1,
73
  top_k=50, # 🔹 Maggiore varietà nelle scelte
74
  top_p=0.95, # 🔹 Nucleus sampling per migliorare il realismo
 
87
  #st.write(f" Frase predetta: {tokenizer.decode(output[0], skip_special_tokens=True)}\n")
88
  #print(output)
89
  st.write(f" -----------------------------------------------------------\n")
90
+ #inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
91
+ output_mio = generator_mio(
92
+ input_text,
 
93
  num_return_sequences=1,
94
  top_k=50, # 🔹 Maggiore varietà nelle scelte
95
  top_p=0.95, # 🔹 Nucleus sampling per migliorare il realismo
 
98
  max_length=50 # 🔹 Previene ripetizioni infinite
99
  )
100
  print(output_mio)
101
+ #generated_text_mio = tokenizer_mio.decode(output[0], skip_special_tokens=True)
102
+ generated_text_mio = output_mio[0]["generated_text"]
103
  st.subheader("Risultato BART CC100:")
104
  st.write(f" Frase predetta: {generated_text_mio}\n")