0.5 introduces several breaking changes so attribution patching example no longer works out of the box.
After some tinkering and help from the team, here is the snippet that works (tested on 0.5.0.dev9)
clean_out = []
corrupted_out = []
corrupted_grads = []
# run clean and corrupt in different traces (wasn't working in one)
with model.trace(clean_tokens) as tracer:
for layer in model.transformer.h:
attn_out = layer.attn.c_proj.input
clean_out.append(attn_out.save())
with model.trace(corrupted_tokens) as tracer:
for layer in model.transformer.h:
attn_out = layer.attn.c_proj.input
corrupted_out.append(attn_out.save())
logits = model.lm_head.output.save()
value = ioi_metric(logits.cpu())
# new backward context
with value.backward():
for saved_attn in corrupted_out[::-1]: # reversed access to activations
corrupted_grads.append(saved_attn.grad.save())