Hi guys! I’m trying nnsight for the first time and have a simple question that I couldn’t quite figure out from the tutorials:
We can edit activations, for example we can replace MLP8 outputs with MLP2 outputs:
with gpt2.trace(batch) as tracer:
gpt2.transformer.h[8].mlp.output = gpt2.transformer.h[2].mlp.output
But what if we do it the other way around? What happens if we replace MLP2 outputs with MLP8 outputs?
with gpt2.trace(batch) as tracer:
gpt2.transformer.h[2].mlp.output = gpt2.transformer.h[8].mlp.output
I understand that nnsight builds a graph internally to track and execute all operations but I don’t understand how this works exactly. In this case, we have cyclical dependencies because mlp8-out depends on mlp2-out of course but in the second example, mlp2-out also depends on mlp8-out. So what is supposed to happen here and why? I expected either an error or mlp2-out activations overwritten with something but instead nothing happened, no error but also nothing changed in mlp2-out.
Ultimately, I wanna use nnsight to run the model until e.g. MLP0-in, then patch MLP0-out (e.g. using activations calculated from MLP0-in), then continue to run the model until MLP1-in, and so on. And I tried something like this for efficiency purposes:
with gpt2.trace(batch) as tracer:
mlp_ins = []
for l in range(n_layers):
mlp_ins.append(gpt2.transformer.h[l].ln_2.input) # cache activations
mlp_ins = torch.stack(mlp_ins)
mlp_outs = ... # calculate from the concatenated mlp_ins
for l in range(n_layeres):
gpt2.transformer.h[l].mlp.output = mlp_outs[l] # patch activations
I now wonder what will happen and what’s supposed to happen? Mathematically, every activation value we patch in only depends on previous layers - but maybe not in the computational graph since we stack all cached activations together? What’d be the correct approach to solve this?