File size: 3,333 Bytes
6810eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import os
import torch
import torch.nn.functional as F
import gradio as gr
from tokenizers import Tokenizer
from model.gpt_model import GPTModel
from data import utils

# Load model and tokenizer (to be initialized in main)
model = None
tokenizer = None
device = None
max_position = None

def generate_text(prompt, max_length, temperature):
    # Generate text using the globally loaded model and tokenizer
    encoded = tokenizer.encode(prompt)
    input_ids = encoded.ids
    if len(input_ids) > max_position:
        input_ids = input_ids[-max_position:]
    generated = input_ids[:]  # copy initial prompt token IDs
    model.eval()
    with torch.no_grad():
        for _ in range(int(max_length)):
            inp = torch.tensor([generated], dtype=torch.long, device=device)
            outputs = model(inp)
            logits = outputs[0, -1, :]
            # Apply temperature
            if temperature != 1.0:
                logits = logits / temperature
            # Sample next token
            probabilities = F.softmax(logits, dim=-1)
            next_token_id = int(torch.multinomial(probabilities, num_samples=1))
            generated.append(next_token_id)
    result = tokenizer.decode(generated)
    return result

def main():
    global model, tokenizer, device, max_position
    parser = argparse.ArgumentParser(description="Launch OpenGPT Gradio demo.")
    parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint (.pt).")
    parser.add_argument("--config", type=str, required=True, help="Path to model config (YAML/JSON).")
    parser.add_argument("--tokenizer", type=str, required=True, help="Path to tokenizer directory or tokenizer.json file.")
    args = parser.parse_args()

    # Load config and initialize model
    config = utils.load_config(args.config)
    model_conf = config.get("model", {})
    vocab_size = model_conf["vocab_size"]
    max_position = model_conf.get("max_position_embeddings", 512)
    hidden_dim = model_conf.get("embedding_dim", 768)
    n_layers = model_conf.get("n_layers", 12)
    n_heads = model_conf.get("n_heads", 12)
    dropout = model_conf.get("dropout", 0.0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_position,
                     n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim,
                     dropout=dropout).to(device)
    utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device)
    model.eval()

    # Load tokenizer
    tk_path = args.tokenizer
    if os.path.isdir(tk_path):
        tk_path = os.path.join(tk_path, "tokenizer.json")
    tokenizer = Tokenizer.from_file(tk_path)

    # Create Gradio interface
    interface = gr.Interface(
        fn=generate_text,
        inputs=[
            gr.Textbox(lines=3, label="Prompt"),
            gr.Slider(minimum=1, maximum=100, value=50, label="Max Length"),
            gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        ],
        outputs=gr.Textbox(lines=5, label="Generated Text"),
        title="OpenGPT Text Generation Demo",
        description="Enter a prompt and generate text using the OpenGPT model."
    )
    interface.launch()

if __name__ == "__main__":
    main()