Update app.py
Browse files
app.py
CHANGED
|
@@ -75,15 +75,15 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
| 75 |
+ [*extra_components])
|
| 76 |
|
| 77 |
|
| 78 |
-
def get_hidden_states(raw_original_prompt):
|
| 79 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 80 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 82 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
| 83 |
-
|
| 84 |
-
if global_state.wait_with_hidden_states:
|
| 85 |
global_state.local_state.hidden_states = None
|
| 86 |
else:
|
|
|
|
| 87 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
| 88 |
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
| 89 |
|
|
@@ -102,7 +102,7 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
| 102 |
tokenizer = global_state.tokenizer
|
| 103 |
print(f'run {model}')
|
| 104 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
| 105 |
-
get_hidden_states(raw_original_prompt)
|
| 106 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
| 107 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 108 |
|
|
|
|
| 75 |
+ [*extra_components])
|
| 76 |
|
| 77 |
|
| 78 |
+
def get_hidden_states(raw_original_prompt, force_hidden_states=False):
|
| 79 |
model, tokenizer = global_state.model, global_state.tokenizer
|
| 80 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
| 81 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
| 82 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
| 83 |
+
if global_state.wait_with_hidden_states and not force_hidden_states:
|
|
|
|
| 84 |
global_state.local_state.hidden_states = None
|
| 85 |
else:
|
| 86 |
+
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
| 87 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
| 88 |
global_state.local_state.hidden_states = hidden_states.cpu().detach()
|
| 89 |
|
|
|
|
| 102 |
tokenizer = global_state.tokenizer
|
| 103 |
print(f'run {model}')
|
| 104 |
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
|
| 105 |
+
get_hidden_states(raw_original_prompt, force_hidden_states=True)
|
| 106 |
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
|
| 107 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
| 108 |
|