[FIX][HuggingfaceLLM]

pull/343/head
Kye 1 year ago
parent d0b09ea029
commit 91a36c8557

@ -94,4 +94,4 @@ blocks_by_parent_description = swarm.get_by_parent_description(
# Run the block in the swarm # Run the block in the swarm
inference = swarm.run_block(toolagent, "Hello World") inference = swarm.run_block(toolagent, "Hello World")
print(inference) print(inference)

@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "3.4.1" version = "3.4.2"
description = "Swarms - Pytorch" description = "Swarms - Pytorch"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

@ -131,7 +131,7 @@ class HuggingfaceLLM(AbstractLLM):
temperature: float = 0.7, temperature: float = 0.7,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.8, top_p: float = 0.8,
dtype = torch.bfloat16, dtype=torch.bfloat16,
*args, *args,
**kwargs, **kwargs,
): ):
@ -189,7 +189,6 @@ class HuggingfaceLLM(AbstractLLM):
self.model_id, *args, **kwargs self.model_id, *args, **kwargs
).to(self.device) ).to(self.device)
def print_error(self, error: str): def print_error(self, error: str):
"""Print error""" """Print error"""
print(colored(f"Error: {error}", "red")) print(colored(f"Error: {error}", "red"))
@ -264,7 +263,7 @@ class HuggingfaceLLM(AbstractLLM):
*args, *args,
**kwargs, **kwargs,
) )
return self.tokenizer.decode( return self.tokenizer.decode(
outputs[0], skip_special_tokens=True outputs[0], skip_special_tokens=True
) )

@ -75,8 +75,8 @@ class BlocksList(BaseStructure):
def get_all(self): def get_all(self):
return self.blocks return self.blocks
def run_block(self, block: Any, task: str, *args, **kwargs): def run_block(self, block: Any, task: str, *args, **kwargs):
"""Run the block for the specified task. """Run the block for the specified task.
Args: Args:

Loading…
Cancel
Save