import cv2
import matplotlib.pyplot as plt
import numpy as np
from skimage.segmentation import active_contour
from skimage.draw import line, polygon, circle, ellipse


class ActiveContour(object):
    """docstring for ActiveContour

    Attributes:
        img (array): Input image
        mask (array): Input binary mask
        img_closing (): Image after Mrophology closing
        init_contour (): Initial contour
        mask_hole (): Mask of enclosed backgound area, which can be regard as
                      a hole in the mask. Active contour segmentation can't 
                      distinguish enclosed backgound area, but it can be obtained 
                      from input mask generated by GraphCut segmentation.
        snake (list): list of points of final contour
        seg_mask (): Segmentation mask
    """

    def __init__(self, img, mask, beta=10000):
        super(ActiveContour, self).__init__()
        self.img = img
        self.mask = mask

        # Pipeline:
        self.mroph_processing()
        self.find_init_contour()
        self.active_contour(beta)
        self.fill_contour()

    def mroph_processing(self, iters=1):
        kernel = np.ones((5, 5), np.uint8)
        img_closing = cv2.morphologyEx(self.mask+0, cv2.MORPH_CLOSE, kernel)
        self.img_closing = img_closing

    def find_init_contour(self):
        image, contours, hierarchy = cv2.findContours(
            self.img_closing, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        assert len(contours) == 1
        self.init_contour = contours[0]
        mask_cnt = np.zeros_like(self.mask)
        cv2.drawContours(mask_cnt, contours, 0, 255, -1)
        self.mask_hole = mask_cnt - self.img_closing

    def active_contour(self, beta=1000):
        init = [[cord[0][1], cord[0][0]] for cord in self.init_contour]
        init = np.array(init)
        self.snake = active_contour(self.img,
                                    init,
                                    beta=beta,
                                    w_edge=100,
                                    w_line=100,
                                    max_px_move=0.1,
                                    max_iterations=50,
                                    coordinates='rc')

    def fill_contour(self):
        self.seg_mask = np.zeros(self.mask.shape, np.uint8)
        rr, cc = polygon(self.snake[:, 0],
                         self.snake[:, 1], self.seg_mask.shape)
        self.seg_mask[rr, cc] = 255
        if cv2.countNonZero(self.mask_hole) > 100:
            self.seg_mask = self.seg_mask - self.mask_hole

    def plot(self):
        fig, ax = plt.subplots(figsize=(7, 7))
        ax.imshow(self.img, cmap=plt.cm.gray)
        # ax.plot(init[:, 1], init[:, 0], '--r', lw=3)
        ax.plot(self.snake[:, 1], self.snake[:, 0], '-b', lw=2)
        ax.plot(self.init_contour[:, :, 0],
                self.init_contour[:, :, 1], '--r', lw=2)
        plt.show()


def main():
    img = cv2.imread('1.jpg')
    img = cv2.resize(img, (400, 400))
    mask = cv2.imread('1_mask.png', 0)
    mask = cv2.resize(mask, (400, 400))
    active_contour = ActiveContour(img, mask)
    active_contour.plot()
    # cv2.imshow('img', active_contour.seg_mask)
    # cv2.waitKey(0)


if __name__ == '__main__':
    main()
