ejschwartz commited on
Commit
3035027
·
1 Parent(s): f01d69f

Remove logging and disable field model

Browse files
Files changed (1) hide show
  1. app.py +28 -45
app.py CHANGED
@@ -3,7 +3,7 @@ from gradio_client import Client
3
  from gradio_client.exceptions import AppError
4
  import frontmatter
5
  import os
6
- #import spaces
7
  import torch
8
  import logging
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -16,18 +16,6 @@ logging.basicConfig(
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
- # Enable transformers logging
20
- transformers_logging.set_verbosity_debug()
21
- transformers_logging.enable_default_handler()
22
- transformers_logging.enable_explicit_format()
23
-
24
- # Enable accelerate and torch logging
25
- logging.getLogger("accelerate").setLevel(logging.DEBUG)
26
- logging.getLogger("torch").setLevel(logging.DEBUG)
27
- logging.getLogger("spaces").setLevel(logging.DEBUG)
28
- logging.getLogger("spaces.zero").setLevel(logging.DEBUG)
29
- logging.getLogger("transformers").setLevel(logging.DEBUG)
30
-
31
  import huggingface_hub
32
 
33
  import prep_decompiled
@@ -43,26 +31,18 @@ tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
43
  vardecoder_model = AutoModelForCausalLM.from_pretrained(
44
  "ejschwartz/resym-vardecoder",
45
  torch_dtype=torch.bfloat16,
 
46
  )
47
  print("Loaded vardecoder model successfully.")
48
 
49
- print(f"Model device: {next(vardecoder_model.parameters()).device}")
50
- print(f"Model dtype: {next(vardecoder_model.parameters()).dtype}")
51
- print(f"Model is meta: {next(vardecoder_model.parameters()).is_meta}")
52
- print(f"Model parameters: {sum(p.numel() for p in vardecoder_model.parameters() if p.requires_grad):,}")
53
-
54
- # Check memory after first model
55
- print(f"GPU memory after vardecoder:")
56
- print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
57
- print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
58
-
59
  logger.info("Loading fielddecoder model...")
60
 
61
- fielddecoder_model = AutoModelForCausalLM.from_pretrained(
62
- "ejschwartz/resym-fielddecoder",
63
- torch_dtype=torch.bfloat16,
64
- )
65
- logger.info("Successfully loaded fielddecoder model")
 
66
 
67
  make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
68
 
@@ -155,23 +135,26 @@ def infer(code):
155
  :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
156
  ]
157
 
158
- field_output = fielddecoder_model.generate(
159
- input_ids=field_input_ids,
160
- max_new_tokens=MAX_NEW_TOKENS,
161
- num_beams=4,
162
- num_return_sequences=1,
163
- do_sample=False,
164
- early_stopping=False,
165
- pad_token_id=0,
166
- eos_token_id=0,
167
- )[0]
168
- field_output = tokenizer.decode(
169
- field_output[field_input_ids.size(1) :],
170
- skip_special_tokens=True,
171
- clean_up_tokenization_spaces=True,
172
- )
173
-
174
- field_output = fields[0] + ":" + field_output
 
 
 
175
  var_output = first_var + ":" + var_output
176
  fieldstring = ", ".join(fields)
177
  return var_output, field_output, varstring, fieldstring
 
3
  from gradio_client.exceptions import AppError
4
  import frontmatter
5
  import os
6
+ import spaces
7
  import torch
8
  import logging
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
16
  )
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  import huggingface_hub
20
 
21
  import prep_decompiled
 
31
  vardecoder_model = AutoModelForCausalLM.from_pretrained(
32
  "ejschwartz/resym-vardecoder",
33
  torch_dtype=torch.bfloat16,
34
+ device_map="auto",
35
  )
36
  print("Loaded vardecoder model successfully.")
37
 
 
 
 
 
 
 
 
 
 
 
38
  logger.info("Loading fielddecoder model...")
39
 
40
+ fielddecoder_model = None
41
+ #fielddecoder_model = AutoModelForCausalLM.from_pretrained(
42
+ # "ejschwartz/resym-fielddecoder",
43
+ # torch_dtype=torch.bfloat16,
44
+ #)
45
+ #logger.info("Successfully loaded fielddecoder model")
46
 
47
  make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
48
 
 
135
  :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
136
  ]
137
 
138
+ if fielddecoder_model is None:
139
+ field_output = "TEMPORARILY DISABLED"
140
+ else:
141
+ field_output = fielddecoder_model.generate(
142
+ input_ids=field_input_ids,
143
+ max_new_tokens=MAX_NEW_TOKENS,
144
+ num_beams=4,
145
+ num_return_sequences=1,
146
+ do_sample=False,
147
+ early_stopping=False,
148
+ pad_token_id=0,
149
+ eos_token_id=0,
150
+ )[0]
151
+ field_output = tokenizer.decode(
152
+ field_output[field_input_ids.size(1) :],
153
+ skip_special_tokens=True,
154
+ clean_up_tokenization_spaces=True,
155
+ )
156
+
157
+ field_output = fields[0] + ":" + field_output
158
  var_output = first_var + ":" + var_output
159
  fieldstring = ", ".join(fields)
160
  return var_output, field_output, varstring, fieldstring