Hi everyone! I’m trying to reproduce a paper on improving instruction following.
The authors used gemma-2-2b-it
with transformerlens
, and they stop generation when the EOS token is reached.
I’m trying to do the same, but .all()
isn’t working for gemma-2-2b-it
, or at least, I haven’t been able to make it work.
Here’s my current code:
with model.generate(prompt, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id):
target_layer = model.model.layers[sel_layer]
for _ in range(512):
out = target_layer.output
target_layer.output = (out[0] + weight * direction,)
target_layer.next()
# gen_ids = model.generator.output[0]
# if gen_ids[-1] == tokenizer.eos_token_id:
# break
out_ids = model.generator.output.save()
print(tokenizer.decode(out_ids[0], skip_special_tokens=True))
In the original implementation, they checked for the EOS token at each step and broke early. Has anyone dealt with EOS-based stopping for gemma-2-2b-it
in this setup?
Hi @theballer! Thanks for making your post detailed.
What version of nnsight
were you trying to implement this with?
I highly recommend checking out our pre-released version of nnsight 0.5, with pip install nnsight>=0.5.0.dev14
. You can learn more about it here: NNsight 0.5 Prerelease: Feedback Requested
Here’s how I would implement an token-based generation stop with nnsight 0.5:
with model.generate(max_new_tokens=512, pad_token_id=model.tokenizer.eos_token_id) as tracer:
with tracer.invoke(prompt):
pass
with tracer.invoke():
# your interventions here
pass
with tracer.invoke():
with tracer.all():
logit = model.output['logits'].argmax(dim=-1)
if logit == model.tokenizer.eos_token_id:
tracer.stop()
Let me know if you have any further questions!
1 Like
Thank you for your help! Yes, I was using earlier version of nnsight, upgraded it, so, I have this code running ok:
with model.generate(max_new_tokens=1024, pad_token_id=model.tokenizer.eos_token_id) as tracer:
with tracer.invoke("prompt"):
pass
with tracer.invoke():
with tracer.all():
model.model.layers[layer].output[0][:] += direction * weight
with tracer.invoke():
with tracer.all():
logit = model.output['logits'].argmax(dim=-1)
if logit == model.tokenizer.eos_token_id:
tracer.stop()
out = model.generator.output.save()
print(tokenizer.decode(out[0]))
I want to add direction at all token positions, and generate till eos, then take output, is that a right code for that? seems ok for me. The only thing now is that I am getting this warning message:
UserWarning: Execution complete but `model.model.layers.14.output.i341` was not provided. This was in an Iterator at iteration 341 so likely this iteration did not happen. If you were using `.iter[:]`, this is likely not an error. warnings.warn(msg)
, what is wrong?