skylersterling commited on
Commit
ad4d3e1
·
verified ·
1 Parent(s): 1a13068

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -11,7 +11,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
11
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
12
  model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
13
  model.eval()
14
- model.to("cpu")
15
 
16
  # Define the function that generates text from a prompt
17
  def generate_text(prompt):
@@ -30,13 +30,12 @@ def generate_text(prompt):
30
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
31
 
32
  decoded_token = tokenizer.decode(next_token.item())
33
- # Print each token as it is generated
34
- print(decoded_token, end='', flush=True)
35
 
36
- # Decode the generated tokens to a string
37
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
38
- return generated_text
39
 
40
  # Create a Gradio interface with a text input and a text output
41
- interface = gr.Interface(fn=generate_text, inputs='text', outputs='text')
42
  interface.launch()
 
11
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
12
  model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
13
  model.eval()
14
+ model.to('cpu')
15
 
16
  # Define the function that generates text from a prompt
17
  def generate_text(prompt):
 
30
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
31
 
32
  decoded_token = tokenizer.decode(next_token.item())
33
+ yield decoded_token # Yield each token as it is generated
 
34
 
35
+ # Decode the entire generated text
36
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
37
+ yield generated_text
38
 
39
  # Create a Gradio interface with a text input and a text output
40
+ interface = gr.Interface(fn=generate_text, inputs='text', outputs='text', live=True)
41
  interface.launch()