hackeracademy commited on
Commit
722f6b0
·
1 Parent(s): 90f5d7c

Fix Gradio signature & set MPLCONFIGDIR

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, gradio as gr, requests, tempfile
2
  from llama_cpp import Llama
3
 
4
  MODEL_URL = (
@@ -6,29 +6,39 @@ MODEL_URL = (
6
  "resolve/main/foundation-sec-8b-q4_k_m.gguf"
7
  )
8
 
9
- # writable directory
10
  CACHE_DIR = "/tmp"
11
  MODEL_PATH = os.path.join(CACHE_DIR, "foundation-sec-8b-q4_k_m.gguf")
12
 
13
- # download only once
 
 
 
14
  if not os.path.exists(MODEL_PATH):
15
- print("Downloading model … (~4.9 GB)")
16
  with requests.get(MODEL_URL, stream=True) as r:
17
  r.raise_for_status()
18
  with open(MODEL_PATH, "wb") as f:
19
  for chunk in r.iter_content(chunk_size=8192):
20
  f.write(chunk)
21
- print("Download finished.")
22
 
23
- # load model
24
  llm = Llama(model_path=MODEL_PATH, n_ctx=4096, verbose=False)
25
 
 
26
  def chat_fn(message, history):
27
- messages = [{"role": "user", "content": message}]
28
- out = llm.create_chat_completion(messages=messages, max_tokens=256, temperature=0.7)
 
 
 
 
 
 
 
 
 
 
29
  return out["choices"][0]["message"]["content"]
30
 
31
  demo = gr.ChatInterface(chat_fn, title="Foundation-Sec-8B")
32
-
33
- # expose on 0.0.0.0:7860 (Gradio default)
34
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os, gradio as gr, requests, tempfile, logging, time
2
  from llama_cpp import Llama
3
 
4
  MODEL_URL = (
 
6
  "resolve/main/foundation-sec-8b-q4_k_m.gguf"
7
  )
8
 
 
9
  CACHE_DIR = "/tmp"
10
  MODEL_PATH = os.path.join(CACHE_DIR, "foundation-sec-8b-q4_k_m.gguf")
11
 
12
+ # silence matplotlib cache warning
13
+ os.environ["MPLCONFIGDIR"] = CACHE_DIR
14
+
15
+ # download once
16
  if not os.path.exists(MODEL_PATH):
17
+ logging.info("Downloading model …")
18
  with requests.get(MODEL_URL, stream=True) as r:
19
  r.raise_for_status()
20
  with open(MODEL_PATH, "wb") as f:
21
  for chunk in r.iter_content(chunk_size=8192):
22
  f.write(chunk)
23
+ logging.info("Download finished.")
24
 
 
25
  llm = Llama(model_path=MODEL_PATH, n_ctx=4096, verbose=False)
26
 
27
+ # correct signature: message, history
28
  def chat_fn(message, history):
29
+ messages = []
30
+ for human, ai in history:
31
+ messages.append({"role": "user", "content": human})
32
+ messages.append({"role": "assistant", "content": ai})
33
+ messages.append({"role": "user", "content": message})
34
+
35
+ out = llm.create_chat_completion(
36
+ messages=messages,
37
+ max_tokens=512,
38
+ temperature=0.7,
39
+ stream=False,
40
+ )
41
  return out["choices"][0]["message"]["content"]
42
 
43
  demo = gr.ChatInterface(chat_fn, title="Foundation-Sec-8B")
 
 
44
  demo.launch(server_name="0.0.0.0", server_port=7860)