You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/examples/hackathon_feb16/gassisan_splat.py

312 lines
9.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
from torch import Tensor
from loguru import logger
from typing import Tuple
import matplotlib.pyplot as plt
try:
# ipywidgets is available in interactive environments like Jupyter.
from ipywidgets import interact, IntSlider
HAS_IPYWIDGETS = True
except ImportError:
HAS_IPYWIDGETS = False
logger.warning(
"ipywidgets not installed. Interactive slicing will be disabled."
)
class GaussianSplat4DStateSpace:
"""
4D Gaussian splatting with a state space model in PyTorch.
Each Gaussian is defined by an 8D state vector:
[x, y, z, w, vx, vy, vz, vw],
where the first four dimensions are the spatial coordinates and the last
four are the velocities. Only the spatial (first four) dimensions are used
for the 4D Gaussian splat, with a corresponding 4×4 covariance matrix.
Attributes:
num_gaussians (int): Number of Gaussians.
state_dim (int): Dimension of the state vector (should be 8).
states (Tensor): Current state for each Gaussian of shape (num_gaussians, state_dim).
covariances (Tensor): Covariance matrices for the spatial dimensions, shape (num_gaussians, 4, 4).
A (Tensor): State transition matrix of shape (state_dim, state_dim).
dt (float): Time step for state updates.
"""
def __init__(
self,
num_gaussians: int,
init_states: Tensor,
init_covariances: Tensor,
dt: float = 1.0,
) -> None:
"""
Initialize the 4D Gaussian splat model.
Args:
num_gaussians (int): Number of Gaussians.
init_states (Tensor): Initial states of shape (num_gaussians, 8).
Each state is assumed to be
[x, y, z, w, vx, vy, vz, vw].
init_covariances (Tensor): Initial covariance matrices for the spatial dimensions,
shape (num_gaussians, 4, 4).
dt (float): Time step for the state update.
"""
if init_states.shape[1] != 8:
raise ValueError(
"init_states should have shape (N, 8) where 8 = 4 position + 4 velocity."
)
if init_covariances.shape[1:] != (4, 4):
raise ValueError(
"init_covariances should have shape (N, 4, 4)."
)
self.num_gaussians = num_gaussians
self.states = init_states.clone() # shape: (N, 8)
self.covariances = (
init_covariances.clone()
) # shape: (N, 4, 4)
self.dt = dt
self.state_dim = init_states.shape[1]
# Create an 8x8 constant-velocity state transition matrix:
# New position = position + velocity*dt, velocity remains unchanged.
I4 = torch.eye(
4, dtype=init_states.dtype, device=init_states.device
)
zeros4 = torch.zeros(
(4, 4), dtype=init_states.dtype, device=init_states.device
)
top = torch.cat([I4, dt * I4], dim=1)
bottom = torch.cat([zeros4, I4], dim=1)
self.A = torch.cat([top, bottom], dim=0) # shape: (8, 8)
logger.info(
"Initialized 4D GaussianSplatStateSpace with {} Gaussians.",
num_gaussians,
)
def update_states(self) -> None:
"""
Update the state of each Gaussian using the constant-velocity state space model.
Applies:
state_next = A @ state_current.
"""
self.states = (
self.A @ self.states.t()
).t() # shape: (num_gaussians, 8)
logger.debug("States updated: {}", self.states)
def _compute_gaussian(
self, pos: Tensor, cov: Tensor, coords: Tensor
) -> Tensor:
"""
Compute the 4D Gaussian function over a grid of coordinates.
Args:
pos (Tensor): The center of the Gaussian (4,).
cov (Tensor): The 4×4 covariance matrix.
coords (Tensor): A grid of coordinates of shape (..., 4).
Returns:
Tensor: Evaluated Gaussian values on the grid with shape equal to coords.shape[:-1].
"""
try:
cov_inv = torch.linalg.inv(cov)
except RuntimeError as e:
logger.warning(
"Covariance inversion failed; using pseudo-inverse. Error: {}",
e,
)
cov_inv = torch.linalg.pinv(cov)
# Broadcast pos over the grid
diff = coords - pos.view(
*(1 for _ in range(coords.ndim - 1)), 4
)
mahal = torch.einsum("...i,ij,...j->...", diff, cov_inv, diff)
gaussian = torch.exp(-0.5 * mahal)
return gaussian
def render(
self,
canvas_size: Tuple[int, int, int, int],
sigma_scale: float = 1.0,
normalize: bool = False,
) -> Tensor:
"""
Render the current 4D Gaussian splats onto a 4D canvas.
Args:
canvas_size (Tuple[int, int, int, int]): The size of the canvas (d1, d2, d3, d4).
sigma_scale (float): Scaling factor for the covariance (affects spread).
normalize (bool): Whether to normalize the final canvas to [0, 1].
Returns:
Tensor: A 4D tensor (canvas) with the accumulated contributions from all Gaussians.
"""
d1, d2, d3, d4 = canvas_size
# Create coordinate grids for each dimension.
grid1 = torch.linspace(
0, d1 - 1, d1, device=self.states.device
)
grid2 = torch.linspace(
0, d2 - 1, d2, device=self.states.device
)
grid3 = torch.linspace(
0, d3 - 1, d3, device=self.states.device
)
grid4 = torch.linspace(
0, d4 - 1, d4, device=self.states.device
)
# Create a 4D meshgrid (using indexing "ij")
grid = torch.stack(
torch.meshgrid(grid1, grid2, grid3, grid4, indexing="ij"),
dim=-1,
) # shape: (d1, d2, d3, d4, 4)
# Initialize the canvas.
canvas = torch.zeros(
(d1, d2, d3, d4),
dtype=self.states.dtype,
device=self.states.device,
)
for i in range(self.num_gaussians):
pos = self.states[i, :4] # spatial center (4,)
cov = (
self.covariances[i] * sigma_scale
) # scaled covariance
gaussian = self._compute_gaussian(pos, cov, grid)
canvas += gaussian
logger.debug(
"Rendered Gaussian {} at position {}", i, pos.tolist()
)
if normalize:
max_val = canvas.max()
if max_val > 0:
canvas = canvas / max_val
logger.debug("Canvas normalized.")
logger.info("4D Rendering complete.")
return canvas
def interactive_slice(canvas: Tensor) -> None:
"""
Display an interactive 2D slice of the 4D canvas using ipywidgets.
This function fixes two of the four dimensions (d3 and d4) via sliders and
displays the resulting 2D slice (over dimensions d1 and d2).
Args:
canvas (Tensor): A 4D tensor with shape (d1, d2, d3, d4).
"""
d1, d2, d3, d4 = canvas.shape
def display_slice(slice_d3: int, slice_d4: int):
slice_2d = canvas[:, :, slice_d3, slice_d4].cpu().numpy()
plt.figure(figsize=(6, 6))
plt.imshow(slice_2d, cmap="hot", origin="lower")
plt.title(f"2D Slice at d3={slice_d3}, d4={slice_d4}")
plt.colorbar()
plt.show()
interact(
display_slice,
slice_d3=IntSlider(min=0, max=d3 - 1, step=1, value=d3 // 2),
slice_d4=IntSlider(min=0, max=d4 - 1, step=1, value=d4 // 2),
)
def mip_projection(canvas: Tensor) -> None:
"""
Render a 2D view of the 4D canvas using maximum intensity projection (MIP)
along the 3rd and 4th dimensions.
Args:
canvas (Tensor): A 4D tensor with shape (d1, d2, d3, d4).
"""
# MIP along dimension 3
mip_3d = canvas.max(dim=2)[0] # shape: (d1, d2, d4)
# MIP along dimension 4
mip_2d = mip_3d.max(dim=2)[0] # shape: (d1, d2)
plt.figure(figsize=(6, 6))
plt.imshow(mip_2d.cpu().numpy(), cmap="hot", origin="lower")
plt.title("2D MIP (Projecting dimensions d3 and d4)")
plt.colorbar()
plt.show()
def main() -> None:
"""
Main function that:
- Creates a 4D Gaussian splat model.
- Updates the states to simulate motion.
- Renders a 4D canvas.
- Visualizes the 4D volume via interactive slicing (if available) or MIP.
"""
torch.manual_seed(42)
num_gaussians = 2
# Define initial states for each Gaussian:
# Each state is [x, y, z, w, vx, vy, vz, vw].
init_states = torch.tensor(
[
[10.0, 15.0, 20.0, 25.0, 0.5, -0.2, 0.3, 0.1],
[30.0, 35.0, 40.0, 45.0, -0.3, 0.4, -0.1, 0.2],
],
dtype=torch.float32,
)
# Define initial 4x4 covariance matrices for the spatial dimensions.
init_covariances = torch.stack(
[
torch.diag(
torch.tensor(
[5.0, 5.0, 5.0, 5.0], dtype=torch.float32
)
),
torch.diag(
torch.tensor(
[3.0, 3.0, 3.0, 3.0], dtype=torch.float32
)
),
]
)
# Create the 4D Gaussian splat model.
model = GaussianSplat4DStateSpace(
num_gaussians, init_states, init_covariances, dt=1.0
)
# Update states to simulate one time step.
model.update_states()
# Render the 4D canvas.
canvas_size = (20, 20, 20, 20)
canvas = model.render(
canvas_size, sigma_scale=1.0, normalize=True
)
# Visualize the 4D data.
if HAS_IPYWIDGETS:
logger.info("Launching interactive slicing tool for 4D data.")
interactive_slice(canvas)
else:
logger.info(
"ipywidgets not available; using maximum intensity projection instead."
)
mip_projection(canvas)
if __name__ == "__main__":
main()