|
import argparse |
|
import os |
|
import torch |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
from tokenizers import Tokenizer |
|
from model.gpt_model import GPTModel |
|
from data import utils |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
device = None |
|
max_position = None |
|
|
|
def generate_text(prompt, max_length, temperature): |
|
|
|
encoded = tokenizer.encode(prompt) |
|
input_ids = encoded.ids |
|
if len(input_ids) > max_position: |
|
input_ids = input_ids[-max_position:] |
|
generated = input_ids[:] |
|
model.eval() |
|
with torch.no_grad(): |
|
for _ in range(int(max_length)): |
|
inp = torch.tensor([generated], dtype=torch.long, device=device) |
|
outputs = model(inp) |
|
logits = outputs[0, -1, :] |
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
probabilities = F.softmax(logits, dim=-1) |
|
next_token_id = int(torch.multinomial(probabilities, num_samples=1)) |
|
generated.append(next_token_id) |
|
result = tokenizer.decode(generated) |
|
return result |
|
|
|
def main(): |
|
global model, tokenizer, device, max_position |
|
parser = argparse.ArgumentParser(description="Launch OpenGPT Gradio demo.") |
|
parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint (.pt).") |
|
parser.add_argument("--config", type=str, required=True, help="Path to model config (YAML/JSON).") |
|
parser.add_argument("--tokenizer", type=str, required=True, help="Path to tokenizer directory or tokenizer.json file.") |
|
args = parser.parse_args() |
|
|
|
|
|
config = utils.load_config(args.config) |
|
model_conf = config.get("model", {}) |
|
vocab_size = model_conf["vocab_size"] |
|
max_position = model_conf.get("max_position_embeddings", 512) |
|
hidden_dim = model_conf.get("embedding_dim", 768) |
|
n_layers = model_conf.get("n_layers", 12) |
|
n_heads = model_conf.get("n_heads", 12) |
|
dropout = model_conf.get("dropout", 0.0) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_position, |
|
n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim, |
|
dropout=dropout).to(device) |
|
utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device) |
|
model.eval() |
|
|
|
|
|
tk_path = args.tokenizer |
|
if os.path.isdir(tk_path): |
|
tk_path = os.path.join(tk_path, "tokenizer.json") |
|
tokenizer = Tokenizer.from_file(tk_path) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox(lines=3, label="Prompt"), |
|
gr.Slider(minimum=1, maximum=100, value=50, label="Max Length"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") |
|
], |
|
outputs=gr.Textbox(lines=5, label="Generated Text"), |
|
title="OpenGPT Text Generation Demo", |
|
description="Enter a prompt and generate text using the OpenGPT model." |
|
) |
|
interface.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|