Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""rwkv_h.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1Z6xYOW9UPksew3P6bBvCK6FHYzqfBFIo | |
""" | |
#请 修改->笔记本设置->T4 GPU,然后确认下方输出Tesla T4来确认有显卡 | |
#然后点击 代码执行工具->全部运行 等待约五到十分钟 | |
#最后点击 最后下方显示的的链接 | |
# !nvidia-smi | |
# !pip install gradio | |
# !pip install huggingface_hub | |
# !pip install pynvml | |
# !pip install rwkv | |
# !pip install Ninja | |
import gradio as gr | |
import os, gc, copy, torch # Keep torch here for the CUDA_HOME fix | |
from datetime import datetime | |
from huggingface_hub import hf_hub_download | |
from pynvml import * | |
import re # <--- ADD THIS LINE FOR THE NAMEERROR | |
# Set CUDA_HOME explicitly for custom CUDA kernel compilation | |
os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
# Flag to check if GPU is present | |
HAS_GPU = False # Initialize to False, let pynvml determine | |
GPU_COUNT = 0 | |
# Model title and context size limit | |
ctx_limit = 2000 | |
# You are loading 3B here, which is good. | |
title = "RWKV-5-H-World-3B" # This was causing OOM | |
model_file = "rwkv-5-h-world-3B" # Stick with 3B for now | |
#title = "RWKV-5-H-World-7B" # This was causing OOM | |
#model_file = "rwkv-5-h-world-7B" # Stick with 7B for now | |
# Get the GPU count (this part is fine, though pynvml might warn) | |
try: | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
# Tesla T4 | |
nvmlInit() | |
GPU_COUNT = nvmlDeviceGetCount() | |
if GPU_COUNT > 0: | |
HAS_GPU = True | |
gpu_h = nvmlDeviceGetHandleByIndex(0) | |
# Removed .decode() as per previous fix | |
print(f"GPU detected: {nvmlDeviceGetName(gpu_h)} with {nvmlDeviceGetMemoryInfo(gpu_h).total / (1024**3):.2f} GB VRAM") | |
else: | |
print("No NVIDIA GPU detected. Will use CPU strategy.") | |
except NVMLError as error: | |
print(f"NVIDIA driver not found or error: {error}. Will use CPU strategy.") | |
except Exception as e: # Catch other potential errors during NVML init | |
print(f"An unexpected error occurred during GPU detection: {e}. Will use CPU strategy.") | |
os.environ["RWKV_JIT_ON"] = '1' | |
# Model strat to use | |
MODEL_STRAT="cpu bf16" # Default to CPU | |
os.environ["RWKV_CUDA_ON"] = '0' # Default to 0 | |
# Switch to GPU mode | |
if HAS_GPU: # Use this more robust check | |
os.environ["RWKV_CUDA_ON"] = '1' | |
MODEL_STRAT = "cuda bf16" # Keep bf16 for 3B model, as it fits. | |
# If you were to try 7B again, THIS is where you'd change to "cuda fp16i8" | |
print(f"MODEL_STRAT: {MODEL_STRAT}") | |
# Load the model accordingly | |
from rwkv.model import RWKV # Keep this import here as per your working code structure | |
model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth") | |
model = RWKV(model=model_path, strategy=MODEL_STRAT) | |
from rwkv.utils import PIPELINE, PIPELINE_ARGS # Keep this import here | |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") | |
print("RWKV model and pipeline loaded successfully!") | |
def generate_prompt(instruction, input=None, history=None): | |
if instruction: | |
instruction = ( | |
instruction.strip() | |
.replace("\r\n", "\n") | |
.replace("\n\n", "\n") | |
.replace("\n\n", "\n") | |
) | |
if (history is not None) and len(history) > 1: | |
input = "" | |
for pair in history: | |
if pair[0] is not None and pair[1] is not None and len(pair[1]) > 0: | |
input += f"{pair[0]},{pair[1]}," | |
input = input[:-1] + f". {instruction}" | |
instruction = "Generate a Response to the **last** question below." | |
if input and len(input) > 0: | |
input = ( | |
input.strip() | |
.replace("\r\n", "\n") | |
.replace("\n\n", "\n") | |
.replace("\n\n", "\n") | |
) | |
return f"""Instruction: {instruction} | |
Input: {input} | |
Response:""" | |
else: | |
return f"""User: {instruction} | |
Assistant:""" | |
examples = [ | |
["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 3900, 1.2, 0.5, 0.5, 0.5], | |
[ | |
"Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", | |
"", | |
3333, | |
1.2, | |
0.5, | |
0.5, | |
0.5, | |
], | |
["Write a song about ravens.", "", 3900, 1.2, 0.5, 0.5, 0.5], | |
["Explain the following metaphor: Life is like cats.", "", 3900, 1.2, 0.5, 0.5, 0.5], | |
[ | |
"Write a story using the following information", | |
"A man named Alex chops a tree down", | |
3333, | |
1.2, | |
0.5, | |
0.5, | |
0.5, | |
], | |
[ | |
"Generate a list of adjectives that describe a person as brave.", | |
"", | |
3333, | |
1.2, | |
0.5, | |
0.5, | |
0.5, | |
], | |
[ | |
"You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", | |
"", | |
3333, | |
1.2, | |
0.5, | |
0.5, | |
0.5, | |
], | |
] | |
def generator( | |
instruction, | |
input=None, | |
token_count=3900, | |
temperature=1.0, | |
top_p=0.5, | |
presencePenalty=0.5, | |
countPenalty=0.5, | |
history=None | |
): | |
args = PIPELINE_ARGS( | |
temperature=max(2.0, float(temperature)), | |
top_p=float(top_p), | |
alpha_frequency=countPenalty, | |
alpha_presence=presencePenalty, | |
token_ban=[], # ban the generation of some tokens | |
token_stop=[0], # stop generation whenever you see any token here | |
) | |
instruction = re.sub(r"\n{2,}", "\n", instruction).strip().replace("\r\n", "\n") | |
no_history = (history is None) | |
if no_history: | |
input = re.sub(r"\n{2,}", "\n", input).strip().replace("\r\n", "\n") | |
ctx = generate_prompt(instruction, input, history) | |
print(ctx + "\n") | |
all_tokens = [] | |
out_last = 0 | |
out_str = "" | |
occurrence = {} | |
state = None | |
for i in range(int(token_count)): | |
out, state = model.forward( | |
pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state | |
) | |
for n in occurrence: | |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency | |
token = pipeline.sample_logits( | |
out, temperature=args.temperature, top_p=args.top_p | |
) | |
if token in args.token_stop: | |
break | |
all_tokens += [token] | |
for xxx in occurrence: | |
occurrence[xxx] *= 0.996 | |
if token not in occurrence: | |
occurrence[token] = 1 | |
else: | |
occurrence[token] += 1 | |
tmp = pipeline.decode(all_tokens[out_last:]) | |
if "\ufffd" not in tmp: | |
out_str += tmp | |
if no_history: | |
yield out_str.strip() | |
else: | |
yield tmp | |
out_last = i + 1 | |
if "\n\n" in out_str: | |
break | |
del out | |
del state | |
gc.collect() | |
if no_history: | |
yield out_str.strip() | |
def user(message, chatbot): | |
chatbot = chatbot or [] | |
return "", chatbot + [[message, None]] | |
def alternative(chatbot, history): | |
if not chatbot or not history: | |
return chatbot, history | |
chatbot[-1][1] = None | |
history[0] = copy.deepcopy(history[1]) | |
return chatbot, history | |
with gr.Blocks(title=title) as demo: | |
gr.HTML(f'<div style="text-align: center;">\n<h1>🌍Chat - {title}</h1>\n</div>') | |
with gr.Tab("Chat mode"): | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot(type='messages') | |
msg = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Enter text and press enter", | |
container=False, | |
) | |
clear = gr.ClearButton([msg, chatbot]) | |
with gr.Column(): | |
token_count_chat = gr.Slider( | |
#10, 512, label="Max Tokens", step=10, value=333 | |
10, 8000, label="Max Tokens", step=10, value=4000 | |
) | |
temperature_chat = gr.Slider( | |
0.2, 2.0, label="Temperature", step=0.1, value=1.2 | |
) | |
top_p_chat = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.5) | |
presence_penalty_chat = gr.Slider( | |
0.0, 1.0, label="Presence Penalty", step=0.1, value=0.5 | |
) | |
count_penalty_chat = gr.Slider( | |
0.0, 1.0, label="Count Penalty", step=0.1, value=0.7 | |
) | |
def clear_chat(): | |
return "", [] | |
def user_msg(message, history): | |
history = history or [] | |
return "", history + [[message, None]] | |
def respond(history, token_count, temperature, top_p, presence_penalty, count_penalty): | |
instruction = history[-1][0] | |
history[-1][1] = "" | |
for character in generator( | |
instruction, | |
None, | |
token_count, | |
temperature, | |
top_p, | |
presence_penalty, | |
count_penalty, | |
history | |
): | |
history[-1][1] += character | |
yield history | |
msg.submit(user_msg, [msg, chatbot], [msg, chatbot], queue=False).then( | |
respond, [chatbot, token_count_chat, temperature_chat, top_p_chat, presence_penalty_chat, count_penalty_chat], chatbot, api_name="chat" | |
) | |
with gr.Tab("Instruct mode"): | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox( | |
lines=2, | |
label="Instruction", | |
value="東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", | |
) | |
input_instruct = gr.Textbox( | |
lines=2, label="Input", placeholder="", value="" | |
) | |
token_count_instruct = gr.Slider( | |
#10, 512, label="Max Tokens", step=10, value=333 | |
10, 8000, label="Max Tokens", step=10, value=4000 | |
) | |
temperature_instruct = gr.Slider( | |
0.2, 2.0, label="Temperature", step=0.1, value=1.2 | |
) | |
top_p_instruct = gr.Slider( | |
0.0, 1.0, label="Top P", step=0.05, value=0.5 | |
) | |
presence_penalty_instruct = gr.Slider( | |
0.0, 1.0, label="Presence Penalty", step=0.1, value=0.5 | |
) | |
count_penalty_instruct = gr.Slider( | |
0.0, 1.0, label="Count Penalty", step=0.1, value=0.5 | |
) | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit", variant="primary") | |
clear = gr.Button("Clear", variant="secondary") | |
output = gr.Textbox(label="Output", lines=5) | |
data = gr.Dataset( | |
components=[ | |
instruction, | |
input_instruct, | |
token_count_instruct, | |
temperature_instruct, | |
top_p_instruct, | |
presence_penalty_instruct, | |
count_penalty_instruct, | |
], | |
samples=examples, | |
label="Example Instructions", | |
headers=[ | |
"Instruction", | |
"Input", | |
"Max Tokens", | |
"Temperature", | |
"Top P", | |
"Presence Penalty", | |
"Count Penalty", | |
], | |
) | |
submit.click( | |
generator, | |
[ | |
instruction, | |
input_instruct, | |
token_count_instruct, | |
temperature_instruct, | |
top_p_instruct, | |
presence_penalty_instruct, | |
count_penalty_instruct, | |
], | |
[output], | |
) | |
clear.click(lambda: None, [], [output]) | |
data.click( | |
lambda x: x, | |
[data], | |
[ | |
instruction, | |
input_instruct, | |
token_count_instruct, | |
temperature_instruct, | |
top_p_instruct, | |
presence_penalty_instruct, | |
count_penalty_instruct, | |
], | |
) | |
demo.queue(max_size=10) | |
#demo.launch(share=False) | |
demo.launch(share=True) |