You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.1 KiB
78 lines
2.1 KiB
import time
|
|
from abc import ABC, abstractmethod
|
|
|
|
def count_tokens(text: str) -> int:
|
|
return len(text.split())
|
|
|
|
class AbstractModel(ABC):
|
|
"""
|
|
AbstractModel
|
|
|
|
"""
|
|
# abstract base class for language models
|
|
def __init__(self):
|
|
self.start_time = None
|
|
self.end_time = None
|
|
self.temperature = 1.0
|
|
self.max_tokens = None
|
|
self.history = ""
|
|
|
|
@abstractmethod
|
|
def run(self, task: str) -> str:
|
|
"""generate text using language model"""
|
|
pass
|
|
|
|
def chat(self, task: str, history: str = "") -> str:
|
|
"""Chat with the model"""
|
|
complete_task = task + " | " + history # Delimiter for clarity
|
|
return self.run(complete_task)
|
|
|
|
def __call__(self, task: str) -> str:
|
|
"""Call the model"""
|
|
return self.run(task)
|
|
|
|
def _sec_to_first_token(self) -> float:
|
|
# Assuming the first token appears instantly after the model starts
|
|
return 0.001
|
|
|
|
def _tokens_per_second(self) -> float:
|
|
"""Tokens per second"""
|
|
elapsed_time = self.end_time - self.start_time
|
|
if elapsed_time == 0:
|
|
return float("inf")
|
|
return self._num_tokens() / elapsed_time
|
|
|
|
def _num_tokens(self, text: str) -> int:
|
|
"""Number of tokens"""
|
|
return count_tokens(text)
|
|
|
|
def _time_for_generation(self, task: str) -> float:
|
|
"""Time for Generation"""
|
|
self.start_time = time.time()
|
|
output = self.run(task)
|
|
self.end_time = time.time()
|
|
return self.end_time - self.start_time
|
|
|
|
@abstractmethod
|
|
def generate_summary(self, text: str) -> str:
|
|
"""Generate Summary"""
|
|
pass
|
|
|
|
def set_temperature(self, value: float):
|
|
"""Set Temperature"""
|
|
self.temperature = value
|
|
|
|
def set_max_tokens(self, value: int):
|
|
"""Set new max tokens"""
|
|
self.max_tokens = value
|
|
|
|
def clear_history(self):
|
|
"""Clear history"""
|
|
self.history = ""
|
|
|
|
def get_generation_time(self) -> float:
|
|
"""Get generation time"""
|
|
if self.start_time and self.end_time:
|
|
return self.end_time - self.start_time
|
|
return 0
|