wop commited on
Commit
31935be
1 Parent(s): 24aba97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -37
app.py CHANGED
@@ -13,47 +13,35 @@ model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
13
  # Set model to evaluation mode
14
  model.eval()
15
 
16
- # Function to generate text in a stream-based manner
17
  def generate_text(prompt):
18
  # Tokenize and encode the input prompt
19
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
20
- max_length = 50 # Maximum length of generated text
21
-
22
- # Generate continuation with streaming tokens
23
- output_ids = input_ids
24
- for _ in range(max_length):
25
- # Generate the next token only (without generating full sequence)
26
- with torch.no_grad():
27
- next_token_ids = model.generate(
28
- output_ids,
29
- max_length=output_ids.shape[-1] + 1, # Increase length by 1 each step
30
- num_return_sequences=1,
31
- pad_token_id=tokenizer.eos_token_id,
32
- do_sample=True,
33
- top_k=50,
34
- top_p=0.95,
35
- use_cache=True
36
- )
37
-
38
- # Get the newly generated token (last one in the sequence)
39
- next_token = next_token_ids[:, -1:]
40
-
41
- # Append new token to the output sequence
42
- output_ids = torch.cat((output_ids, next_token), dim=-1)
43
-
44
- # Decode and yield the output text incrementally
45
- decoded_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
- yield decoded_text
47
-
48
- # Create a Gradio interface with streaming enabled
49
  interface = gr.Interface(
50
- fn=generate_text, # Function to call when interacting with the UI
51
- inputs="text", # Input type: Single-line text
52
- outputs=gr.Textbox(), # Stream output using a Textbox for real-time updates
53
- title="Quble Text Generation", # Title of the UI
54
- description="Enter a prompt to generate text using Quble with live streaming.", # Simple description
55
- live=True # Enable live streaming of the output
56
  )
57
 
58
  # Launch the Gradio app
59
- interface.launch()
 
13
  # Set model to evaluation mode
14
  model.eval()
15
 
16
+ # Function to generate text based on input prompt
17
  def generate_text(prompt):
18
  # Tokenize and encode the input prompt
19
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
20
+
21
+ # Generate continuation
22
+ with torch.no_grad():
23
+ generated_ids = model.generate(
24
+ input_ids,
25
+ max_length=50, # Maximum length of generated text
26
+ num_return_sequences=1, # Generate 1 sequence
27
+ pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding
28
+ do_sample=True, # Enable sampling
29
+ top_k=50, # Top-k sampling
30
+ top_p=0.95 # Nucleus sampling
31
+ )
32
+
33
+ # Decode the generated text
34
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
35
+ return generated_text
36
+
37
+ # Create a Gradio interface
 
 
 
 
 
 
 
 
 
 
 
38
  interface = gr.Interface(
39
+ fn=generate_text, # Function to call when interacting with the UI
40
+ inputs="text", # Input type: Single-line text
41
+ outputs="text", # Output type: Text (the generated output)
42
+ title="Quble Text Generation", # Title of the UI
43
+ description="Enter a prompt to generate text using Quble." # Simple description
 
44
  )
45
 
46
  # Launch the Gradio app
47
+ interface.launch()