ibrahimgiki commited on
Commit
f1234fd
·
verified ·
1 Parent(s): 6df2d37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -1,15 +1,20 @@
1
  # Load the GPT-2 large model and tokenizer
2
  model_name = "gpt2-large"
3
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
4
  model = AutoModelForCausalLM.from_pretrained(model_name)
5
 
6
  # Function to generate a blog post based on a topic title
7
  def generate_blog(topic_title, max_length=300):
8
  # Step 1: Encode the input
9
- input_ids = tokenizer.encode(topic_title, return_tensors='pt')
 
 
10
 
11
  # Step 2: Generate model output
12
- output_ids = model.generate(input_ids, max_length=max_length, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
13
 
14
  # Step 3: Decode the output
15
  blog_post = tokenizer.decode(output_ids[0], skip_special_tokens=True)
@@ -17,7 +22,7 @@ def generate_blog(topic_title, max_length=300):
17
  return blog_post
18
 
19
  # Example usage
20
- topic_title = input("Enter a topic title for the blog post: ")
21
  blog_post = generate_blog(topic_title)
22
  print("\nGenerated Blog Post:\n")
23
- print(blog_post)
 
1
  # Load the GPT-2 large model and tokenizer
2
  model_name = "gpt2-large"
3
  tokenizer = AutoTokenizer.from_pretrained(model_name)
4
+ # Add padding token to the tokenizer
5
+ tokenizer.pad_token = tokenizer.eos_token # Set padding token to EOS token
6
+
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  # Function to generate a blog post based on a topic title
10
  def generate_blog(topic_title, max_length=300):
11
  # Step 1: Encode the input
12
+ inputs = tokenizer.encode_plus(topic_title, return_tensors='pt', padding=True)
13
+ input_ids = inputs['input_ids']
14
+ attention_mask = inputs['attention_mask']
15
 
16
  # Step 2: Generate model output
17
+ output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
18
 
19
  # Step 3: Decode the output
20
  blog_post = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
22
  return blog_post
23
 
24
  # Example usage
25
+ topic_title = input("Enter title for the blog: ")
26
  blog_post = generate_blog(topic_title)
27
  print("\nGenerated Blog Post:\n")
28
+ print(blog_post)