|
|
@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
"""Logits processors for the GPT-Neo model."""
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import (
|
|
|
|
from transformers import (
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
LogitsWarper,
|
|
|
|
LogitsWarper,
|
|
|
@ -48,7 +50,7 @@ class NumberStoppingCriteria(StoppingCriteria):
|
|
|
|
scores: torch.FloatTensor,
|
|
|
|
scores: torch.FloatTensor,
|
|
|
|
) -> bool:
|
|
|
|
) -> bool:
|
|
|
|
decoded = self.tokenizer.decode(
|
|
|
|
decoded = self.tokenizer.decode(
|
|
|
|
input_ids[0][self.prompt_length :],
|
|
|
|
input_ids[0][self.prompt_length:],
|
|
|
|
skip_special_tokens=True,
|
|
|
|
skip_special_tokens=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|