You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/simulations/map_generation/game_map.py

663 lines
21 KiB

"""
Production-grade AI Vision Pipeline for depth estimation, segmentation, object detection,
and 3D point cloud generation.
This module provides a comprehensive pipeline that combines MiDaS for depth estimation,
SAM (Segment Anything Model) for semantic segmentation, YOLOv8 for object detection,
and Open3D for 3D point cloud generation.
"""
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, Any
import warnings
warnings.filterwarnings("ignore")
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import open3d as o3d
from loguru import logger
# Third-party model imports
try:
import timm
from segment_anything import (
SamAutomaticMaskGenerator,
sam_model_registry,
)
from ultralytics import YOLO
except ImportError as e:
logger.error(f"Missing required dependencies: {e}")
sys.exit(1)
class AIVisionPipeline:
"""
A comprehensive AI vision pipeline that performs depth estimation, semantic segmentation,
object detection, and 3D point cloud generation from input images.
This class integrates multiple state-of-the-art models:
- MiDaS for monocular depth estimation
- SAM (Segment Anything Model) for semantic segmentation
- YOLOv8 for object detection
- Open3D for 3D point cloud generation
Attributes:
model_dir (Path): Directory where models are stored
device (torch.device): Computing device (CPU/CUDA)
midas_model: Loaded MiDaS depth estimation model
midas_transform: MiDaS preprocessing transforms
sam_generator: SAM automatic mask generator
yolo_model: YOLOv8 object detection model
Example:
>>> pipeline = AIVisionPipeline()
>>> results = pipeline.process_image("path/to/image.jpg")
>>> point_cloud = results["point_cloud"]
"""
def __init__(
self,
model_dir: str = "./models",
device: Optional[str] = None,
midas_model_type: str = "MiDaS",
sam_model_type: str = "vit_b",
yolo_model_path: str = "yolov8n.pt",
log_level: str = "INFO",
) -> None:
"""
Initialize the AI Vision Pipeline.
Args:
model_dir: Directory to store downloaded models
device: Computing device ('cpu', 'cuda', or None for auto-detection)
midas_model_type: MiDaS model variant ('MiDaS', 'MiDaS_small', 'DPT_Large', etc.)
sam_model_type: SAM model type ('vit_b', 'vit_l', 'vit_h')
yolo_model_path: Path to YOLOv8 model weights
log_level: Logging level ('DEBUG', 'INFO', 'WARNING', 'ERROR')
Raises:
RuntimeError: If required models cannot be loaded
FileNotFoundError: If model files are not found
"""
# Setup logging
logger.remove()
logger.add(
sys.stdout,
level=log_level,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
)
# Initialize attributes
self.model_dir = Path(model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
# Device setup
if device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
else:
self.device = torch.device(device)
logger.info(f"Using device: {self.device}")
# Model configuration
self.midas_model_type = midas_model_type
self.sam_model_type = sam_model_type
self.yolo_model_path = yolo_model_path
# Initialize model placeholders
self.midas_model: Optional[torch.nn.Module] = None
self.midas_transform: Optional[transforms.Compose] = None
self.sam_generator: Optional[SamAutomaticMaskGenerator] = None
self.yolo_model: Optional[YOLO] = None
# Load all models
self._setup_models()
logger.success("AI Vision Pipeline initialized successfully")
def _setup_models(self) -> None:
"""
Load and initialize all AI models with proper error handling.
Raises:
RuntimeError: If any model fails to load
"""
try:
self._load_midas_model()
self._load_sam_model()
self._load_yolo_model()
except Exception as e:
logger.error(f"Failed to setup models: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
def _load_midas_model(self) -> None:
"""Load MiDaS depth estimation model."""
try:
logger.info(
f"Loading MiDaS model: {self.midas_model_type}"
)
# Load MiDaS model from torch hub
self.midas_model = torch.hub.load(
"intel-isl/MiDaS",
self.midas_model_type,
pretrained=True,
)
self.midas_model.to(self.device)
self.midas_model.eval()
# Load corresponding transforms
midas_transforms = torch.hub.load(
"intel-isl/MiDaS", "transforms"
)
if self.midas_model_type in ["DPT_Large", "DPT_Hybrid"]:
self.midas_transform = midas_transforms.dpt_transform
else:
self.midas_transform = (
midas_transforms.default_transform
)
logger.success("MiDaS model loaded successfully")
except Exception as e:
logger.error(f"Failed to load MiDaS model: {e}")
raise
def _load_sam_model(self) -> None:
"""Load SAM (Segment Anything Model) for semantic segmentation."""
try:
logger.info(f"Loading SAM model: {self.sam_model_type}")
# SAM model checkpoints mapping
sam_checkpoint_urls = {
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
}
checkpoint_path = (
self.model_dir / f"sam_{self.sam_model_type}.pth"
)
# Download checkpoint if not exists
if not checkpoint_path.exists():
logger.info(
f"Downloading SAM checkpoint to {checkpoint_path}"
)
import urllib.request
urllib.request.urlretrieve(
sam_checkpoint_urls[self.sam_model_type],
checkpoint_path,
)
# Load SAM model
sam = sam_model_registry[self.sam_model_type](
checkpoint=str(checkpoint_path)
)
sam.to(self.device)
# Create automatic mask generator
self.sam_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100,
)
logger.success("SAM model loaded successfully")
except Exception as e:
logger.error(f"Failed to load SAM model: {e}")
raise
def _load_yolo_model(self) -> None:
"""Load YOLOv8 object detection model."""
try:
logger.info(
f"Loading YOLOv8 model: {self.yolo_model_path}"
)
self.yolo_model = YOLO(self.yolo_model_path)
# Move to appropriate device
if self.device.type == "cuda":
self.yolo_model.to(self.device)
logger.success("YOLOv8 model loaded successfully")
except Exception as e:
logger.error(f"Failed to load YOLOv8 model: {e}")
raise
def _load_and_preprocess_image(
self, image_path: Union[str, Path]
) -> Tuple[np.ndarray, Image.Image]:
"""
Load and preprocess input image.
Args:
image_path: Path to the input image (JPG or PNG)
Returns:
Tuple of (opencv_image, pil_image)
Raises:
FileNotFoundError: If image file doesn't exist
ValueError: If image format is not supported
"""
image_path = Path(image_path)
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
if image_path.suffix.lower() not in [".jpg", ".jpeg", ".png"]:
raise ValueError(
f"Unsupported image format: {image_path.suffix}"
)
try:
# Load with OpenCV (BGR format)
cv_image = cv2.imread(str(image_path))
if cv_image is None:
raise ValueError(
f"Could not load image: {image_path}"
)
# Convert BGR to RGB for PIL
rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_image)
logger.debug(
f"Loaded image: {image_path} ({rgb_image.shape})"
)
return rgb_image, pil_image
except Exception as e:
logger.error(f"Failed to load image {image_path}: {e}")
raise
def estimate_depth(self, image: np.ndarray) -> np.ndarray:
"""
Generate depth map using MiDaS model.
Args:
image: Input image as numpy array (H, W, 3) in RGB format
Returns:
Depth map as numpy array (H, W)
Raises:
RuntimeError: If depth estimation fails
"""
try:
logger.debug("Estimating depth with MiDaS")
# Preprocess image for MiDaS
input_tensor = self.midas_transform(image).to(self.device)
# Perform inference
with torch.no_grad():
depth_map = self.midas_model(input_tensor)
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=image.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
# Convert to numpy
depth_numpy = depth_map.cpu().numpy()
# Normalize depth values
depth_numpy = (depth_numpy - depth_numpy.min()) / (
depth_numpy.max() - depth_numpy.min()
)
logger.debug(
f"Depth estimation completed. Shape: {depth_numpy.shape}"
)
return depth_numpy
except Exception as e:
logger.error(f"Depth estimation failed: {e}")
raise RuntimeError(f"Depth estimation error: {e}")
def segment_image(
self, image: np.ndarray
) -> List[Dict[str, Any]]:
"""
Perform semantic segmentation using SAM.
Args:
image: Input image as numpy array (H, W, 3) in RGB format
Returns:
List of segmentation masks with metadata
Raises:
RuntimeError: If segmentation fails
"""
try:
logger.debug("Performing segmentation with SAM")
# Generate masks
masks = self.sam_generator.generate(image)
logger.debug(f"Generated {len(masks)} segmentation masks")
return masks
except Exception as e:
logger.error(f"Segmentation failed: {e}")
raise RuntimeError(f"Segmentation error: {e}")
def detect_objects(
self, image: np.ndarray
) -> List[Dict[str, Any]]:
"""
Perform object detection using YOLOv8.
Args:
image: Input image as numpy array (H, W, 3) in RGB format
Returns:
List of detected objects with bounding boxes and confidence scores
Raises:
RuntimeError: If object detection fails
"""
try:
logger.debug("Performing object detection with YOLOv8")
# Run inference
results = self.yolo_model(image, verbose=False)
# Extract detections
detections = []
for result in results:
boxes = result.boxes
if boxes is not None:
for i in range(len(boxes)):
detection = {
"bbox": boxes.xyxy[i]
.cpu()
.numpy(), # [x1, y1, x2, y2]
"confidence": float(
boxes.conf[i].cpu().numpy()
),
"class_id": int(
boxes.cls[i].cpu().numpy()
),
"class_name": result.names[
int(boxes.cls[i].cpu().numpy())
],
}
detections.append(detection)
logger.debug(f"Detected {len(detections)} objects")
return detections
except Exception as e:
logger.error(f"Object detection failed: {e}")
raise RuntimeError(f"Object detection error: {e}")
def generate_point_cloud(
self,
image: np.ndarray,
depth_map: np.ndarray,
masks: Optional[List[Dict[str, Any]]] = None,
) -> o3d.geometry.PointCloud:
"""
Generate 3D point cloud from image and depth data.
Args:
image: RGB image array (H, W, 3)
depth_map: Depth map array (H, W)
masks: Optional segmentation masks for point cloud filtering
Returns:
Open3D PointCloud object
Raises:
ValueError: If input dimensions don't match
RuntimeError: If point cloud generation fails
"""
try:
logger.debug("Generating 3D point cloud")
if image.shape[:2] != depth_map.shape:
raise ValueError(
"Image and depth map dimensions must match"
)
height, width = depth_map.shape
# Create intrinsic camera parameters (assuming standard camera)
fx = fy = width # Focal length approximation
cx, cy = (
width / 2,
height / 2,
) # Principal point at image center
# Create coordinate grids
u, v = np.meshgrid(np.arange(width), np.arange(height))
# Convert depth to actual distances (inverse depth)
# MiDaS outputs inverse depth, so we invert it
z = 1.0 / (
depth_map + 1e-6
) # Add small epsilon to avoid division by zero
# Back-project to 3D coordinates
x = (u - cx) * z / fx
y = (v - cy) * z / fy
# Create point cloud
points = np.stack(
[x.flatten(), y.flatten(), z.flatten()], axis=1
)
colors = (
image.reshape(-1, 3) / 255.0
) # Normalize colors to [0, 1]
# Filter out invalid points
valid_mask = np.isfinite(points).all(axis=1) & (
z.flatten() > 0
)
points = points[valid_mask]
colors = colors[valid_mask]
# Create Open3D point cloud
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points)
point_cloud.colors = o3d.utility.Vector3dVector(colors)
# Optional: Filter by segmentation masks
if masks and len(masks) > 0:
# Use the largest mask for filtering
largest_mask = max(masks, key=lambda x: x["area"])
mask_2d = largest_mask["segmentation"]
mask_1d = mask_2d.flatten()[valid_mask]
filtered_points = points[mask_1d]
filtered_colors = colors[mask_1d]
point_cloud.points = o3d.utility.Vector3dVector(
filtered_points
)
point_cloud.colors = o3d.utility.Vector3dVector(
filtered_colors
)
# Remove statistical outliers
point_cloud, _ = point_cloud.remove_statistical_outlier(
nb_neighbors=20, std_ratio=2.0
)
logger.debug(
f"Generated point cloud with {len(point_cloud.points)} points"
)
return point_cloud
except Exception as e:
logger.error(f"Point cloud generation failed: {e}")
raise RuntimeError(f"Point cloud generation error: {e}")
def process_image(
self, image_path: Union[str, Path]
) -> Dict[str, Any]:
"""
Process a single image through the complete AI vision pipeline.
Args:
image_path: Path to input image (JPG or PNG)
Returns:
Dictionary containing all processing results:
- 'image': Original RGB image
- 'depth_map': Depth estimation result
- 'segmentation_masks': SAM segmentation results
- 'detections': YOLO object detection results
- 'point_cloud': Open3D point cloud object
Raises:
FileNotFoundError: If image file doesn't exist
RuntimeError: If any processing step fails
"""
try:
logger.info(f"Processing image: {image_path}")
# Load and preprocess image
rgb_image, pil_image = self._load_and_preprocess_image(
image_path
)
# Depth estimation
depth_map = self.estimate_depth(rgb_image)
# Semantic segmentation
segmentation_masks = self.segment_image(rgb_image)
# Object detection
detections = self.detect_objects(rgb_image)
# 3D point cloud generation
point_cloud = self.generate_point_cloud(
rgb_image, depth_map, segmentation_masks
)
# Compile results
results = {
"image": rgb_image,
"depth_map": depth_map,
"segmentation_masks": segmentation_masks,
"detections": detections,
"point_cloud": point_cloud,
"metadata": {
"image_shape": rgb_image.shape,
"num_segments": len(segmentation_masks),
"num_detections": len(detections),
"num_points": len(point_cloud.points),
},
}
logger.success("Image processing completed successfully")
logger.info(f"Results: {results['metadata']}")
return results
except Exception as e:
logger.error(f"Image processing failed: {e}")
raise
def save_point_cloud(
self,
point_cloud: o3d.geometry.PointCloud,
output_path: Union[str, Path],
) -> None:
"""
Save point cloud to file.
Args:
point_cloud: Open3D PointCloud object
output_path: Output file path (.ply, .pcd, .xyz)
Raises:
RuntimeError: If saving fails
"""
try:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
success = o3d.io.write_point_cloud(
str(output_path), point_cloud
)
if not success:
raise RuntimeError("Failed to write point cloud file")
logger.success(f"Point cloud saved to: {output_path}")
except Exception as e:
logger.error(f"Failed to save point cloud: {e}")
raise RuntimeError(f"Point cloud save error: {e}")
def visualize_point_cloud(
self, point_cloud: o3d.geometry.PointCloud
) -> None:
"""
Visualize point cloud using Open3D viewer.
Args:
point_cloud: Open3D PointCloud object to visualize
"""
try:
logger.info("Opening point cloud visualization")
o3d.visualization.draw_geometries([point_cloud])
except Exception as e:
logger.warning(f"Visualization failed: {e}")
# Example usage and testing
if __name__ == "__main__":
# Example usage
try:
# Initialize pipeline
pipeline = AIVisionPipeline(
model_dir="./models", log_level="INFO"
)
# Process an image (replace with actual image path)
image_path = "map_two.png" # Replace with your image path
if Path(image_path).exists():
results = pipeline.process_image(image_path)
# Save point cloud
pipeline.save_point_cloud(
results["point_cloud"], "output_point_cloud.ply"
)
# Optional: Visualize point cloud
pipeline.visualize_point_cloud(results["point_cloud"])
print(
f"Processing completed! Generated {results['metadata']['num_points']} 3D points"
)
else:
logger.warning(f"Example image not found: {image_path}")
except Exception as e:
logger.error(f"Example execution failed: {e}")