wop's picture
Update app.py
31935be verified
raw
history blame contribute delete
No virus
1.87 kB
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import gradio as gr
# Check if a GPU is available and use it, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model and tokenizer from the saved directory
model_path = "Blexus/Quble_Test_Model_v1_Pretrain"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
# Set model to evaluation mode
model.eval()
# Function to generate text based on input prompt
def generate_text(prompt):
# Tokenize and encode the input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate continuation
with torch.no_grad():
generated_ids = model.generate(
input_ids,
max_length=50, # Maximum length of generated text
num_return_sequences=1, # Generate 1 sequence
pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding
do_sample=True, # Enable sampling
top_k=50, # Top-k sampling
top_p=0.95 # Nucleus sampling
)
# Decode the generated text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_text
# Create a Gradio interface
interface = gr.Interface(
fn=generate_text, # Function to call when interacting with the UI
inputs="text", # Input type: Single-line text
outputs="text", # Output type: Text (the generated output)
title="Quble Text Generation", # Title of the UI
description="Enter a prompt to generate text using Quble." # Simple description
)
# Launch the Gradio app
interface.launch()