import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Poly3DCollection class Visualization3D: def __init__(self, camera_position): self.camera_position = camera_position def plot_scene(self, detections_info, follow_vector=None): fig = plt.figure() ax = fig.add_subplot(111, projection='3d') for detection in detections_info: x1, y1, x2, y2 = detection['bbox'] mean_depth = detection['mean_depth'] class_name = detection['class_name'] width = x2 - x1 height = y2 - y1 depth = 0.05 box_points = [ [x1, y1, mean_depth], [x2, y1, mean_depth], [x2, y2, mean_depth], [x1, y2, mean_depth], [x1, y1, mean_depth + depth], [x2, y1, mean_depth + depth], [x2, y2, mean_depth + depth], [x1, y2, mean_depth + depth] ] faces = [ [box_points[0], box_points[1], box_points[5], box_points[4]], [box_points[3], box_points[2], box_points[6], box_points[7]], [box_points[0], box_points[3], box_points[7], box_points[4]], [box_points[1], box_points[2], box_points[6], box_points[5]], [box_points[0], box_points[1], box_points[2], box_points[3]], [box_points[4], box_points[5], box_points[6], box_points[7]] ] box = Poly3DCollection(faces, facecolors='cyan', linewidths=1, edgecolors='r', alpha=0.25) ax.add_collection3d(box) ax.text((x1 + x2) / 2, (y1 + y2) / 2, mean_depth + depth, f'{class_name} {mean_depth:.2f}m', color='blue', fontsize=8) if follow_vector is not None: target_position = self.camera_position + follow_vector ax.quiver( self.camera_position[0], self.camera_position[1], self.camera_position[2], follow_vector[0], follow_vector[1], follow_vector[2], color='red', arrow_length_ratio=0.1 ) ax.text(target_position[0], target_position[1], target_position[2], "Target", color='red', fontsize=10) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Depth (m)") ax.set_title("3D Scene with Follow Vector") all_x = [d['bbox'][0] for d in detections_info] + [d['bbox'][2] for d in detections_info] all_y = [d['bbox'][1] for d in detections_info] + [d['bbox'][3] for d in detections_info] all_z = [d['mean_depth'] for d in detections_info] + [d['mean_depth'] + 0.05 for d in detections_info] ax.set_xlim(min(all_x) - 50, max(all_x) + 50) ax.set_ylim(min(all_y) - 50, max(all_y) + 50) ax.set_zlim(min(all_z) - 0.1, max(all_z) + 0.1) plt.show()