Caching computation in a tracing context

Hi!

I’m running experiments where I want to intervene at the last token position of a batch of (quite long) prompts, at different layers, with different vectors and steering weights.

Currently, it seems I have to run the full forward pass for the entire prompt for every intervention (once per layer/vector/weight combination), even though the computation for all tokens except the last one is identical. In vanilla HuggingFace Transformers, I, e.g., could use past_key_values to cache the computation for the prefix and only recompute the last token, but I haven’t found a way to do this with NNsight’s trace context.

Is there a way in NNsight to:

  • Cache or reuse the computation for the first (n-1) tokens,

  • And only recompute the last token’s forward pass with a different intervention at each layer/vector/weight,

  • While still being able to use NNsight’s tracing and intervention features?

Or possibly are there any workarounds?

Thanks!

Do you use NNsight LanguageModel class? Its trace is just a wrapper around HF forward pass, so I think you should be able to pass the past_key_values argument as you do in HF.

Did you get an error when trying it?

Yeah, I’m trying this and getting this error

from nnsight import LanguageModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

model_nn = LanguageModel(model, tokenizer=tokenizer)

prefix = "I went to the kitchen"
prefix_inputs = tokenizer(prefix, return_tensors="pt")
new_text = " and I got a sandwich"

with torch.no_grad():
    prefix_outputs = model(**prefix_inputs, use_cache=True)

with torch.no_grad():
    with model_nn.trace(new_text, past_key_values=prefix_outputs.past_key_values, use_cache=True) as trace:
        h = model_nn.transformer.h[-1].output[0].save()
Traceback (most recent call last):
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/tracing/graph/node.py", line 289, in execute
    self.target.execute(self)
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/intervention/contexts/interleaving.py", line 161, in execute
    graph.model.interleave(interleaver, *invoker_args, fn=method,**kwargs, **invoker_kwargs)
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/modeling/mixins/meta.py", line 52, in interleave
    return super().interleave(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/intervention/base.py", line 341, in interleave
    with interleaver:
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/intervention/interleaver.py", line 129, in __exit__
    raise exc_val
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/intervention/base.py", line 342, in interleave
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/nnsight/modeling/language.py", line 297, in _execute
    return self._model(
           ^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1062, in forward
    transformer_outputs = self.transformer(
                          ^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 829, in forward
    attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py", line 395, in _prepare_4d_causal_attention_mask_for_sdpa
    expanded_4d_mask = attn_mask_converter.to_4d(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gustaw/Documents/concept_vectors/.venv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py", line 139, in to_4d
    expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 3

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
    
  File "<frozen runpy>", line 88, in _run_code
    

NNsightError: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 3```

oh weird, and the exact same code works with huggingface forward?

Yup, it does… :confused:

next_input = tokenizer(new_text, return_tensors="pt")
with torch.no_grad():
    out = model.forward(input_ids=next_input["input_ids"], past_key_values=prefix_outputs.past_key_values)