lyimo commited on
Commit
2be2417
Β·
1 Parent(s): 7a9342b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -1,33 +1,29 @@
 
1
  import gradio as gr
 
 
2
  import copy
3
- import os
4
- import time
5
- import llama_cpp
6
  from llama_cpp import Llama
7
  from huggingface_hub import hf_hub_download
8
- from fastai.vision.all import *
9
 
10
- # Load the LLM model
 
 
 
 
11
  llm = Llama(
12
  model_path=hf_hub_download(
13
  repo_id=os.environ.get("REPO_ID", "TheBloke/Llama-2-7B-Chat-GGML"),
14
  filename=os.environ.get("MODEL_FILE", "llama-2-7b-chat.ggmlv3.q5_0.bin"),
15
  ),
16
  n_ctx=2048,
17
- n_gpu_layers=50, # change n_gpu_layers if you have more or less VRAM
18
  )
19
 
20
  history = []
21
-
22
  system_message = """
23
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe...
24
  """
25
- # The rest of the system message
26
-
27
- # Load the Vision Model
28
- learn = load_learner('export.pkl')
29
-
30
- labels = learn.dls.vocab
31
 
32
  def generate_text(message, history):
33
  temp = ""
@@ -59,36 +55,25 @@ def generate_text(message, history):
59
  temp += stream["choices"][0]["text"]
60
  yield temp
61
 
62
- history = ["init", input_prompt]
 
63
 
64
- def predict(img):
65
- try:
66
- img = PILImage.create(img)
67
- except:
68
- return {"bird": "Unknown"}
69
  pred, pred_idx, probs = learn.predict(img)
70
- return {"bird": labels[pred_idx], "probs": {labels[i]: float(probs[i]) for i in range(len(labels))}}
71
-
72
- title = "Bird Detector with LLM"
73
- description = "Detect birds and get LLM responses."
74
- examples = [{"img": "BIRD.png", "message": "Tell me about the bird."}]
75
- interpretation = 'default'
76
- enable_queue = True
77
-
78
- def combined(img, message):
79
- prediction = predict(img)
80
- response = list(generate_text(f"I have detected {prediction['bird']} in the image. {message}", history))
81
- return response[0] # Return the first generated response
82
 
 
83
  gr.Interface(
84
- fn=combined,
85
- inputs=[
86
- gr.inputs.Image(),
87
- gr.inputs.Textbox(label="Message to LLM")
88
- ],
89
  outputs=gr.outputs.Textbox(),
90
- title=title,
91
- description=description,
92
- examples=examples,
93
- interpretation=interpretation,
94
  ).launch()
 
1
+ import os
2
  import gradio as gr
3
+ from fastai.vision.all import *
4
+ import skimage
5
  import copy
 
 
 
6
  from llama_cpp import Llama
7
  from huggingface_hub import hf_hub_download
 
8
 
9
+ # Load the FastAI vision model
10
+ learn = load_learner('export.pkl')
11
+ labels = learn.dls.vocab
12
+
13
+ # Load the Llama language model
14
  llm = Llama(
15
  model_path=hf_hub_download(
16
  repo_id=os.environ.get("REPO_ID", "TheBloke/Llama-2-7B-Chat-GGML"),
17
  filename=os.environ.get("MODEL_FILE", "llama-2-7b-chat.ggmlv3.q5_0.bin"),
18
  ),
19
  n_ctx=2048,
20
+ n_gpu_layers=50,
21
  )
22
 
23
  history = []
 
24
  system_message = """
25
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
26
  """
 
 
 
 
 
 
27
 
28
  def generate_text(message, history):
29
  temp = ""
 
55
  temp += stream["choices"][0]["text"]
56
  yield temp
57
 
58
+ history.append(("USER:", message))
59
+ history.append(("ASSISTANT:", temp))
60
 
61
+ # Define the predict function for the FastAI model
62
+ def predict_with_llama_and_generate_text(img):
63
+ img = PILImage.create(img)
 
 
64
  pred, pred_idx, probs = learn.predict(img)
65
+ detected_object = labels[pred_idx]
66
+
67
+ response = f"The system has detected {detected_object}. Do you want to know about {detected_object}?"
68
+
69
+ for llama_response in generate_text(response, history):
70
+ yield llama_response
 
 
 
 
 
 
71
 
72
+ # Define the Gradio interface
73
  gr.Interface(
74
+ fn=predict_with_llama_and_generate_text,
75
+ inputs=gr.inputs.Image(shape=(512, 512)),
 
 
 
76
  outputs=gr.outputs.Textbox(),
77
+ title="Multimodal Assistant",
78
+ description="An AI model that combines image classification with text generation.",
 
 
79
  ).launch()