Patching in 0.5

Hello,

I am patching MLP outputs to better understand what each MLP layer is doing. To do this, I’m running something like this below:

prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a data scientist. Given a data table, return only the *maximum y-value*. Respond with only the y-value as it appears in the table. Do not explain. Do not return the x-value or the row index.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the maximum y-value?\nData Table:\nInput Log <s> Point Statistic Grp H 10 Grp L 31.2 Grp C 19.1 Grp T 34.2 Grp E 13 Grp G 28.2 Grp A 7 Grp U 22.1 Grp D 25.1 Grp W 16<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
corrupt_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a data scientist. Given a data table, return only the *maximum y-value*. Respond with only the y-value as it appears in the table. Do not explain. Do not return the x-value or the row index.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the maximum y-value?\nData Table:\nInput Log <s> Point Statistic Grp H 10 Grp L 31.2 Grp C 19.1 Grp T 14.7 Grp E 13 Grp G 28.2 Grp A 7 Grp U 22.1 Grp D 25.1 Grp W 16<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

row_token_map = {0: {'x': [96, 97, 98], 'y': [100]},
 1: {'x': [101, 102, 103], 'y': [105, 106, 107]},
 2: {'x': [108, 109, 110], 'y': [112, 113, 114]},
 3: {'x': [115, 116, 117], 'y': [119, 120, 121]},
 4: {'x': [122, 123, 124], 'y': [126]},
 5: {'x': [127, 128, 129], 'y': [131, 132, 133]},
 6: {'x': [134, 135, 136], 'y': [138]},
 7: {'x': [139, 140, 141], 'y': [143, 144, 145]},
 8: {'x': [146, 147, 148], 'y': [150, 151, 152]},
 9: {'x': [153, 154, 155], 'y': [157]}}

llamaInstruct = LanguageModel("meta-llama/Llama-3.3-70B-Instruct", device_map="auto")
layer_idx = 0
max_row = 3
clean_answer_idx = 1958
corrupt_answer_idx = 2148
patching_results = {}

with llamaInstruct.trace(remote = True) as tracer:
    barrier = tracer.barrier(3)

    with tracer.invoke(prompt) as invoker:
        mlp_out = llamaInstruct.model.layers[layer_idx].mlp.output
        value_idx = row_token_map[max_row]['y']
        mlp_clean = mlp_out[:, value_idx, :]

        clean_logits = llamaInstruct.lm_head.output
        clean_logit_diff = (
            clean_logits[0, -1, clean_answer_idx] - clean_logits[0, -1, corrupt_answer_idx]
        )
        barrier()

    with tracer.invoke(corrupt_prompt) as invoker:
        corrupted_logits = llamaInstruct.lm_head.output

        corrupted_logit_diff = (
            corrupted_logits[0, -1, clean_answer_idx] - corrupted_logits[0, -1, corrupt_answer_idx]
        )
        barrier()

    with tracer.invoke(corrupt_prompt) as invoker:
        barrier()
        mlp_patched = llamaInstruct.model.layers[layer_idx].mlp.output
        value_idx = row_token_map[max_row]['y']
        for i, idx in enumerate(value_idx):
            mlp_patched[:, idx, :] = mlp_clean[:, i, :]

        llamaInstruct.model.layers[layer_idx].mlp.output = mlp_patched
    
        patched_logits = llamaInstruct.lm_head.output
        patched_output = patched_logits.argmax(dim=-1).save()
        patched_logit_diff = (
            patched_logits[0, -1, answer_token_indices[file_id][0]]
            - patched_logits[0, -1, answer_token_indices[file_id][1]]
        )

        # Calculate the improvement in the correct token after patching.
        patched_result = (patched_logit_diff - corrupted_logit_diff) / (
            clean_logit_diff - corrupted_logit_diff
        )

    patching_results[layer_idx] = patched_result.item().save()

The idea is to patch specific tokens in the corrupted run with the MLP outputs from the clean run. This used to work fine, but now I’m getting the following error when I try to patch the layer’s output:

ValueError: Execution complete but model.model.layers.0.mlp.output.i0 was not provided. Did you call an Envoy out of order? Investigate why this module was not called?
Remote exception.

Do I need to change how I’m handling this to make it compatible with nnsight 0.5? Thanks!

Hi @hhyeminbang! Here’s a fixed version of your code:

from nnsight import LanguageModel

prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a data scientist. Given a data table, return only the *maximum y-value*. Respond with only the y-value as it appears in the table. Do not explain. Do not return the x-value or the row index.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the maximum y-value?\nData Table:\nInput Log <s> Point Statistic Grp H 10 Grp L 31.2 Grp C 19.1 Grp T 34.2 Grp E 13 Grp G 28.2 Grp A 7 Grp U 22.1 Grp D 25.1 Grp W 16<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
corrupt_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a data scientist. Given a data table, return only the *maximum y-value*. Respond with only the y-value as it appears in the table. Do not explain. Do not return the x-value or the row index.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the maximum y-value?\nData Table:\nInput Log <s> Point Statistic Grp H 10 Grp L 31.2 Grp C 19.1 Grp T 14.7 Grp E 13 Grp G 28.2 Grp A 7 Grp U 22.1 Grp D 25.1 Grp W 16<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

row_token_map = {0: {'x': [96, 97, 98], 'y': [100]},
 1: {'x': [101, 102, 103], 'y': [105, 106, 107]},
 2: {'x': [108, 109, 110], 'y': [112, 113, 114]},
 3: {'x': [115, 116, 117], 'y': [119, 120, 121]},
 4: {'x': [122, 123, 124], 'y': [126]},
 5: {'x': [127, 128, 129], 'y': [131, 132, 133]},
 6: {'x': [134, 135, 136], 'y': [138]},
 7: {'x': [139, 140, 141], 'y': [143, 144, 145]},
 8: {'x': [146, 147, 148], 'y': [150, 151, 152]},
 9: {'x': [153, 154, 155], 'y': [157]}}

llamaInstruct = LanguageModel("meta-llama/Llama-3.3-70B-Instruct", device_map="auto")
layer_idx = 0
max_row = 3
clean_answer_idx = 1958
corrupt_answer_idx = 2148

with llamaInstruct.trace(remote=True) as tracer:

    patching_results = dict().save()
    barrier_mlp = tracer.barrier(2)
    barrier_logits = tracer.barrier(2)

    with tracer.invoke(prompt) as invoker:
        mlp_out = llamaInstruct.model.layers[layer_idx].mlp.output
        value_idx = row_token_map[max_row]['y']
        mlp_clean = mlp_out[:, value_idx, :]

        barrier_mlp()

        clean_logits = llamaInstruct.lm_head.output
        clean_logit_diff = (
            clean_logits[0, -1, clean_answer_idx] - clean_logits[0, -1, corrupt_answer_idx]
        )

    with tracer.invoke(corrupt_prompt) as invoker:
        corrupted_logits = llamaInstruct.lm_head.output

        corrupted_logit_diff = (
            corrupted_logits[0, -1, clean_answer_idx] - corrupted_logits[0, -1, corrupt_answer_idx]
        )

        barrier_logits()

    with tracer.invoke(corrupt_prompt) as invoker:
        barrier_mlp()
        mlp_patched = llamaInstruct.model.layers[layer_idx].mlp.output
        value_idx = row_token_map[max_row]['y']
        for i, idx in enumerate(value_idx):
            mlp_patched[:, idx, :] = mlp_clean[:, i, :]

        llamaInstruct.model.layers[layer_idx].mlp.output = mlp_patched

        barrier_logits()
    
        patched_logits = llamaInstruct.lm_head.output
        patched_output = patched_logits.argmax(dim=-1).save()
        patched_logit_diff = (
            patched_logits[0, -1, 0] - patched_logits[0, -1, 1] # answer token indices :)
        )

        # Calculate the improvement in the correct token after patching.
        patched_result = (patched_logit_diff - corrupted_logit_diff) / (
            clean_logit_diff - corrupted_logit_diff
        )

        patching_results[layer_idx] = patched_result.item().save()


print(patching_results)

Some context about why tracer.barrier() is needed here:

Each invoker is essentially a parallel thread. Each reference to a module/Envoy (e.g. model.layers[0].output) also acts as a barrier, because your intervention code needs to wait for that module to be called by the model’s computation graph.

There is a special case though when you are referencing a value from invoker (a) that is defined at the same point in execution as the reference location in invoker (b). Most commonly occurring here in patching experiments.

Just like in any multi-threaded program, you want to introduce a barrier in invoker (b) (to pause the thread) before the reference of the value from invoker (a), and another one after the value in invoker (a) is defined.

Note: I used two barriers to make the code more readable, one for the mlp output and one for the logits from the unpatched corrupted prompt. But, you can definitely reuse the mlp barrier in place of the logins one!

Thank you so much for your help – this makes a lot of sense!

I noticed that you added patching_results = dict().save(). Could you explain why that’s necessary? My goal is to store results in a dictionary outside of the tracer block, so I can run the tracer block multiple times and append values to this dictionary (I currently have the tracer block inside a for loop, saving values at the end of each iteration).

In this case, what’s the best way to declare and manage this dictionary?

You can have the dict defined the tracer, just call .save on it inside the tracing context, so it’s updated state would be returned to you by the server during remote execution.