What is a Mediator exception? I got this trying to fix the attribution patching tutorial
---------------------------------------------------------------------------
NNsightException Traceback (most recent call last)
/tmp/ipython-input-3154363494.py in <cell line: 0>()
3 corrupted_grads = []
4
----> 5 with model.trace() as tracer:
6 # Using nnsight's tracer.invoke context, we can batch the clean and the
7 # corrupted runs into the same tracing context, allowing us to access
1 frames/usr/local/lib/python3.12/dist-packages/nnsight/intervention/tracing/base.py in __exit__(self, exc_type, exc_val, exc_tb)
599 # This is the expected case - the traced code was intercepted
600 # Execute the captured code using the configured backend
--> 601 self.backend(self)
602
603 # Return True to suppress the ExitTracingException
/usr/local/lib/python3.12/dist-packages/nnsight/intervention/backends/execution.py in __call__(self, tracer)
22 except Exception as e:
23
---> 24 raise wrap_exception(e, tracer.info) from None
25 finally:
26 Globals.exit()
NNsightException:
Traceback (most recent call last):
File "/tmp/ipython-input-3154363494.py", line 56, in <cell line: 0>
corrupted_grads.append(layer.attn.c_proj.input.grad.save())
File "/usr/local/lib/python3.12/dist-packages/nnsight/intervention/envoy.py", line 256, in input
inputs = self.inputs
File "/usr/local/lib/python3.12/dist-packages/nnsight/intervention/envoy.py", line 205, in inputs
return self._interleaver.current.request(
File "/usr/local/lib/python3.12/dist-packages/nnsight/intervention/interleaver.py", line 396, in current
return self.mediators[threading.current_thread().name]
KeyError: 'Mediator136692785885312'
This is the cell that failed:
clean_out = []
corrupted_out = []
corrupted_grads = []
with model.trace() as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass
with tracer.invoke(clean_tokens) as invoker_clean:
# Gather each layer's attention
for layer in model.transformer.h:
# Get clean attention output for this layer
# across all attention heads
attn_out = layer.attn.c_proj.input
clean_out.append(attn_out.save())
with tracer.invoke(corrupted_tokens) as invoker_corrupted:
# Gather each layer's attention and gradients
for layer in model.transformer.h:
# Get corrupted attention output for this layer
# across all attention heads
corrupted_out.append(layer.attn.c_proj.input.save())
# Let's get the logits for the model's output
# for the corrupted run
logits = model.lm_head.output.save()
# Our metric uses tensors saved on cpu, so we
# need to move the logits to cpu.
value = ioi_metric(logits.cpu())
# We also need to run a backwards pass to
# update gradient values
with value.backward():
for layer in model.transformer.h:
# save corrupted gradients for attribution patching
corrupted_grads.append(layer.attn.c_proj.input.grad.save())