From d9a1d75c7c445422925391707ad212c19bbdcf81 Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 20 Oct 2023 01:49:54 -0400 Subject: [PATCH] abstract model --- swarms/models/base.py | 73 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 7 deletions(-) diff --git a/swarms/models/base.py b/swarms/models/base.py index 63ee4db1..57045165 100644 --- a/swarms/models/base.py +++ b/swarms/models/base.py @@ -1,18 +1,77 @@ +import time from abc import ABC, abstractmethod +def count_tokens(text: str) -> int: + return len(text.split()) class AbstractModel(ABC): + """ + AbstractModel + + """ # abstract base class for language models - def __init__(): - pass + def __init__(self): + self.start_time = None + self.end_time = None + self.temperature = 1.0 + self.max_tokens = None + self.history = "" @abstractmethod - def run(self, prompt): - # generate text using language model + def run(self, task: str) -> str: + """generate text using language model""" pass - def chat(self, prompt, history): - pass + def chat(self, task: str, history: str = "") -> str: + """Chat with the model""" + complete_task = task + " | " + history # Delimiter for clarity + return self.run(complete_task) + + def __call__(self, task: str) -> str: + """Call the model""" + return self.run(task) + + def _sec_to_first_token(self) -> float: + # Assuming the first token appears instantly after the model starts + return 0.001 + + def _tokens_per_second(self) -> float: + """Tokens per second""" + elapsed_time = self.end_time - self.start_time + if elapsed_time == 0: + return float("inf") + return self._num_tokens() / elapsed_time + + def _num_tokens(self, text: str) -> int: + """Number of tokens""" + return count_tokens(text) - def __call__(self, task): + def _time_for_generation(self, task: str) -> float: + """Time for Generation""" + self.start_time = time.time() + output = self.run(task) + self.end_time = time.time() + return self.end_time - self.start_time + + @abstractmethod + def generate_summary(self, text: str) -> str: + """Generate Summary""" pass + + def set_temperature(self, value: float): + """Set Temperature""" + self.temperature = value + + def set_max_tokens(self, value: int): + """Set new max tokens""" + self.max_tokens = value + + def clear_history(self): + """Clear history""" + self.history = "" + + def get_generation_time(self) -> float: + """Get generation time""" + if self.start_time and self.end_time: + return self.end_time - self.start_time + return 0