import torch from torch import Tensor from loguru import logger from typing import Tuple import matplotlib.pyplot as plt try: # ipywidgets is available in interactive environments like Jupyter. from ipywidgets import interact, IntSlider HAS_IPYWIDGETS = True except ImportError: HAS_IPYWIDGETS = False logger.warning( "ipywidgets not installed. Interactive slicing will be disabled." ) class GaussianSplat4DStateSpace: """ 4D Gaussian splatting with a state space model in PyTorch. Each Gaussian is defined by an 8D state vector: [x, y, z, w, vx, vy, vz, vw], where the first four dimensions are the spatial coordinates and the last four are the velocities. Only the spatial (first four) dimensions are used for the 4D Gaussian splat, with a corresponding 4×4 covariance matrix. Attributes: num_gaussians (int): Number of Gaussians. state_dim (int): Dimension of the state vector (should be 8). states (Tensor): Current state for each Gaussian of shape (num_gaussians, state_dim). covariances (Tensor): Covariance matrices for the spatial dimensions, shape (num_gaussians, 4, 4). A (Tensor): State transition matrix of shape (state_dim, state_dim). dt (float): Time step for state updates. """ def __init__( self, num_gaussians: int, init_states: Tensor, init_covariances: Tensor, dt: float = 1.0, ) -> None: """ Initialize the 4D Gaussian splat model. Args: num_gaussians (int): Number of Gaussians. init_states (Tensor): Initial states of shape (num_gaussians, 8). Each state is assumed to be [x, y, z, w, vx, vy, vz, vw]. init_covariances (Tensor): Initial covariance matrices for the spatial dimensions, shape (num_gaussians, 4, 4). dt (float): Time step for the state update. """ if init_states.shape[1] != 8: raise ValueError( "init_states should have shape (N, 8) where 8 = 4 position + 4 velocity." ) if init_covariances.shape[1:] != (4, 4): raise ValueError( "init_covariances should have shape (N, 4, 4)." ) self.num_gaussians = num_gaussians self.states = init_states.clone() # shape: (N, 8) self.covariances = ( init_covariances.clone() ) # shape: (N, 4, 4) self.dt = dt self.state_dim = init_states.shape[1] # Create an 8x8 constant-velocity state transition matrix: # New position = position + velocity*dt, velocity remains unchanged. I4 = torch.eye( 4, dtype=init_states.dtype, device=init_states.device ) zeros4 = torch.zeros( (4, 4), dtype=init_states.dtype, device=init_states.device ) top = torch.cat([I4, dt * I4], dim=1) bottom = torch.cat([zeros4, I4], dim=1) self.A = torch.cat([top, bottom], dim=0) # shape: (8, 8) logger.info( "Initialized 4D GaussianSplatStateSpace with {} Gaussians.", num_gaussians, ) def update_states(self) -> None: """ Update the state of each Gaussian using the constant-velocity state space model. Applies: state_next = A @ state_current. """ self.states = ( self.A @ self.states.t() ).t() # shape: (num_gaussians, 8) logger.debug("States updated: {}", self.states) def _compute_gaussian( self, pos: Tensor, cov: Tensor, coords: Tensor ) -> Tensor: """ Compute the 4D Gaussian function over a grid of coordinates. Args: pos (Tensor): The center of the Gaussian (4,). cov (Tensor): The 4×4 covariance matrix. coords (Tensor): A grid of coordinates of shape (..., 4). Returns: Tensor: Evaluated Gaussian values on the grid with shape equal to coords.shape[:-1]. """ try: cov_inv = torch.linalg.inv(cov) except RuntimeError as e: logger.warning( "Covariance inversion failed; using pseudo-inverse. Error: {}", e, ) cov_inv = torch.linalg.pinv(cov) # Broadcast pos over the grid diff = coords - pos.view( *(1 for _ in range(coords.ndim - 1)), 4 ) mahal = torch.einsum("...i,ij,...j->...", diff, cov_inv, diff) gaussian = torch.exp(-0.5 * mahal) return gaussian def render( self, canvas_size: Tuple[int, int, int, int], sigma_scale: float = 1.0, normalize: bool = False, ) -> Tensor: """ Render the current 4D Gaussian splats onto a 4D canvas. Args: canvas_size (Tuple[int, int, int, int]): The size of the canvas (d1, d2, d3, d4). sigma_scale (float): Scaling factor for the covariance (affects spread). normalize (bool): Whether to normalize the final canvas to [0, 1]. Returns: Tensor: A 4D tensor (canvas) with the accumulated contributions from all Gaussians. """ d1, d2, d3, d4 = canvas_size # Create coordinate grids for each dimension. grid1 = torch.linspace( 0, d1 - 1, d1, device=self.states.device ) grid2 = torch.linspace( 0, d2 - 1, d2, device=self.states.device ) grid3 = torch.linspace( 0, d3 - 1, d3, device=self.states.device ) grid4 = torch.linspace( 0, d4 - 1, d4, device=self.states.device ) # Create a 4D meshgrid (using indexing "ij") grid = torch.stack( torch.meshgrid(grid1, grid2, grid3, grid4, indexing="ij"), dim=-1, ) # shape: (d1, d2, d3, d4, 4) # Initialize the canvas. canvas = torch.zeros( (d1, d2, d3, d4), dtype=self.states.dtype, device=self.states.device, ) for i in range(self.num_gaussians): pos = self.states[i, :4] # spatial center (4,) cov = ( self.covariances[i] * sigma_scale ) # scaled covariance gaussian = self._compute_gaussian(pos, cov, grid) canvas += gaussian logger.debug( "Rendered Gaussian {} at position {}", i, pos.tolist() ) if normalize: max_val = canvas.max() if max_val > 0: canvas = canvas / max_val logger.debug("Canvas normalized.") logger.info("4D Rendering complete.") return canvas def interactive_slice(canvas: Tensor) -> None: """ Display an interactive 2D slice of the 4D canvas using ipywidgets. This function fixes two of the four dimensions (d3 and d4) via sliders and displays the resulting 2D slice (over dimensions d1 and d2). Args: canvas (Tensor): A 4D tensor with shape (d1, d2, d3, d4). """ d1, d2, d3, d4 = canvas.shape def display_slice(slice_d3: int, slice_d4: int): slice_2d = canvas[:, :, slice_d3, slice_d4].cpu().numpy() plt.figure(figsize=(6, 6)) plt.imshow(slice_2d, cmap="hot", origin="lower") plt.title(f"2D Slice at d3={slice_d3}, d4={slice_d4}") plt.colorbar() plt.show() interact( display_slice, slice_d3=IntSlider(min=0, max=d3 - 1, step=1, value=d3 // 2), slice_d4=IntSlider(min=0, max=d4 - 1, step=1, value=d4 // 2), ) def mip_projection(canvas: Tensor) -> None: """ Render a 2D view of the 4D canvas using maximum intensity projection (MIP) along the 3rd and 4th dimensions. Args: canvas (Tensor): A 4D tensor with shape (d1, d2, d3, d4). """ # MIP along dimension 3 mip_3d = canvas.max(dim=2)[0] # shape: (d1, d2, d4) # MIP along dimension 4 mip_2d = mip_3d.max(dim=2)[0] # shape: (d1, d2) plt.figure(figsize=(6, 6)) plt.imshow(mip_2d.cpu().numpy(), cmap="hot", origin="lower") plt.title("2D MIP (Projecting dimensions d3 and d4)") plt.colorbar() plt.show() def main() -> None: """ Main function that: - Creates a 4D Gaussian splat model. - Updates the states to simulate motion. - Renders a 4D canvas. - Visualizes the 4D volume via interactive slicing (if available) or MIP. """ torch.manual_seed(42) num_gaussians = 2 # Define initial states for each Gaussian: # Each state is [x, y, z, w, vx, vy, vz, vw]. init_states = torch.tensor( [ [10.0, 15.0, 20.0, 25.0, 0.5, -0.2, 0.3, 0.1], [30.0, 35.0, 40.0, 45.0, -0.3, 0.4, -0.1, 0.2], ], dtype=torch.float32, ) # Define initial 4x4 covariance matrices for the spatial dimensions. init_covariances = torch.stack( [ torch.diag( torch.tensor( [5.0, 5.0, 5.0, 5.0], dtype=torch.float32 ) ), torch.diag( torch.tensor( [3.0, 3.0, 3.0, 3.0], dtype=torch.float32 ) ), ] ) # Create the 4D Gaussian splat model. model = GaussianSplat4DStateSpace( num_gaussians, init_states, init_covariances, dt=1.0 ) # Update states to simulate one time step. model.update_states() # Render the 4D canvas. canvas_size = (20, 20, 20, 20) canvas = model.render( canvas_size, sigma_scale=1.0, normalize=True ) # Visualize the 4D data. if HAS_IPYWIDGETS: logger.info("Launching interactive slicing tool for 4D data.") interactive_slice(canvas) else: logger.info( "ipywidgets not available; using maximum intensity projection instead." ) mip_projection(canvas) if __name__ == "__main__": main()