You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.8 KiB

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
class Visualization3D:
def __init__(self, camera_position):
self.camera_position = camera_position
def plot_scene(self, detections_info, follow_vector=None):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for detection in detections_info:
x1, y1, x2, y2 = detection['bbox']
mean_depth = detection['mean_depth']
class_name = detection['class_name']
width = x2 - x1
height = y2 - y1
depth = 0.05
box_points = [
[x1, y1, mean_depth],
[x2, y1, mean_depth],
[x2, y2, mean_depth],
[x1, y2, mean_depth],
[x1, y1, mean_depth + depth],
[x2, y1, mean_depth + depth],
[x2, y2, mean_depth + depth],
[x1, y2, mean_depth + depth]
]
faces = [
[box_points[0], box_points[1], box_points[5], box_points[4]],
[box_points[3], box_points[2], box_points[6], box_points[7]],
[box_points[0], box_points[3], box_points[7], box_points[4]],
[box_points[1], box_points[2], box_points[6], box_points[5]],
[box_points[0], box_points[1], box_points[2], box_points[3]],
[box_points[4], box_points[5], box_points[6], box_points[7]]
]
box = Poly3DCollection(faces, facecolors='cyan', linewidths=1, edgecolors='r', alpha=0.25)
ax.add_collection3d(box)
ax.text((x1 + x2) / 2, (y1 + y2) / 2, mean_depth + depth, f'{class_name} {mean_depth:.2f}m',
color='blue', fontsize=8)
if follow_vector is not None:
target_position = self.camera_position + follow_vector
ax.quiver(
self.camera_position[0], self.camera_position[1], self.camera_position[2],
follow_vector[0], follow_vector[1], follow_vector[2],
color='red', arrow_length_ratio=0.1
)
ax.text(target_position[0], target_position[1], target_position[2],
"Target", color='red', fontsize=10)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Depth (m)")
ax.set_title("3D Scene with Follow Vector")
all_x = [d['bbox'][0] for d in detections_info] + [d['bbox'][2] for d in detections_info]
all_y = [d['bbox'][1] for d in detections_info] + [d['bbox'][3] for d in detections_info]
all_z = [d['mean_depth'] for d in detections_info] + [d['mean_depth'] + 0.05 for d in detections_info]
ax.set_xlim(min(all_x) - 50, max(all_x) + 50)
ax.set_ylim(min(all_y) - 50, max(all_y) + 50)
ax.set_zlim(min(all_z) - 0.1, max(all_z) + 0.1)
plt.show()