zqu2004 commited on
Commit
8b1dbc7
·
verified ·
1 Parent(s): d70b214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -4,14 +4,27 @@ from PIL import Image
4
  import torch
5
  import spaces
6
 
7
- # Model name and arguments
8
- repo_name = "cyan2k/molmo-7B-D-bnb-4bit"
9
- arguments = {"device_map": "auto", "torch_dtype": "auto", "trust_remote_code": True}
10
 
11
  # Load the processor and model
12
- processor = AutoProcessor.from_pretrained(repo_name, **arguments)
13
 
14
- model = AutoModelForCausalLM.from_pretrained(repo_name, **arguments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Predefined prompts
17
  prompts = [
@@ -49,7 +62,6 @@ def process_image_and_text(image, text, max_new_tokens, temperature, top_p):
49
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
50
 
51
  return generated_text
52
-
53
 
54
  def chatbot(image, text, history, max_new_tokens, temperature, top_p):
55
  if image is None:
@@ -64,7 +76,7 @@ def update_textbox(prompt):
64
 
65
  # Define the Gradio interface
66
  with gr.Blocks() as demo:
67
- gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
68
 
69
  with gr.Row():
70
  image_input = gr.Image(type="numpy")
 
4
  import torch
5
  import spaces
6
 
7
+ # Flag to use GPU (set to False by default)
8
+ USE_GPU = False
 
9
 
10
  # Load the processor and model
11
+ device = torch.device("cuda" if USE_GPU and torch.cuda.is_available() else "cpu")
12
 
13
+ processor = AutoProcessor.from_pretrained(
14
+ 'allenai/MolmoE-1B-0924',
15
+ trust_remote_code=True,
16
+ torch_dtype='auto',
17
+ )
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ 'allenai/MolmoE-1B-0924',
21
+ trust_remote_code=True,
22
+ torch_dtype='auto',
23
+ device_map='auto' if USE_GPU else None
24
+ )
25
+
26
+ if not USE_GPU:
27
+ model.to(device)
28
 
29
  # Predefined prompts
30
  prompts = [
 
62
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
63
 
64
  return generated_text
 
65
 
66
  def chatbot(image, text, history, max_new_tokens, temperature, top_p):
67
  if image is None:
 
76
 
77
  # Define the Gradio interface
78
  with gr.Blocks() as demo:
79
+ gr.Markdown("# Image Chatbot with MolmoE-1B-0924")
80
 
81
  with gr.Row():
82
  image_input = gr.Image(type="numpy")