What is a Mediator exception?

What is a Mediator exception? I got this trying to fix the attribution patching tutorial

 ---------------------------------------------------------------------------
NNsightException                          Traceback (most recent call last)
/tmp/ipython-input-3154363494.py in <cell line: 0>()
      3 corrupted_grads = []
      4 
----> 5 with model.trace() as tracer:
      6 # Using nnsight's tracer.invoke context, we can batch the clean and the
      7 # corrupted runs into the same tracing context, allowing us to access

1 frames/usr/local/lib/python3.12/dist-packages/nnsight/intervention/tracing/base.py in __exit__(self, exc_type, exc_val, exc_tb)
    599             # This is the expected case - the traced code was intercepted
    600             # Execute the captured code using the configured backend
--> 601             self.backend(self)
    602 
    603             # Return True to suppress the ExitTracingException

/usr/local/lib/python3.12/dist-packages/nnsight/intervention/backends/execution.py in __call__(self, tracer)
     22         except Exception as e:
     23 
---> 24             raise wrap_exception(e, tracer.info) from None
     25         finally:
     26             Globals.exit()

NNsightException: 

Traceback (most recent call last):
  File "/tmp/ipython-input-3154363494.py", line 56, in <cell line: 0>
    corrupted_grads.append(layer.attn.c_proj.input.grad.save())
  File "/usr/local/lib/python3.12/dist-packages/nnsight/intervention/envoy.py", line 256, in input
    inputs = self.inputs
  File "/usr/local/lib/python3.12/dist-packages/nnsight/intervention/envoy.py", line 205, in inputs
    return self._interleaver.current.request(
  File "/usr/local/lib/python3.12/dist-packages/nnsight/intervention/interleaver.py", line 396, in current
    return self.mediators[threading.current_thread().name]

KeyError: 'Mediator136692785885312'

This is the cell that failed:

clean_out = []
corrupted_out = []
corrupted_grads = []

with model.trace() as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass

    with tracer.invoke(clean_tokens) as invoker_clean:
        # Gather each layer's attention
        for layer in model.transformer.h:
            # Get clean attention output for this layer
            # across all attention heads
            attn_out = layer.attn.c_proj.input
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:
        # Gather each layer's attention and gradients
        for layer in model.transformer.h:
            # Get corrupted attention output for this layer
            # across all attention heads
            corrupted_out.append(layer.attn.c_proj.input.save())

        # Let's get the logits for the model's output
        # for the corrupted run
        logits = model.lm_head.output.save()

        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())

        # We also need to run a backwards pass to
        # update gradient values
        
        with value.backward(): 
          for layer in model.transformer.h:
            # save corrupted gradients for attribution patching
            corrupted_grads.append(layer.attn.c_proj.input.grad.save())

@clement_dumas

Here’s code to make this work. I believe @ebortz is almost done refactoring the attribution patching tutorial.

with model.trace() as tracer:
# Using nnsight's tracer.invoke context, we can batch the clean and the
# corrupted runs into the same tracing context, allowing us to access
# information generated within each of these runs within one forward pass

    with tracer.invoke(clean_tokens) as invoker_clean:
        # Gather each layer's attention
        for layer in model.transformer.h:
            # Get clean attention output for this layer
            # across all attention heads
            attn_out = layer.attn.c_proj.input
            clean_out.append(attn_out)

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:
        # Gather each layer's attention and gradients
        for layer in model.transformer.h:
            # Get corrupted attention output for this layer
            # across all attention heads
            corruped_out = layer.attn.c_proj.input
            layer.attn.c_proj.input = corruped_out
            corrupted_out.append(corruped_out)

        # Let's get the logits for the model's output
        # for the corrupted run
        logits = model.lm_head.output.save()

        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())

        # We also need to run a backwards pass to
        # update gradient values
        
        with value.backward(): 
          for i in reversed(range(len(model.transformer.h))):
            # save corrupted gradients for attribution patching
            corrupted_grads.append(corrupted_out[i].grad)

I did three things.

  1. Reference the corrupted_out from the existing list instead of calling .input in the backwards pass. All tensors referenced in the .backwards() context must already exists and therefore .input/.output isnt allowed inside. That’s why you were getting that error. Ill look for a way to detect this and provide a better error message

  2. I reinject the corrupted_out after accessing it

corruped_out = layer.attn.c_proj.input
layer.attn.c_proj.input = corruped_out

Me and @AdamBelfki are looking at a way to prevent this but essentially, because were using two invokes with 2 prompts, nnsight narrows (slices) the tensors to be only the part of the batch relative to the invoke. Therefore its technically just off the backwards graph and the .grad will never be populated. Re-injecting it fixes this.

  1. Referencing the corruped_out in reverse order because that’s the order the gradient for each value will be populated.
1 Like

oh that makes sense, thanks!