Can't generate multiple tokens and patch for each token

I tried different variants of the code to generate multiple tokens and intervene at each pass, but couldn’t make it work. This code works on 0.4, but not with 0.5.

from nnsight import LanguageModel
from torch import Tensor

REMOTE = True
n_tokens = 700

model = LanguageModel("meta-llama/Meta-Llama-3.1-405B-Instruct")
tokenizer = model.tokenizer

def intervene_with_h(
model: LanguageModel,
prompt: str,
h: Tensor,
layer: int,
remote: bool = REMOTE,
):

    start_idx = len(tokenizer.encode(prompt))
    
    with model.generate(remote=REMOTE, max_new_tokens=n_tokens) as generator:
        with model.all():
            with generator.invoke(prompt):
                model.model.layers[layer].output[0][-1] += h
                tokens = model.generator.output.save()
    
    return tokenizer.decode(tokens[0][start_idx:])
    
completions_intervention = intervene_with_h(model, prompt, alpha * h, layer=layer)

Error message:

--------------------------------------------------------------------------- ExitTracingException Traceback (most recent call last) Cell In[42], line 41, in intervene_with_h(model, prompt, h, layer, remote) 40 with model.generate(remote=REMOTE, max_new_tokens=n_tokens) as generator: —> 41 with model.all(): 42 with generator.invoke(prompt): Cell In[42], line 41, in intervene_with_h(model, prompt, h, layer, remote) 40 with model.generate(remote=REMOTE, max_new_tokens=n_tokens) as generator: —> 41 with model.all(): 42 with generator.invoke(prompt): File ~/Library/Caches/pypoetry/virtualenvs/multimodality-2TwJzaIP-py3.10/lib/python3.10/site-packages/nnsight/intervention/tracing/base.py:386, in Tracer.enter..skip(new_frame, event, arg) 385 self.info.frame.f_trace = None → 386 raise ExitTracingException() ExitTracingException: During handling of the above exception, another exception occurred: RemoteException Traceback (most recent call last) Cell In[43], line 1 ----> 1 completions_intervention = intervene_with_h( 2 model, prompt, alpha * h, layer=layer 3 ) Cell In[42], line 40, in intervene_with_h(model, prompt, h, layer, remote) 32 start_idx = len(tokenizer.encode(prompt)) 34 # with model.generate(max_new_tokens=n_tokens, remote=remote) as generator: 35 # with model.all(): 36 # with generator.invoke(prompt): 37 # model.model.layers[layer].output[0][-1] += h 38 # tokens = model.generator.output.save() —> 40 with model.generate(remote=REMOTE, max_new_tokens=n_tokens) as generator: 41 with model.all(): 42 with generator.invoke(prompt): File ~/Library/Caches/pypoetry/virtualenvs/multimodality-2TwJzaIP-py3.10/lib/python3.10/site-packages/nnsight/intervention/tracing/base.py:416, in Tracer.exit(self, exc_type, exc_val, exc_tb) 412 # Suppress the ExitTracingException but let other exceptions propagate 413 if exc_type is ExitTracingException: 414 415 # Execute the traced code using the configured backend → 416 self.backend(self) 418 return True 420 self.backend(self) File ~/Library/Caches/pypoetry/virtualenvs/multimodality-2TwJzaIP-py3.10/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:88, in RemoteBackend.call(self, tracer) 83 def call(self, tracer = None): 85 if self.blocking: 86 87 # Do blocking request. —> 88 result = self.blocking_request(tracer) 90 else: 91 92 # Otherwise we are getting the status / result of the existing job. 93 result = self.non_blocking_request(tracer) File ~/Library/Caches/pypoetry/virtualenvs/multimodality-2TwJzaIP-py3.10/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:300, in RemoteBackend.blocking_request(self, tracer) 296 return result 298 except Exception as e: → 300 raise e 302 finally: 303 LocalTracer.deregister() File ~/Library/Caches/pypoetry/virtualenvs/multimodality-2TwJzaIP-py3.10/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:293, in RemoteBackend.blocking_request(self, tracer) 291 response = ResponseModel.unpickle(response) 292 # Handle the response. → 293 result = self.handle_response(response, tracer=tracer) 294 # Break when completed. 295 if result is not None: File ~/Library/Caches/pypoetry/virtualenvs/multimodality-2TwJzaIP-py3.10/lib/python3.10/site-packages/nnsight/intervention/backends/remote.py:123, in RemoteBackend.handle_response(self, response, tracer) 120 self.job_status = response.status 122 if response.status == ResponseModel.JobStatus.ERROR: → 123 raise RemoteException(f"{response.description}\nRemote exception.") 125 # Log response for user 126 response.log(remote_logger) RemoteException: Traceback (most recent call last): File "/u/svcndifuser/ndif-deployment/repos/prod/ndif/src/services/ray/src/ray/deployments/modeling/base.py", line 213, in call result = await job_task File "/u/svcndifuser/miniconda3/lib/python3.10/asyncio/threads.py", line 25, in to_thread return await loop.run_in_executor(None, func_call) File "/u/svcndifuser/miniconda3/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span return method(self, *_args, **_kwargs) File "/u/svcndifuser/ndif-deployment/repos/prod/ndif/src/services/ray/src/ray/deployments/modeling/base.py", line 283, in execute result = RemoteExecutionBackend(request.interventions, self.execution_protector)(request.tracer) File "/u/svcndifuser/ndif-deployment/repos/prod/ndif/src/services/ray/src/ray/nn/backend.py", line 27, in call run(tracer, self.fn) File "/u/svcndifuser/ndif-deployment/repos/prod/ndif/src/services/ray/src/ray/nn/sandbox.py", line 16, in run raise wrap_exception(e,tracer.info) from None nnsight.NNsightException: Traceback (most recent call last): File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/nnsight/intervention/backends/execution.py", line 21, in call tracer.execute(fn) File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/nnsight/intervention/tracing/iterator.py", line 51, in execute mediator = Mediator(fn, self.info, batch_group=self.interleaver.current.batch_group, stop=self.interleaver.current.all_stop) File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/nnsight/intervention/interleaver.py", line 376, in current return self.mediators[threading.current_thread().name] TypeError: ‘NoneType’ object is not subscriptable Traceback (most recent call last): File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/nnsight/intervention/backends/execution.py", line 21, in call tracer.execute(fn) File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/nnsight/intervention/tracing/iterator.py", line 51, in execute mediator = Mediator(fn, self.info, batch_group=self.interleaver.current.batch_group, stop=self.interleaver.current.all_stop) File "/u/svcndifuser/miniconda3/lib/python3.10/site-packages/nnsight/intervention/interleaver.py", line 376, in current return self.mediators[threading.current_thread().name] TypeError: ‘NoneType’ object is not subscriptable Remote exception.

I think we have to use several contexts for this, e.g.:

with model.generate(remote=REMOTE, max_new_tokens=n_tokens) as generator:         
    with generator.invoke(prompt):
         with generator.all():
             model.model.layers[layer].output[0][-1] += h
    with generator.invoke()
         tokens = model.generator.output.save()
3 Likes

Hi @Amina, thanks for posting about this!

Let me complement @ilya.lasy 's reply and give you some best nnsight tracing practices:

  1. Invokers are the top-level wrapper of any tracing intervention, consider it like defining a function. So model.all() should be indented in below it.

  2. model.all() is deprecated. Consider using tracers.all() or tracer.iter[<slice>] as idx: instead.

  3. In model.generate(), the model.generator module is only being called once, and at the end of the generation. This means all your other interventions will hang until they are completely missed by tracing, because when the actual model.generator module is called the token generation has already ended.

You can debug this by adding a print statement,

with tracer.all():
    print("Hi")
    # some intervention
    model.generator.output

and see how many times the statement is actually printed.

Instead what you should do is de-indent model.generator to only be called when tracer.all() is done iterating. (@ilya.lasy 's suggestion also works, by calling the generator inside of a prompt-less invoker)

Finally, here’s how I would re-implement your tracing:

with model.generate(remote=REMOTE, max_new_tokens=n_tokens) as tracer:
        with tracer.invoke(prompt):
            with tracer.all():
                model.model.layers[layer].output[0][-1] += h

            tokens = model.generator.output.save()
2 Likes

Thank you, both, it worked!

1 Like

I was having the exact same issue, but with a local model, and for me the solution you provided does not work.

Here’s a snippet of my code:

with model.generate(**generate_kwargs) as tracer:
        with tracer.invoke(input_ids):
            with tracer.all():
                model.model.layers[3].self_attn.o_proj.output *= 0.0
        
            model_output = model.generator.output.save()

return model_output

This throws an

UnboundLocalError: cannot access local variable 'model_output' where it is not associated with a value

at the return statement.
It seems that the lines after the with tracer.all() context manager are not executed at all. Adding some print statements supports this.

@ilya.lasy 's solution with the empty invoke() , .i.e. changing this to:

with model.generate(**generate_kwargs) as tracer:
        with tracer.invoke(input_ids):
            with tracer.all():
                model.model.layers[3].self_attn.o_proj.output *= 0.0
        with tracer.invoke():
            model_output = model.generator.output.save()

return model_output

works perfectly fine.

Hi @kortukov, are you passing a max_new_token in your **generate_kwargs.

If you’re not specifying it, then @ilya.lasy’s solution is the way to go. This is because tracer.all() will not know when to stop and will run forever until the model execution is stopped, by which time is too late to intercept the model output.

Yes, I do.

I have created a clean Colab notebook to reproduce my issue. In a clean notebook, both your solution and @ilya.lasy 's one work.

So apparently, this issue is on my side. Sorry for the confusion, and thanks for the swift reply.

1 Like