You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.6 KiB
102 lines
3.6 KiB
import torch
|
|
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
|
from diffusers.utils import export_to_video
|
|
|
|
|
|
class ZeroscopeTTV:
|
|
"""
|
|
ZeroscopeTTV class represents a zero-shot video generation model.
|
|
|
|
Args:
|
|
model_name (str): The name of the pre-trained model to use.
|
|
torch_dtype (torch.dtype): The torch data type to use for computations.
|
|
chunk_size (int): The size of chunks for forward chunking.
|
|
dim (int): The dimension along which to split the input for forward chunking.
|
|
num_inference_steps (int): The number of inference steps to perform.
|
|
height (int): The height of the video frames.
|
|
width (int): The width of the video frames.
|
|
num_frames (int): The number of frames in the video.
|
|
|
|
Attributes:
|
|
model_name (str): The name of the pre-trained model.
|
|
torch_dtype (torch.dtype): The torch data type used for computations.
|
|
chunk_size (int): The size of chunks for forward chunking.
|
|
dim (int): The dimension along which the input is split for forward chunking.
|
|
num_inference_steps (int): The number of inference steps to perform.
|
|
height (int): The height of the video frames.
|
|
width (int): The width of the video frames.
|
|
num_frames (int): The number of frames in the video.
|
|
pipe (DiffusionPipeline): The diffusion pipeline for video generation.
|
|
|
|
Methods:
|
|
forward(task: str = None, *args, **kwargs) -> str:
|
|
Performs forward pass on the input task and returns the path of the generated video.
|
|
|
|
Examples:
|
|
>>> from swarms.models
|
|
>>> zeroscope = ZeroscopeTTV()
|
|
>>> task = "A person is walking on the street."
|
|
>>> video_path = zeroscope(task)
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "cerspense/zeroscope_v2_576w",
|
|
torch_dtype=torch.float16,
|
|
chunk_size: int = 1,
|
|
dim: int = 1,
|
|
num_inference_steps: int = 40,
|
|
height: int = 320,
|
|
width: int = 576,
|
|
num_frames: int = 36,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
self.model_name = model_name
|
|
self.torch_dtype = torch_dtype
|
|
self.chunk_size = chunk_size
|
|
self.dim = dim
|
|
self.num_inference_steps = num_inference_steps
|
|
self.height = height
|
|
self.width = width
|
|
self.num_frames = num_frames
|
|
|
|
self.pipe = DiffusionPipeline.from_pretrained(
|
|
model_name, torch_dtype=torch_dtype, *args, **kwargs
|
|
)
|
|
self.pipe.scheduler = DPMSolverMultistepScheduler(
|
|
self.pipe.scheduler.config,
|
|
)
|
|
self.pipe_enable_model_cpu_offload()
|
|
self.pipe.enable_vae_slicing()
|
|
self.pipe.unet.enable_forward_chunking(
|
|
chunk_size=chunk_size, dim=dim
|
|
)
|
|
|
|
def run(self, task: str = None, *args, **kwargs):
|
|
"""
|
|
Performs a forward pass on the input task and returns the path of the generated video.
|
|
|
|
Args:
|
|
task (str): The input task for video generation.
|
|
|
|
Returns:
|
|
str: The path of the generated video.
|
|
"""
|
|
try:
|
|
video_frames = self.pipe(
|
|
task,
|
|
num_inference_steps=self.num_inference_steps,
|
|
height=self.height,
|
|
width=self.width,
|
|
num_frames=self.num_frames,
|
|
*args,
|
|
**kwargs,
|
|
).frames
|
|
video_path = export_to_video(video_frames)
|
|
return video_path
|
|
except Exception as error:
|
|
print(f"Error in [ZeroscopeTTV.forward]: {error}")
|
|
raise error
|