Cicciokr commited on
Commit
58f0b57
·
verified ·
1 Parent(s): 8a4bbe5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -51,7 +51,7 @@ input_text = st.text_area(
51
  model_name_mio = "Cicciokr/BART-la-s"
52
  model_mio = AutoModelForSeq2SeqLM.from_pretrained(model_name_mio)
53
  tokenizer_mio = AutoTokenizer.from_pretrained(model_name_mio)
54
- tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
55
 
56
  #generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
57
  generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
@@ -63,7 +63,10 @@ if input_text:
63
  st.write(f" -----------------------------------------------------------\n")
64
 
65
  inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
66
- output_mio = generator_mio(input_text
 
 
 
67
  #num_return_sequences=1,
68
  #top_k=50, # 🔹 Maggiore varietà nelle scelte
69
  #top_p=0.95, # 🔹 Nucleus sampling per migliorare il realismo
@@ -71,7 +74,7 @@ if input_text:
71
  #repetition_penalty=1.2, # 🔹 Evita ripetizioni e loop
72
  #max_length=50 # 🔹 Previene ripetizioni infinite
73
  )
74
- #generated_text_mio = tokenizer_mio.decode(output_mio[0], skip_special_tokens=True)
75
- generated_text_mio = output_mio[0]["generated_text"]
76
  st.subheader("Risultato BART CC100:")
77
  st.write(f" Frase predetta: {generated_text_mio}\n")
 
51
  model_name_mio = "Cicciokr/BART-la-s"
52
  model_mio = AutoModelForSeq2SeqLM.from_pretrained(model_name_mio)
53
  tokenizer_mio = AutoTokenizer.from_pretrained(model_name_mio)
54
+ #tokenizer_mio.pad_token_id = tokenizer_mio.eos_token_id
55
 
56
  #generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
57
  generator_mio = pipeline("text2text-generation", model=model_mio, tokenizer=tokenizer_mio)
 
63
  st.write(f" -----------------------------------------------------------\n")
64
 
65
  inputs_mio = tokenizer_mio(input_text, return_tensors="pt")
66
+ output_mio = model_mio.generate(
67
+ **inputs_mio,
68
+ forced_bos_token_id=tokenizer.bos_token_id,
69
+ max_length=20, do_sample=True, top_p=0.96, num_return_sequences=5
70
  #num_return_sequences=1,
71
  #top_k=50, # 🔹 Maggiore varietà nelle scelte
72
  #top_p=0.95, # 🔹 Nucleus sampling per migliorare il realismo
 
74
  #repetition_penalty=1.2, # 🔹 Evita ripetizioni e loop
75
  #max_length=50 # 🔹 Previene ripetizioni infinite
76
  )
77
+ generated_text_mio = tokenizer_mio.batch_decode(output_mio[0], skip_special_tokens=True)
78
+ #generated_text_mio = output_mio[0]["generated_text"]
79
  st.subheader("Risultato BART CC100:")
80
  st.write(f" Frase predetta: {generated_text_mio}\n")