import pyrealsense2 as rs
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import cv2

# RS_WIDTH = 1280
# RS_HEIGHT = 720

RS_WIDTH = 640
RS_HEIGHT = 480


class RealSenseInterface:

    def noCameraFound(self):
        print("RealSense camera not found. Check the connection.")


class ImageAcquisition(RealSenseInterface):

    def __init__(self):
        try:
            self.__pipeline = rs.pipeline()
            config = rs.config()
            config.enable_stream(rs.stream.depth, RS_WIDTH, RS_HEIGHT, rs.format.z16, 30)
            config.enable_stream(rs.stream.color, RS_WIDTH, RS_HEIGHT, rs.format.bgr8, 30)
            self.__pipeline.start(config)
        except:
            self.noCameraFound()
            exit()

    def get_depth_image(self):
        frames = self.__pipeline.wait_for_frames()
        depth_frame = frames.get_depth_frame()
        return np.asanyarray(depth_frame.get_data()) if depth_frame else None

    def get_color_image(self):
        frames = self.__pipeline.wait_for_frames()
        color_frame = frames.get_color_frame()
        return np.asanyarray(color_frame.get_data()) if color_frame else None

    def stop(self):
        self.__pipeline.stop()


class ImageDisplay(RealSenseInterface):

    def display_images(self, depth_image, color_image):
        fig, (ax1, ax2) = plt.subplots(1, 2)
        depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)
        ax1.imshow(depth_colormap, cmap='jet')
        ax2.imshow(color_image)
        plt.show()


class ImageSaver(RealSenseInterface):

    def save_image(self, image, filename):
        cv2.imwrite(filename, image)


class RealSenseController:

    def __init__(self):
        self.acquisition = ImageAcquisition()
        self.display = ImageDisplay()
        self.saver = ImageSaver()

    def start(self):
        print("RealSense 3D camera detected.")
        fig, (ax1, ax2) = plt.subplots(1, 2)
        depth_im = ax1.imshow(np.zeros((RS_HEIGHT, RS_WIDTH, 3)), cmap='jet')
        color_im = ax2.imshow(np.zeros((RS_HEIGHT, RS_WIDTH, 3)), cmap='jet')

        def update(frame):
            depth_image = self.acquisition.get_depth_image()
            color_image = self.acquisition.get_color_image()
            if depth_image is not None and color_image is not None:
                depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)
                depth_im.set_array(depth_colormap)
                color_im.set_array(color_image)
                self.saver.save_image(color_image, "color_image.png")

        ani = FuncAnimation(fig, update, blit=False, cache_frame_data=False)
        fig.canvas.manager.window.wm_geometry("+0+0")
        plt.show()
        self.acquisition.stop()