OneEyeDJ commited on
Commit
40afd1d
·
verified ·
1 Parent(s): d681921

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -20
main.py CHANGED
@@ -6,25 +6,40 @@ import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoProcessor
7
  import argparse
8
  import os
 
9
 
10
  class SimpleVideoLLaMA3Interface:
11
  def __init__(self, model_path):
12
- print(f"Loading model from {model_path}...")
13
- self.model = AutoModelForCausalLM.from_pretrained(
14
- model_path,
15
- trust_remote_code=True,
16
- device_map="auto",
17
- torch_dtype=torch.bfloat16,
18
- attn_implementation="flash_attention_2",
19
- )
20
- self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
21
- print("Model loaded successfully!")
22
-
23
  self.image_formats = ("png", "jpg", "jpeg", "bmp", "gif", "webp")
24
  self.video_formats = ("mp4", "avi", "mov", "mkv", "webm", "m4v", "3gp", "flv")
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @torch.inference_mode()
27
  def predict(self, messages, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=4096, fps=10, max_frames=256):
 
 
 
28
  if not messages or len(messages) == 0:
29
  return messages
30
 
@@ -202,13 +217,9 @@ class SimpleVideoLLaMA3Interface:
202
 
203
  return interface
204
 
205
- if __name__ == "__main__":
206
- parser = argparse.ArgumentParser()
207
- parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoLLaMA3-7B")
208
- parser.add_argument("--port", type=int, default=7860)
209
- parser.add_argument("--share", action="store_true")
210
- args = parser.parse_args()
211
 
212
- app = SimpleVideoLLaMA3Interface(args.model_path)
213
- interface = app.create_interface()
214
- interface.launch(server_port=args.port, share=args.share, server_name="0.0.0.0")
 
6
  from transformers import AutoModelForCausalLM, AutoProcessor
7
  import argparse
8
  import os
9
+ import spaces # Import spaces for ZEROGPU
10
 
11
  class SimpleVideoLLaMA3Interface:
12
  def __init__(self, model_path):
13
+ self.model_path = model_path
14
+ self.model = None
15
+ self.processor = None
 
 
 
 
 
 
 
 
16
  self.image_formats = ("png", "jpg", "jpeg", "bmp", "gif", "webp")
17
  self.video_formats = ("mp4", "avi", "mov", "mkv", "webm", "m4v", "3gp", "flv")
18
+
19
+ # Load processor on CPU (doesn't need GPU)
20
+ print(f"Loading processor from {model_path}...")
21
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
22
+ print("Processor loaded successfully!")
23
 
24
+ def load_model(self):
25
+ """Load model - this will be called inside GPU-decorated functions"""
26
+ if self.model is None:
27
+ print(f"Loading model from {self.model_path}...")
28
+ self.model = AutoModelForCausalLM.from_pretrained(
29
+ self.model_path,
30
+ trust_remote_code=True,
31
+ device_map="auto",
32
+ torch_dtype=torch.bfloat16,
33
+ attn_implementation="flash_attention_2",
34
+ )
35
+ print("Model loaded successfully!")
36
+
37
+ @spaces.GPU(duration=120) # Allocate GPU for up to 120 seconds
38
  @torch.inference_mode()
39
  def predict(self, messages, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=4096, fps=10, max_frames=256):
40
+ # Load model inside GPU context
41
+ self.load_model()
42
+
43
  if not messages or len(messages) == 0:
44
  return messages
45
 
 
217
 
218
  return interface
219
 
220
+ # For Hugging Face Spaces
221
+ app = SimpleVideoLLaMA3Interface("DAMO-NLP-SG/VideoLLaMA3-7B")
222
+ interface = app.create_interface()
 
 
 
223
 
224
+ if __name__ == "__main__":
225
+ interface.launch()