tafxle commited on
Commit
d8e2347
1 Parent(s): 3ac45a0

Cache + Measure time

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -1,20 +1,23 @@
1
  import torch
2
  import transformers
3
- import numpy as np
4
  from huggingface_hub import hf_hub_download
 
5
 
6
 
7
- tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
8
-
9
- hf_hub_download("OpenDungeon/gpt-j-8bit-ffbgem", "model.pt")
 
 
10
 
11
- qmodel = torch.load("model.pt")
12
 
13
  def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
14
  past_key_values = None # used to keep track of conversation history
15
  input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
16
  output = [""] * batch
17
-
 
18
  with torch.inference_mode():
19
  for i in range(limit_tokens + 20):
20
  if i == 5:
@@ -33,16 +36,25 @@ def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_toke
33
  if single_hook is not None:
34
  single_hook(tokenizer.decode(token_ix[0]))
35
  if i == limit_tokens:
36
- print()
37
- print((time.perf_counter() - start_time) / (i - 4), "s per token")
38
  break
39
 
40
  input_dict = dict(input_ids=token_ix)
41
- print()
42
- return output
43
-
44
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
45
 
 
46
 
47
- text = st.text_area("Prompt")
48
- PrintContinuation(text, qmodel, lambda x: t.markdown(f"## {x}..."))
 
1
  import torch
2
  import transformers
3
+ import time
4
  from huggingface_hub import hf_hub_download
5
+ import streamlit as st
6
 
7
 
8
+ @st.cache
9
+ def load_model():
10
+ hf_hub_download("OpenDungeon/gpt-j-8bit-ffbgem", "model.pt")
11
+ qmodel = torch.load("model.pt")
12
+ return transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"), qmodel
13
 
 
14
 
15
  def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
16
  past_key_values = None # used to keep track of conversation history
17
  input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
18
  output = [""] * batch
19
+ batch_time = 0
20
+
21
  with torch.inference_mode():
22
  for i in range(limit_tokens + 20):
23
  if i == 5:
 
36
  if single_hook is not None:
37
  single_hook(tokenizer.decode(token_ix[0]))
38
  if i == limit_tokens:
39
+ batch_time = (time.perf_counter() - start_time) / (i - 4)
 
40
  break
41
 
42
  input_dict = dict(input_ids=token_ix)
43
+ return output, batch_time
44
+
45
+
46
+ tokenizer, model = load_model()
47
+ text = st.text_area("Prefix")
48
+ batch = st.number_input("Variants", value=1)
49
+
50
+ t = st.empty()
51
+ firstline = ""
52
+
53
+ def PrintSome(text):
54
+ global t, firstline
55
+ firstline += text
56
+ t.markdown(f"## {firstline}...")
57
 
58
+ choices, batch_time = PrintContinuation(text, model, PrintSome, batch, 50)
59
 
60
+ t.markdown(" \n\n".join(choices) + f" \n\nBatch:Seconds per batch: {batch_time}, Batch: {batch}")