Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -19,57 +19,67 @@ Original file is located at
|
|
19 |
# !pip install Ninja
|
20 |
|
21 |
import gradio as gr
|
22 |
-
import gc, copy,
|
|
|
23 |
from huggingface_hub import hf_hub_download
|
24 |
from pynvml import *
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
# Flag to check if GPU is present
|
29 |
-
HAS_GPU =
|
|
|
30 |
|
31 |
# Model title and context size limit
|
32 |
-
ctx_limit =
|
33 |
-
|
34 |
-
|
|
|
35 |
|
36 |
-
#title = "RWKV-5-H-World-
|
37 |
-
#model_file = "rwkv-5-h-world-
|
38 |
|
39 |
-
# Get the GPU count
|
40 |
try:
|
41 |
nvmlInit()
|
42 |
GPU_COUNT = nvmlDeviceGetCount()
|
43 |
if GPU_COUNT > 0:
|
44 |
HAS_GPU = True
|
45 |
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
|
|
|
|
|
|
|
|
46 |
except NVMLError as error:
|
47 |
-
print(error)
|
|
|
|
|
48 |
|
49 |
|
50 |
os.environ["RWKV_JIT_ON"] = '1'
|
51 |
|
52 |
# Model strat to use
|
53 |
-
MODEL_STRAT="cpu bf16"
|
54 |
-
os.environ["RWKV_CUDA_ON"] = '0' #
|
55 |
|
56 |
# Switch to GPU mode
|
57 |
-
|
58 |
-
if HAS_GPU == True :
|
59 |
os.environ["RWKV_CUDA_ON"] = '1'
|
60 |
-
MODEL_STRAT = "cuda bf16"
|
61 |
-
#
|
62 |
-
#MODEL_STRAT = "cuda fp16i8" # or "cuda int8" if that's supported
|
63 |
-
|
64 |
print(f"MODEL_STRAT: {MODEL_STRAT}")
|
65 |
|
|
|
66 |
# Load the model accordingly
|
67 |
-
from rwkv.model import RWKV
|
68 |
model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
|
69 |
model = RWKV(model=model_path, strategy=MODEL_STRAT)
|
70 |
-
|
|
|
71 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
72 |
-
|
|
|
73 |
|
74 |
def generate_prompt(instruction, input=None, history=None):
|
75 |
if instruction:
|
|
|
19 |
# !pip install Ninja
|
20 |
|
21 |
import gradio as gr
|
22 |
+
import os, gc, copy, torch # Keep torch here for the CUDA_HOME fix
|
23 |
+
from datetime import datetime
|
24 |
from huggingface_hub import hf_hub_download
|
25 |
from pynvml import *
|
26 |
|
27 |
+
# Set CUDA_HOME explicitly for custom CUDA kernel compilation
|
28 |
+
os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
29 |
|
30 |
# Flag to check if GPU is present
|
31 |
+
HAS_GPU = False # Initialize to False, let pynvml determine
|
32 |
+
GPU_COUNT = 0
|
33 |
|
34 |
# Model title and context size limit
|
35 |
+
ctx_limit = 2000
|
36 |
+
# You are loading 3B here, which is good.
|
37 |
+
title = "RWKV-5-H-World-3B" # This was causing OOM
|
38 |
+
model_file = "rwkv-5-h-world-3B" # Stick with 3B for now
|
39 |
|
40 |
+
#title = "RWKV-5-H-World-7B" # This was causing OOM
|
41 |
+
#model_file = "rwkv-5-h-world-7B" # Stick with 7B for now
|
42 |
|
43 |
+
# Get the GPU count (this part is fine, though pynvml might warn)
|
44 |
try:
|
45 |
nvmlInit()
|
46 |
GPU_COUNT = nvmlDeviceGetCount()
|
47 |
if GPU_COUNT > 0:
|
48 |
HAS_GPU = True
|
49 |
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
50 |
+
# Removed .decode() as per previous fix
|
51 |
+
print(f"GPU detected: {nvmlDeviceGetName(gpu_h)} with {nvmlDeviceGetMemoryInfo(gpu_h).total / (1024**3):.2f} GB VRAM")
|
52 |
+
else:
|
53 |
+
print("No NVIDIA GPU detected. Will use CPU strategy.")
|
54 |
except NVMLError as error:
|
55 |
+
print(f"NVIDIA driver not found or error: {error}. Will use CPU strategy.")
|
56 |
+
except Exception as e: # Catch other potential errors during NVML init
|
57 |
+
print(f"An unexpected error occurred during GPU detection: {e}. Will use CPU strategy.")
|
58 |
|
59 |
|
60 |
os.environ["RWKV_JIT_ON"] = '1'
|
61 |
|
62 |
# Model strat to use
|
63 |
+
MODEL_STRAT="cpu bf16" # Default to CPU
|
64 |
+
os.environ["RWKV_CUDA_ON"] = '0' # Default to 0
|
65 |
|
66 |
# Switch to GPU mode
|
67 |
+
if HAS_GPU: # Use this more robust check
|
|
|
68 |
os.environ["RWKV_CUDA_ON"] = '1'
|
69 |
+
MODEL_STRAT = "cuda bf16" # Keep bf16 for 3B model, as it fits.
|
70 |
+
# If you were to try 7B again, THIS is where you'd change to "cuda fp16i8"
|
|
|
|
|
71 |
print(f"MODEL_STRAT: {MODEL_STRAT}")
|
72 |
|
73 |
+
|
74 |
# Load the model accordingly
|
75 |
+
from rwkv.model import RWKV # Keep this import here as per your working code structure
|
76 |
model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
|
77 |
model = RWKV(model=model_path, strategy=MODEL_STRAT)
|
78 |
+
|
79 |
+
from rwkv.utils import PIPELINE, PIPELINE_ARGS # Keep this import here
|
80 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
81 |
+
|
82 |
+
print("RWKV model and pipeline loaded successfully!")
|
83 |
|
84 |
def generate_prompt(instruction, input=None, history=None):
|
85 |
if instruction:
|