Update interpret.py
Browse files- interpret.py +2 -2
interpret.py
CHANGED
|
@@ -90,10 +90,10 @@ class InterpretationPrompt:
|
|
| 90 |
else:
|
| 91 |
raise NotImplementedError
|
| 92 |
|
| 93 |
-
def generate(self, model, embeds, k,
|
| 94 |
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
|
| 95 |
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
|
| 96 |
-
module = model.get_submodule(
|
| 97 |
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
|
| 98 |
generated = model.generate(tokens_batch, **generation_kwargs)
|
| 99 |
return generated
|
|
|
|
| 90 |
else:
|
| 91 |
raise NotImplementedError
|
| 92 |
|
| 93 |
+
def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs):
|
| 94 |
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
|
| 95 |
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
|
| 96 |
+
module = model.get_submodule(layers_format.format(k=k))
|
| 97 |
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
|
| 98 |
generated = model.generate(tokens_batch, **generation_kwargs)
|
| 99 |
return generated
|