How to break at the EOS token (gemma-2-2b-it)?

Hi everyone! I’m trying to reproduce a paper on improving instruction following.
The authors used gemma-2-2b-it with transformerlens, and they stop generation when the EOS token is reached.

I’m trying to do the same, but .all() isn’t working for gemma-2-2b-it, or at least, I haven’t been able to make it work.

Here’s my current code:

with model.generate(prompt, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id):
    target_layer = model.model.layers[sel_layer]
    for _ in range(512):
        out = target_layer.output
        target_layer.output = (out[0] + weight * direction,)
        target_layer.next()
        # gen_ids = model.generator.output[0]
        # if gen_ids[-1] == tokenizer.eos_token_id:
        #     break
        
    out_ids = model.generator.output.save()

print(tokenizer.decode(out_ids[0], skip_special_tokens=True))

In the original implementation, they checked for the EOS token at each step and broke early. Has anyone dealt with EOS-based stopping for gemma-2-2b-it in this setup?

Hi @theballer! Thanks for making your post detailed.

What version of nnsight were you trying to implement this with?

I highly recommend checking out our pre-released version of nnsight 0.5, with pip install nnsight>=0.5.0.dev14. You can learn more about it here: NNsight 0.5 Prerelease: Feedback Requested

Here’s how I would implement an token-based generation stop with nnsight 0.5:

with model.generate(max_new_tokens=512, pad_token_id=model.tokenizer.eos_token_id) as tracer:

	with tracer.invoke(prompt):
		pass

	with tracer.invoke():
		# your interventions here
		pass
	
	with tracer.invoke():
		with tracer.all():
			logit = model.output['logits'].argmax(dim=-1)
	
			if logit == model.tokenizer.eos_token_id:
				tracer.stop()

Let me know if you have any further questions!

1 Like

Thank you for your help! Yes, I was using earlier version of nnsight, upgraded it, so, I have this code running ok:

with model.generate(max_new_tokens=1024, pad_token_id=model.tokenizer.eos_token_id) as tracer:

    with tracer.invoke("prompt"):
        pass

    with tracer.invoke():
        with tracer.all():
            model.model.layers[layer].output[0][:] += direction * weight
        
    with tracer.invoke():
        with tracer.all():
            logit = model.output['logits'].argmax(dim=-1)
            if logit == model.tokenizer.eos_token_id:
                tracer.stop()
            out = model.generator.output.save()

print(tokenizer.decode(out[0]))

I want to add direction at all token positions, and generate till eos, then take output, is that a right code for that? seems ok for me. The only thing now is that I am getting this warning message:
UserWarning: Execution complete but `model.model.layers.14.output.i341` was not provided. This was in an Iterator at iteration 341 so likely this iteration did not happen. If you were using `.iter[:]`, this is likely not an error. warnings.warn(msg), what is wrong?