diff --git a/swarms/tools/logits_processor.py b/swarms/tools/logits_processor.py index ed7fef18..c6ba1691 100644 --- a/swarms/tools/logits_processor.py +++ b/swarms/tools/logits_processor.py @@ -1,3 +1,5 @@ +"""Logits processors for the GPT-Neo model.""" + from transformers import ( PreTrainedTokenizer, LogitsWarper, @@ -48,7 +50,7 @@ class NumberStoppingCriteria(StoppingCriteria): scores: torch.FloatTensor, ) -> bool: decoded = self.tokenizer.decode( - input_ids[0][self.prompt_length :], + input_ids[0][self.prompt_length:], skip_special_tokens=True, )