[REFACTOR][HuggingfaceLLM]

pull/343/head
Kye 1 year ago
parent 41b858a91d
commit e8ca14f071

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "3.2.7" version = "3.2.8"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

@ -28,6 +28,7 @@ class CogAgent(BaseMultiModalModel):
>>> cog_agent.run("How are you?", "images/1.jpg") >>> cog_agent.run("How are you?", "images/1.jpg")
<s> I'm fine. How are you? </s> <s> I'm fine. How are you? </s>
""" """
def __init__( def __init__(
self, self,
model_name: str = "ZhipuAI/cogagent-chat", model_name: str = "ZhipuAI/cogagent-chat",

@ -3,18 +3,18 @@ import concurrent.futures
import logging import logging
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from termcolor import colored from termcolor import colored
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
) )
from swarms.models.base_llm import AbstractLLM
class HuggingfaceLLM: class HuggingfaceLLM(AbstractLLM):
""" """
A class for running inference on a given model. A class for running inference on a given model.
@ -123,7 +123,6 @@ class HuggingfaceLLM:
quantize: bool = False, quantize: bool = False,
quantization_config: dict = None, quantization_config: dict = None,
verbose=False, verbose=False,
# logger=None,
distributed=False, distributed=False,
decoding=False, decoding=False,
max_workers: int = 5, max_workers: int = 5,
@ -135,6 +134,7 @@ class HuggingfaceLLM:
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.device = ( self.device = (
device device
@ -174,16 +174,21 @@ class HuggingfaceLLM:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, *args, **kwargs self.model_id
) )
if quantize:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, self.model_id,
quantization_config=bnb_config, quantization_config=bnb_config,
*args, *args,
**kwargs, **kwargs,
) ).to(self.device)
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, *args, **kwargs
).to(self.device)
self.model # .to(self.device)
except Exception as e: except Exception as e:
# self.logger.error(f"Failed to load the model or the tokenizer: {e}") # self.logger.error(f"Failed to load the model or the tokenizer: {e}")
# raise # raise
@ -205,33 +210,6 @@ class HuggingfaceLLM:
"""Ashcnronous generate text for a given prompt""" """Ashcnronous generate text for a given prompt"""
return await asyncio.to_thread(self.run, task) return await asyncio.to_thread(self.run, task)
def load_model(self):
"""Load the model"""
if not self.model or not self.tokenizer:
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id
)
bnb_config = (
BitsAndBytesConfig(**self.quantization_config)
if self.quantization_config
else None
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=bnb_config
).to(self.device)
if self.distributed:
self.model = DDP(self.model)
except Exception as error:
self.logger.error(
"Failed to load the model or the tokenizer:"
f" {error}"
)
raise
def concurrent_run(self, tasks: List[str], max_workers: int = 5): def concurrent_run(self, tasks: List[str], max_workers: int = 5):
"""Concurrently generate text for a list of prompts.""" """Concurrently generate text for a list of prompts."""
with concurrent.futures.ThreadPoolExecutor( with concurrent.futures.ThreadPoolExecutor(
@ -252,7 +230,7 @@ class HuggingfaceLLM:
results = [future.result() for future in futures] results = [future.result() for future in futures]
return results return results
def run(self, task: str): def run(self, task: str, *args, **kwargs):
""" """
Generate a response based on the prompt text. Generate a response based on the prompt text.
@ -263,20 +241,12 @@ class HuggingfaceLLM:
Returns: Returns:
- Generated text (str). - Generated text (str).
""" """
self.load_model()
max_length = self.max_length
self.print_dashboard(task)
try: try:
inputs = self.tokenizer.encode(task, return_tensors="pt") inputs = self.tokenizer.encode(task, return_tensors="pt")
# self.log.start()
if self.decoding: if self.decoding:
with torch.no_grad(): with torch.no_grad():
for _ in range(max_length): for _ in range(self.max_length):
output_sequence = [] output_sequence = []
outputs = self.model.generate( outputs = self.model.generate(
@ -300,7 +270,11 @@ class HuggingfaceLLM:
else: else:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(
inputs, max_length=max_length, do_sample=True inputs,
max_length=self.max_length,
do_sample=True,
*args,
**kwargs,
) )
del inputs del inputs
@ -320,67 +294,8 @@ class HuggingfaceLLM:
) )
raise raise
def __call__(self, task: str): def __call__(self, task: str, *args, **kwargs):
""" return self.run(task, *args, **kwargs)
Generate a response based on the prompt text.
Args:
- task (str): Text to prompt the model.
- max_length (int): Maximum length of the response.
Returns:
- Generated text (str).
"""
self.load_model()
max_length = self.max_length
self.print_dashboard(task)
try:
inputs = self.tokenizer.encode(
task, return_tensors="pt"
).to(self.device)
# self.log.start()
if self.decoding:
with torch.no_grad():
for _ in range(max_length):
output_sequence = []
outputs = self.model.generate(
inputs,
max_length=len(inputs) + 1,
do_sample=True,
)
output_tokens = outputs[0][-1]
output_sequence.append(output_tokens.item())
# print token in real-time
print(
self.tokenizer.decode(
[output_tokens],
skip_special_tokens=True,
),
end="",
flush=True,
)
inputs = outputs
else:
with torch.no_grad():
outputs = self.model.generate(
inputs, max_length=max_length, do_sample=True
)
del inputs
return self.tokenizer.decode(
outputs[0], skip_special_tokens=True
)
except Exception as e:
self.logger.error(f"Failed to generate the text: {e}")
raise
async def __call_async__(self, task: str, *args, **kwargs) -> str: async def __call_async__(self, task: str, *args, **kwargs) -> str:
"""Call the model asynchronously""" "" """Call the model asynchronously""" ""

Loading…
Cancel
Save