FlameF0X commited on
Commit
62b2e19
·
verified ·
1 Parent(s): d29de7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -63
app.py CHANGED
@@ -3,17 +3,9 @@ import torch
3
  import gradio as gr
4
  import datetime
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
6
- from safetensors.torch import load_file
7
 
8
  import spaces
9
 
10
- @spaces.GPU
11
- def use_gpu():
12
- import torch
13
- print("Torch CUDA available:", torch.cuda.is_available())
14
- return {"cuda_available": torch.cuda.is_available()}
15
-
16
-
17
  # Constants
18
  MODEL_CONFIG = {
19
  "G0-Release": "FlameF0X/Snowflake-G0-Release",
@@ -45,48 +37,54 @@ css = """
45
  .model-select { background-color: #2a2a4a; padding: 10px; border-radius: 8px; margin-bottom: 15px; }
46
  """
47
 
 
48
  model_registry = {}
49
 
50
- def load_all_models():
51
- for name, model_id in MODEL_CONFIG.items():
52
- print(f"Loading model: {name} from {model_id}")
53
- tokenizer = AutoTokenizer.from_pretrained(model_id)
54
- if tokenizer.pad_token is None:
55
- tokenizer.pad_token = tokenizer.eos_token
56
-
57
- safetensor_path = os.path.join(model_id, "model.safetensors")
58
- if os.path.exists(safetensor_path):
59
- print("Loading from safetensors...")
60
- model = load_file(safetensor_path)
61
- else:
62
- print("Loading from Hugging Face or .bin...")
63
- # Key fix: no device_map, load on CPU only
64
- model = AutoModelForCausalLM.from_pretrained(
65
- model_id,
66
- torch_dtype=torch.float32,
67
- device_map=None
68
- )
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  pipeline = TextGenerationPipeline(
71
  model=model,
72
  tokenizer=tokenizer,
73
  return_full_text=False,
74
- max_length=MAX_LENGTH
 
75
  )
76
-
77
- model_registry[name] = (model, tokenizer, pipeline)
78
-
79
- def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_tokens, history=None):
80
- if history is None:
81
- history = []
82
- history.append({"role": "user", "content": prompt})
83
-
84
- try:
85
- if model_version not in model_registry:
86
- raise ValueError(f"Model '{model_version}' not found.")
87
-
88
- _, tokenizer, pipeline = model_registry[model_version]
89
-
90
  outputs = pipeline(
91
  prompt,
92
  do_sample=temperature > 0,
@@ -97,19 +95,43 @@ def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_toke
97
  pad_token_id=tokenizer.pad_token_id,
98
  num_return_sequences=1
99
  )
100
-
101
  response = outputs[0]["generated_text"]
102
- history.append({"role": "assistant", "content": response, "model": model_version})
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  formatted_history = []
105
  for entry in history:
106
  prefix = "👤 User: " if entry["role"] == "user" else f"❄️ [{entry.get('model', 'Model')}]: "
107
  formatted_history.append(f"{prefix}{entry['content']}")
108
-
109
  return response, history, "\n\n".join(formatted_history)
110
-
111
  except Exception as e:
112
- error_msg = f"Error generating response: {str(e)}"
113
  history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version})
114
  return error_msg, history, str(history)
115
 
@@ -230,21 +252,9 @@ def create_demo():
230
 
231
  return demo
232
 
233
- # Initialize
234
- print("Loading Snowflake models...")
235
- try:
236
- load_all_models()
237
- print("All models loaded successfully!")
238
- demo = create_demo()
239
- except Exception as e:
240
- print(f"Error loading models: {e}")
241
- with gr.Blocks(css=css) as demo:
242
- gr.HTML(f"""
243
- <div class="header" style="background-color: #ffebee;">
244
- <h1><span class="snowflake-icon">⚠️</span> Error Loading Models</h1>
245
- <p>There was a problem loading the Snowflake models: {str(e)}</p>
246
- </div>
247
- """)
248
 
249
  if __name__ == "__main__":
250
- demo.launch()
 
3
  import gradio as gr
4
  import datetime
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
 
6
 
7
  import spaces
8
 
 
 
 
 
 
 
 
9
  # Constants
10
  MODEL_CONFIG = {
11
  "G0-Release": "FlameF0X/Snowflake-G0-Release",
 
37
  .model-select { background-color: #2a2a4a; padding: 10px; border-radius: 8px; margin-bottom: 15px; }
38
  """
39
 
40
+ # Global registry - models will be loaded on-demand within GPU function
41
  model_registry = {}
42
 
43
+ def load_model_cpu(model_id):
44
+ """Load model on CPU only - no CUDA initialization"""
45
+ print(f"Loading model on CPU: {model_id}")
46
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ # Load model on CPU only
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_id,
53
+ torch_dtype=torch.float32,
54
+ device_map=None, # No device mapping
55
+ low_cpu_mem_usage=True
56
+ )
57
+
58
+ return model, tokenizer
 
 
 
59
 
60
+ @spaces.GPU
61
+ def generate_text_gpu(prompt, model_version, temperature, top_p, top_k, max_new_tokens):
62
+ """GPU-decorated function for text generation"""
63
+ try:
64
+ # Load model if not already loaded
65
+ if model_version not in model_registry:
66
+ model_id = MODEL_CONFIG[model_version]
67
+ model, tokenizer = load_model_cpu(model_id)
68
+ model_registry[model_version] = (model, tokenizer)
69
+
70
+ model, tokenizer = model_registry[model_version]
71
+
72
+ # Move model to GPU only inside this function
73
+ if torch.cuda.is_available():
74
+ model = model.cuda()
75
+ device = "cuda"
76
+ else:
77
+ device = "cpu"
78
+
79
+ # Create pipeline inside GPU function
80
  pipeline = TextGenerationPipeline(
81
  model=model,
82
  tokenizer=tokenizer,
83
  return_full_text=False,
84
+ max_length=MAX_LENGTH,
85
+ device=device
86
  )
87
+
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  outputs = pipeline(
89
  prompt,
90
  do_sample=temperature > 0,
 
95
  pad_token_id=tokenizer.pad_token_id,
96
  num_return_sequences=1
97
  )
98
+
99
  response = outputs[0]["generated_text"]
100
+ return response, None
101
+
102
+ except Exception as e:
103
+ error_msg = f"Error generating response: {str(e)}"
104
+ return error_msg, str(e)
105
 
106
+ def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_tokens, history=None):
107
+ """Main generation function that calls GPU function"""
108
+ if history is None:
109
+ history = []
110
+
111
+ # Add user message to history
112
+ history.append({"role": "user", "content": prompt})
113
+
114
+ try:
115
+ # Call GPU function
116
+ response, error = generate_text_gpu(
117
+ prompt, model_version, temperature, top_p, top_k, max_new_tokens
118
+ )
119
+
120
+ if error:
121
+ history.append({"role": "assistant", "content": f"[ERROR] {response}", "model": model_version})
122
+ else:
123
+ history.append({"role": "assistant", "content": response, "model": model_version})
124
+
125
+ # Format history for display
126
  formatted_history = []
127
  for entry in history:
128
  prefix = "👤 User: " if entry["role"] == "user" else f"❄️ [{entry.get('model', 'Model')}]: "
129
  formatted_history.append(f"{prefix}{entry['content']}")
130
+
131
  return response, history, "\n\n".join(formatted_history)
132
+
133
  except Exception as e:
134
+ error_msg = f"Error in generation pipeline: {str(e)}"
135
  history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version})
136
  return error_msg, history, str(history)
137
 
 
252
 
253
  return demo
254
 
255
+ # Initialize demo without loading models (they'll load on-demand)
256
+ print("Initializing Snowflake Models Demo...")
257
+ demo = create_demo()
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  if __name__ == "__main__":
260
+ demo.launch()