Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,8 @@ class GlobalState:
|
|
| 23 |
model : Optional[PreTrainedModel] = None
|
| 24 |
hidden_states : Optional[torch.Tensor] = None
|
| 25 |
interpretation_prompt_template : str = '{prompt}'
|
| 26 |
-
original_prompt_template : str = '{prompt}'
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
suggested_interpretation_prompts = [
|
|
@@ -46,6 +47,7 @@ def reset_model(model_name, *extra_components):
|
|
| 46 |
model_path = model_args.pop('model_path')
|
| 47 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
| 48 |
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
|
|
|
| 49 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
| 50 |
use_ctransformers = model_args.pop('ctransformers', False)
|
| 51 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
|
@@ -96,7 +98,7 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
| 96 |
|
| 97 |
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
|
| 98 |
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
|
| 99 |
-
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
|
| 100 |
|
| 101 |
# generate the interpretations
|
| 102 |
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
|
@@ -138,23 +140,24 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 138 |
|
| 139 |
gr.Markdown(
|
| 140 |
'''
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
''', line_breaks=True)
|
| 146 |
|
| 147 |
# with gr.Column(scale=1):
|
| 148 |
# gr.Markdown('<span style="font-size:180px;">π€</span>')
|
| 149 |
|
| 150 |
with gr.Group():
|
| 151 |
-
model_chooser = gr.Radio(label='Model', choices=list(model_info.keys()), value=model_name)
|
| 152 |
|
| 153 |
with gr.Blocks() as demo_blocks:
|
| 154 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
| 155 |
with gr.Group('Interpretation'):
|
| 156 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 157 |
-
gr.Examples([[p] for p in suggested_interpretation_prompts],
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
gr.Markdown('## The Prompt to Analyze')
|
|
@@ -198,8 +201,8 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
| 198 |
|
| 199 |
|
| 200 |
# event listeners
|
| 201 |
-
extra_components = [
|
| 202 |
-
|
| 203 |
model_chooser.change(reset_model, [model_chooser, *extra_components], extra_components)
|
| 204 |
|
| 205 |
for i, btn in enumerate(tokens_container):
|
|
|
|
| 23 |
model : Optional[PreTrainedModel] = None
|
| 24 |
hidden_states : Optional[torch.Tensor] = None
|
| 25 |
interpretation_prompt_template : str = '{prompt}'
|
| 26 |
+
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
|
| 27 |
+
layers_format : str = 'model.layers.{k}'
|
| 28 |
|
| 29 |
|
| 30 |
suggested_interpretation_prompts = [
|
|
|
|
| 47 |
model_path = model_args.pop('model_path')
|
| 48 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
| 49 |
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
| 50 |
+
global_state.layers_format = model_args.pop('layers_format')
|
| 51 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
| 52 |
use_ctransformers = model_args.pop('ctransformers', False)
|
| 53 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
|
|
|
| 98 |
|
| 99 |
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
|
| 100 |
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
|
| 101 |
+
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt, layers_format=global_state.layers_format)
|
| 102 |
|
| 103 |
# generate the interpretations
|
| 104 |
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
|
|
|
| 140 |
|
| 141 |
gr.Markdown(
|
| 142 |
'''
|
| 143 |
+
**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ**
|
| 144 |
+
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers.
|
| 145 |
+
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand,
|
| 146 |
+
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! π―π―π―
|
| 147 |
''', line_breaks=True)
|
| 148 |
|
| 149 |
# with gr.Column(scale=1):
|
| 150 |
# gr.Markdown('<span style="font-size:180px;">π€</span>')
|
| 151 |
|
| 152 |
with gr.Group():
|
| 153 |
+
model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
|
| 154 |
|
| 155 |
with gr.Blocks() as demo_blocks:
|
| 156 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
| 157 |
with gr.Group('Interpretation'):
|
| 158 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
| 159 |
+
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
|
| 160 |
+
[interpretation_prompt], cache_examples=False)
|
| 161 |
|
| 162 |
|
| 163 |
gr.Markdown('## The Prompt to Analyze')
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
# event listeners
|
| 204 |
+
extra_components = [interpretation_prompt, interpretation_prompt_examples, original_prompt_raw, *tokens_container,
|
| 205 |
+
original_prompt_btn, *interpretation_bubbles]
|
| 206 |
model_chooser.change(reset_model, [model_chooser, *extra_components], extra_components)
|
| 207 |
|
| 208 |
for i, btn in enumerate(tokens_container):
|