File size: 6,933 Bytes
5d4acdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
from transformers import pipeline
import torch
import time

# --- 配置 ---
MODEL_ID = "jinv2/opt125m-wikitext2-finetuned"
TASK = "text-generation"

# --- 设备选择 ---
# 优先使用 GPU (如果 Space 配置了)
device = 0 if torch.cuda.is_available() else -1
device_name = "GPU" if device == 0 else "CPU"
print(f"使用设备: {device_name}")

# --- 加载模型 Pipeline ---
# 使用 pipeline 简化文本生成任务
print(f"开始加载模型: {MODEL_ID}...")
try:
    # 对于 OPT 模型,通常不需要 trust_remote_code=True
    # torch_dtype 设为 'auto' 让 transformers 自动选择最佳精度
    pipe = pipeline(
        TASK,
        model=MODEL_ID,
        torch_dtype='auto', # 自动选择精度 (float32 on CPU, float16/bfloat16 on GPU if supported)
        device=device
    )
    print("模型加载成功。")
    # 获取模型实际加载的数据类型
    if hasattr(pipe.model, 'dtype'):
         loaded_dtype = pipe.model.dtype
         print(f"模型加载使用的数据类型: {loaded_dtype}")
    else:
         print("无法自动检测模型加载的数据类型,可能使用默认值。")

except Exception as e:
    print(f"加载模型时出错: {e}")
    raise gr.Error(f"加载模型 '{MODEL_ID}' 失败。错误: {e}。请检查 Space 日志。")

# --- 文本生成函数 ---
def generate_text(prompt, max_length, temperature, top_p, repetition_penalty):
    """使用加载的 pipeline 生成文本"""
    if not prompt:
        return "请输入起始文本 (prompt)。"

    print(f"\n收到提示词: '{prompt}'")
    print(f"生成参数: 最大长度={max_length}, 温度={temperature}, Top-p={top_p}, 重复惩罚={repetition_penalty}")

    # 注意:max_length 通常包含 prompt 的长度。
    # 我们希望生成 max_new_tokens,所以总长度是 prompt 长度 + max_new_tokens
    # 但 text-generation pipeline 的 max_length 参数是 *总* 长度。
    # 为简单起见,我们直接使用 max_length 作为总长度限制,用户输入的 prompt 会被计算在内。
    # 或者,我们可以计算 prompt 的 token 数量并加上期望的新 token 数。
    # 这里我们采用更简单的 max_length 方法。

    start_time = time.time()
    try:
        # OPT 模型通常用于文本续写,不需要复杂的聊天模板
        outputs = pipe(
            prompt,
            max_length=max_length, # 这是生成的总文本长度,包括 prompt
            do_sample=True if temperature > 0 else False, # 仅当 temperature > 0 时采样
            temperature=max(temperature, 1e-6), # Temperature 不能为 0 或负数
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            num_return_sequences=1,
            pad_token_id=pipe.tokenizer.eos_token_id # 避免填充警告
        )
        generated_text = outputs[0]['generated_text']

        # pipeline 输出通常包含原始提示,我们只返回生成的部分
        # (如果需要完整文本,可以直接返回 generated_text)
        response = generated_text[len(prompt):].strip()

        end_time = time.time()
        duration = end_time - start_time
        print(f"生成完成。原始输出长度: {len(generated_text)}, 提取的续写部分: {response}")
        print(f"生成耗时: {duration:.2f} 秒")

        # 如果模型有时不生成任何新内容,返回提示信息
        if not response and len(generated_text) <= len(prompt):
             return "(模型没有生成新的文本,可能需要调整参数或 prompt)"
        return response

    except Exception as e:
        print(f"生成过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return f"生成过程中发生错误: {e}"

# --- 创建 Gradio 界面 ---
with gr.Blocks(theme=gr.themes.Soft(), title=f"测试 {MODEL_ID}") as demo:
    gr.Markdown(f"""
    # 测试文本生成模型: `{MODEL_ID}`
    输入一段起始文本 (prompt),模型将尝试续写它。
    **注意:** 模型运行在 **{device_name}** 上。
    """)

    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Textbox(
                label="输入起始文本 (Prompt)",
                lines=5,
                placeholder="例如:从前有一只勇敢的小兔子,它梦想着..."
            )
            with gr.Accordion("高级生成选项", open=False):
                max_length_slider = gr.Slider(
                    minimum=20,
                    maximum=512, # OPT-125m 的标准上下文长度通常是 2048,但设置低一些以防内存问题和过长生成
                    value=100, # 默认生成较短的续写
                    step=10,
                    label="最大总长度 (Max Length)",
                    info="生成的文本(包括提示)的最大令牌数。"
                )
                temperature_slider = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.05,
                    label="温度 (Temperature)",
                    info="控制随机性。>1 更随机, <1 更确定。0 表示贪婪解码。"
                )
                top_p_slider = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top-p (Nucleus Sampling)",
                    info="累积概率阈值,用于筛选下一个词的候选。仅在 temperature > 0 时有效。"
                )
                repetition_penalty_slider = gr.Slider(
                    minimum=1.0,
                    maximum=2.0,
                    value=1.1,
                    step=0.1,
                    label="重复惩罚 (Repetition Penalty)",
                    info="大于 1 可减少重复。设为 1.0 则禁用。"
                )
        submit_button = gr.Button("生成续写", variant="primary")

    with gr.Column(scale=3):
        output_text = gr.Textbox(
            label="模型续写内容 (Generated Text)",
            lines=15,
            interactive=False
        )

    gr.Examples(
        examples=[
            ["人工智能的未来是", 150, 0.8, 0.9, 1.1],
            ["今天天气真不错,阳光明媚,", 80, 0.7, 0.95, 1.0],
            ["The quick brown fox jumps over the", 50, 0.5, 0.9, 1.2],
        ],
        inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider, repetition_penalty_slider],
        outputs=output_text,
        fn=generate_text,
        cache_examples=False,
        label="示例"
    )

    submit_button.click(
        fn=generate_text,
        inputs=[prompt_input, max_length_slider, temperature_slider, top_p_slider, repetition_penalty_slider],
        outputs=output_text,
        api_name="generate"
    )

# 启动 Gradio 应用
demo.launch()