@ -1,3 +1,5 @@
"""Logits processors for the GPT-Neo model."""
from transformers import (
PreTrainedTokenizer,
LogitsWarper,
@ -48,7 +50,7 @@ class NumberStoppingCriteria(StoppingCriteria):
scores: torch.FloatTensor,
) -> bool:
decoded = self.tokenizer.decode(
input_ids[0][self.prompt_length :],
input_ids[0][self.prompt_length:],
skip_special_tokens=True,
)