dicklee2046 commited on
Commit
f39e4e5
·
verified ·
1 Parent(s): e9e576f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -19,57 +19,67 @@ Original file is located at
19
  # !pip install Ninja
20
 
21
  import gradio as gr
22
- import gc, copy, re
 
23
  from huggingface_hub import hf_hub_download
24
  from pynvml import *
25
 
26
-
 
27
 
28
  # Flag to check if GPU is present
29
- HAS_GPU = True
 
30
 
31
  # Model title and context size limit
32
- ctx_limit = 8192
33
- title = "RWKV-5-H-World-7B"
34
- model_file = "rwkv-5-h-world-7B"
 
35
 
36
- #title = "RWKV-5-H-World-3B"
37
- #model_file = "rwkv-5-h-world-3B"
38
 
39
- # Get the GPU count
40
  try:
41
  nvmlInit()
42
  GPU_COUNT = nvmlDeviceGetCount()
43
  if GPU_COUNT > 0:
44
  HAS_GPU = True
45
  gpu_h = nvmlDeviceGetHandleByIndex(0)
 
 
 
 
46
  except NVMLError as error:
47
- print(error)
 
 
48
 
49
 
50
  os.environ["RWKV_JIT_ON"] = '1'
51
 
52
  # Model strat to use
53
- MODEL_STRAT="cpu bf16"
54
- os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
55
 
56
  # Switch to GPU mode
57
- print(f"HAS_GPU = {HAS_GPU}")
58
- if HAS_GPU == True :
59
  os.environ["RWKV_CUDA_ON"] = '1'
60
- MODEL_STRAT = "cuda bf16"
61
- #MODEL_STRAT = "cuda bf16 *20 -> cpu fp32"
62
- #MODEL_STRAT = "cuda fp16i8" # or "cuda int8" if that's supported
63
-
64
  print(f"MODEL_STRAT: {MODEL_STRAT}")
65
 
 
66
  # Load the model accordingly
67
- from rwkv.model import RWKV
68
  model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
69
  model = RWKV(model=model_path, strategy=MODEL_STRAT)
70
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
 
71
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
72
- print("RWKV model loaded successfully!")
 
73
 
74
  def generate_prompt(instruction, input=None, history=None):
75
  if instruction:
 
19
  # !pip install Ninja
20
 
21
  import gradio as gr
22
+ import os, gc, copy, torch # Keep torch here for the CUDA_HOME fix
23
+ from datetime import datetime
24
  from huggingface_hub import hf_hub_download
25
  from pynvml import *
26
 
27
+ # Set CUDA_HOME explicitly for custom CUDA kernel compilation
28
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
29
 
30
  # Flag to check if GPU is present
31
+ HAS_GPU = False # Initialize to False, let pynvml determine
32
+ GPU_COUNT = 0
33
 
34
  # Model title and context size limit
35
+ ctx_limit = 2000
36
+ # You are loading 3B here, which is good.
37
+ title = "RWKV-5-H-World-3B" # This was causing OOM
38
+ model_file = "rwkv-5-h-world-3B" # Stick with 3B for now
39
 
40
+ #title = "RWKV-5-H-World-7B" # This was causing OOM
41
+ #model_file = "rwkv-5-h-world-7B" # Stick with 7B for now
42
 
43
+ # Get the GPU count (this part is fine, though pynvml might warn)
44
  try:
45
  nvmlInit()
46
  GPU_COUNT = nvmlDeviceGetCount()
47
  if GPU_COUNT > 0:
48
  HAS_GPU = True
49
  gpu_h = nvmlDeviceGetHandleByIndex(0)
50
+ # Removed .decode() as per previous fix
51
+ print(f"GPU detected: {nvmlDeviceGetName(gpu_h)} with {nvmlDeviceGetMemoryInfo(gpu_h).total / (1024**3):.2f} GB VRAM")
52
+ else:
53
+ print("No NVIDIA GPU detected. Will use CPU strategy.")
54
  except NVMLError as error:
55
+ print(f"NVIDIA driver not found or error: {error}. Will use CPU strategy.")
56
+ except Exception as e: # Catch other potential errors during NVML init
57
+ print(f"An unexpected error occurred during GPU detection: {e}. Will use CPU strategy.")
58
 
59
 
60
  os.environ["RWKV_JIT_ON"] = '1'
61
 
62
  # Model strat to use
63
+ MODEL_STRAT="cpu bf16" # Default to CPU
64
+ os.environ["RWKV_CUDA_ON"] = '0' # Default to 0
65
 
66
  # Switch to GPU mode
67
+ if HAS_GPU: # Use this more robust check
 
68
  os.environ["RWKV_CUDA_ON"] = '1'
69
+ MODEL_STRAT = "cuda bf16" # Keep bf16 for 3B model, as it fits.
70
+ # If you were to try 7B again, THIS is where you'd change to "cuda fp16i8"
 
 
71
  print(f"MODEL_STRAT: {MODEL_STRAT}")
72
 
73
+
74
  # Load the model accordingly
75
+ from rwkv.model import RWKV # Keep this import here as per your working code structure
76
  model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
77
  model = RWKV(model=model_path, strategy=MODEL_STRAT)
78
+
79
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS # Keep this import here
80
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
81
+
82
+ print("RWKV model and pipeline loaded successfully!")
83
 
84
  def generate_prompt(instruction, input=None, history=None):
85
  if instruction: