Intervention Graph and Execution Order

Hi guys! I’m trying nnsight for the first time and have a simple question that I couldn’t quite figure out from the tutorials:

We can edit activations, for example we can replace MLP8 outputs with MLP2 outputs:

with gpt2.trace(batch) as tracer:
    gpt2.transformer.h[8].mlp.output = gpt2.transformer.h[2].mlp.output

But what if we do it the other way around? What happens if we replace MLP2 outputs with MLP8 outputs?

with gpt2.trace(batch) as tracer:
    gpt2.transformer.h[2].mlp.output = gpt2.transformer.h[8].mlp.output

I understand that nnsight builds a graph internally to track and execute all operations but I don’t understand how this works exactly. In this case, we have cyclical dependencies because mlp8-out depends on mlp2-out of course but in the second example, mlp2-out also depends on mlp8-out. So what is supposed to happen here and why? I expected either an error or mlp2-out activations overwritten with something but instead nothing happened, no error but also nothing changed in mlp2-out.

Ultimately, I wanna use nnsight to run the model until e.g. MLP0-in, then patch MLP0-out (e.g. using activations calculated from MLP0-in), then continue to run the model until MLP1-in, and so on. And I tried something like this for efficiency purposes:

with gpt2.trace(batch) as tracer:
    mlp_ins = []
    for l in range(n_layers):
        mlp_ins.append(gpt2.transformer.h[l].ln_2.input)  # cache activations
    mlp_ins = torch.stack(mlp_ins)
    mlp_outs = ...  # calculate from the concatenated mlp_ins
    for l in range(n_layeres):
        gpt2.transformer.h[l].mlp.output = mlp_outs[l]  # patch activations

I now wonder what will happen and what’s supposed to happen? Mathematically, every activation value we patch in only depends on previous layers - but maybe not in the computational graph since we stack all cached activations together? What’d be the correct approach to solve this?

Thanks for your detailed post. It’s awesome to see new users looking to understand how nnsight works in-depth.

I’m detecting two points/questions in your post:

  1. How does the order of accessing modules affect results?

In short, trying to set a module’s input or output with a value computed later on in the model’s computation graph (e.g. set layer_2 output with output from layer_8) will not have any result.

This is because layer_2 is dependent on layer_8 having already executed; by the time layer_8 is executed the model can’t go back in its computation graph and set layer_2 to have a downstream effect.

If you have such a use case, then you will need to run two forward passes to achieve your desired effect. See below for an example on how to do that.

  1. Custom user case?

If you need to have all the “MLP inputs” and do some operation on them before patching each layer’s output, do you will to run two forward passes. The most efficient way to accomplish that is using the nnsight session:

with gpt2.session():
   with gpt2.trace(input):
      mlp_ins = []
      for l in range(n_layers):
          mlp_ins.append(gpt2.transformer.h[l].ln_2.input)  # cache activations
      mlp_ins = torch.stack(mlp_ins)
      mlp_outs = ...  # calculate from the concatenated mlp_ins

   with gpt2.trace(input):
       for l in range(n_layers):
          gpt2.transformer.h[l].mlp.output = mlp_outs[l]

Alternatively, if you are simply skipping the MLP layer of each transformer block, then you can simply patch the input of the layer norm directly into the output of the MLP:

with gpt2(input):
    for l in range(n_layers):
        gpt2.transformer.h[l].mlp.output = gpt2.transformer.h[l].ln_2.input

I hope this is helpful! Let me know if you have any further questions or need more help with your use case.

Thank you a lot!

I’m trying to understand the internals a bit better: What exactly happens under the hood if I say gpt2.transformer.h[2].ln_2.input = ...? Will this register a hook on the specific transformer nn module or is this done differently? I’m trying to understand why no error is thrown when I try to set it to some value which isn’t yet computed.

For my specific use case, the solution you provided wouldn’t work because I don’t want to freeze and patch in activations but apply interventions at each layer, with interventions on earlier layers affecting interventions in later layers.

For example something like this:

with gpt2.trace(input):
    cache = []
    for l in range(n_layers):
        cache.append(gpt2.transformer.h[l].ln_2.input)
        new_actvs = fancy_op(cache)
        gpt2.transformer.h[l].mlp.output = new_actvs

So basically, I wanna run the model until MLP0-in and save this vector, then continue until MLP0-out and replace its activations with something I computed from MLP0-in. Then, I wanna continue to run this patched model to MLP1-in, save vector as well, then continue to MLP1-out and replace its activations with new values computed from MLP0-in and MLP1-in.

Does this work with my code? Because in this version, all activations to patch can be computed from activations in lower layers.

Follow up question: If my new approach works, how about this:

with gpt2.trace(input):
    cache = []
    for l in range(n_layers):
        cache.append(gpt2.transformer.h[l].ln_2.input)
    cache  = torch.stack(cache).unbind(0)  # concat all values into one tensor and reverse it
    for l in range(n_layers):
        new_actvs = fancy_op(cache[:l])
        gpt2.transformer.h[l].mlp.output = new_actvs

So this basically does exactly the same as the example above, it should be equivalent. But will it yield the same result? Or is nothing patched now because the stack operation makes earlier layers dependent on later layers in the graph (although in reality they aren’t)?

I think my primary fear is: If I don’t take good care, some of my edits might not get applied and I don’t get an error message about it, so I need to understand what goes on under the hood.

1) What happens when gpt2.transformer.h[2].ln_2.input = ... is called?

.input is an entrypoint on the Envoy (wrapper) of the module from the path gpt2.transformer.h[2].ln_2. Calling this entrypoint will register an input hook on its corresponding pyTorch module. The hook is executed when the module is reached in the model’s computation graph. At the same time, it will also create a root node on the nnsight intervention graph that will be populated with the input value of the module when the hook is called.

Doing gpt2.transformer.h[2].ln_2.input = gpt2.transformer.h[8].ln_2.input will also add a node for the layer 8 ln_2 input (also register a hook for it) and also add a swap node to execute the setting operation which will be dependent on both layer 2 and layer 8. Here’s a look at the exact intervention graph for this intervention:

with gpt2("Hi") as tracer:
    gpt2.transformer.h[2].ln_2.input = gpt2.transformer.h[8].ln_2.input
    tracer.vis(title="Intervention Graph", save=True)

Because the swap node is also dependent on the layer 8 InterventionProtocol node, it doesn’t get executed during the model’s execution of layer 2 (nodes are only executed when all their dependencies have been executed, right), and so the swap never has any downstream effect on the results.

2) use case

Your new solution should work properly because it doesn’t call any modules out-of-order.

Your second new solution may not work as intended because you are calling later modules in the first loop and then calling earlier modules again in the second loop. However! It might still work fine, because we do some “fancy” skipping of operations under the hood in the intervention graph execution and your setting/swapping is not dependent on any later values, so I think it will work. Regardless, your first solution is more efficient and should work as intended.

Hope this makes sense :slight_smile:

1 Like