#
##
##  SPDX-FileCopyrightText: © 2007-2021 Benedict Verhegghe <bverheg@gmail.com>
##  SPDX-License-Identifier: GPL-3.0-or-later
##
##  This file is part of pyFormex 2.6  (Mon Aug 23 15:13:50 CEST 2021)
##  pyFormex is a tool for generating, manipulating and transforming 3D
##  geometrical models by sequences of mathematical operations.
##  Home page: https://pyformex.org
##  Project page: https://savannah.nongnu.org/projects/pyformex/
##  Development: https://gitlab.com/bverheg/pyformex
##  Distributed under the GNU General Public License version 3 or later.
##
##  This program is free software: you can redistribute it and/or modify
##  it under the terms of the GNU General Public License as published by
##  the Free Software Foundation, either version 3 of the License, or
##  (at your option) any later version.
##
##  This program is distributed in the hope that it will be useful,
##  but WITHOUT ANY WARRANTY; without even the implied warranty of
##  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##  GNU General Public License for more details.
##
##  You should have received a copy of the GNU General Public License
##  along with this program.  If not, see http://www.gnu.org/licenses/.
##
#

#
# This is a modification of the pyvista/pyacvd extension module
# from https://github.com/pyvista/pyacvd.
# It was modified to not use pyvista or vtk data structures,
# but directly use the pyFormex data instead.
#
# The original pyvista/pyacvd is distributed under the MIT license
#
# MIT License
#
# Copyright (c) 2017-2021 The PyVista Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#

"""Point based clustering module

"""
import ctypes
import numpy as np
from scipy import sparse

from pyformex import arraytools as at
from pyformex.trisurface import TriSurface
from pyformex.lib import _clustering

# Temporary class to replace pyvista.PolyData
class PolyData:
    def __init__(self, S):
        self.S = TriSurface(S)

    @property
    def faces(self):
        plex = np.full((self.S.nelems(), 1), 3, dtype=at.Int)
        faces = np.column_stack([plex, self.S.elems])
        return faces

    @property
    def points(self):
        return self.S.coords

    @property
    def number_of_points(self):
        return self.S.ncoords()

    def compute_normals(self, avg):
        if avg:
            return self.S.avgVertexNormals()
        else:
            return self.normals()


class Clustering(object):
    """Uniform point clustering based on ACVD.

    Parameters
    ----------
    mesh : TriSurface

    """

    def __init__(self, S):
        """Check inputs and initializes neighbors"""
        self.mesh = PolyData(S)
        self.clusters = None
        self.nclus = None
        self.remesh = None
        self._area = None
        self._wcent = None
        self._neigh = None
        self._nneigh = None
        self._edges = None
        self._update_data()

    def _update_data(self, weights=None):
        # Compute point weights and weighted points
        self._area, self._wcent = weighted_points(
            self.mesh, additional_weights=weights)
        # neighbors and edges
        self._neigh, self._nneigh = neighbors_from_mesh(self.mesh)
        self._edges = _clustering.edge_id(self._neigh, self._nneigh)

    def cluster(self, nclus, maxiter=100, debug=False, iso_try=10):
        """Cluster points """
        self.clusters, _, self.nclus = _clustering.cluster(
            self._neigh, self._nneigh, nclus, self._area, self._wcent,
            self._edges, maxiter, debug, iso_try)
        return self.clusters

    def subdivide(self, ndiv):
        """Perform a linear subdivision of the mesh

        Parameters
        ----------
        ndiv : int
            Final number of divisions.

        Notes
        -----
        Unlike with the original pyvista/pyacvd module, we can directly set
        the final number of subdivisions and the sub division is done in a
        single step. In pyvista/pyacvd one specifies a number of subdivision
        steps and each step subdivides in 2. Thus a value of nsub = 3 in
        pyvista/pyacvd corresponds to ndiv = 2^3 = 8 in pyFormex. pyFormex
        allows subdivision numbers that are not powers of two. This is not
        possible in pyvista/pyacvd.
        """
        S = self.mesh.S
        S = S.subdivide(ndiv).fuse().compact()
        self.mesh = PolyData(S)
        print(f"Subdivide: {self.mesh.number_of_points}")
        self._update_data()

    def create_mesh(self, flipnorm=True):
        """ Generates mesh from clusters """
        if flipnorm:
            cnorm = self.cluster_norm
        else:
            cnorm = None
        # Generate mesh
        self.remesh = create_mesh(self.mesh, self._area, self.clusters,
                                  cnorm, flipnorm)
        return self.remesh

    @property
    def cluster_norm(self):
        """ Return cluster norms """
        if not hasattr(self, 'clusters'):
            raise Exception('No clusters available')

        # Normals of original mesh
        norm = self.mesh.compute_normals(avg=True)

        # Compute normalized mean cluster normals
        cnorm = np.empty((self.nclus, 3))
        cnorm[:, 0] = np.bincount(self.clusters, weights=norm[:, 0] * self._area)
        cnorm[:, 1] = np.bincount(self.clusters, weights=norm[:, 1] * self._area)
        cnorm[:, 2] = np.bincount(self.clusters, weights=norm[:, 2] * self._area)
        weights = ((cnorm * cnorm).sum(1)**0.5).reshape((-1, 1))
        weights[weights == 0] = 1
        cnorm /= weights
        return cnorm

    @property
    def cluster_centroid(self):
        """ Computes an area normalized value for each cluster """
        wval = self.mesh.points * self._area.reshape(-1, 1)
        cval = np.vstack((np.bincount(self.clusters, weights=wval[:, 0]),
                          np.bincount(self.clusters, weights=wval[:, 1]),
                          np.bincount(self.clusters, weights=wval[:, 2])))
        weights = np.bincount(self.clusters, weights=self._area)
        weights[weights == 0] = 1
        cval /= weights
        return cval.T


def cluster_centroid(cent, area, clusters):
    """ Computes an area normalized centroid for each cluster """

    # Check if null cluster exists
    null_clusters = np.any(clusters == -1)
    if null_clusters:
        clusters = clusters.copy()
        clusters[clusters == -1] = clusters.max() + 1

    wval = cent * area.reshape(-1, 1)
    cweight = np.bincount(clusters, weights=area)
    cweight[cweight == 0] = 1

    cval = np.vstack((np.bincount(clusters, weights=wval[:, 0]),
                      np.bincount(clusters, weights=wval[:, 1]),
                      np.bincount(clusters, weights=wval[:, 2]))) / cweight

    if null_clusters:
        cval[:, -1] = np.inf

    return cval.T


def create_mesh(mesh, area, clusters, cnorm, flipnorm=True):
    """Generates a new mesh given cluster data

    moveclus is a boolean flag to move cluster centers to the surface of their
    corresponding cluster

    """
    print("create_mesh")
    faces = mesh.faces.reshape(-1, 4)
    points = mesh.points
    if points.dtype != np.double:
        points = points.astype(np.double)
    print(f"points {points.shape}")
    print(f"faces {faces.shape}")
    print(f"clusters {clusters.shape}")
    print(f"cnorm {cnorm.shape}")

    # Compute centroids
    ccent = np.ascontiguousarray(cluster_centroid(points, area, clusters))

    # Create sparse matrix storing the number of adjcent clusters a point has
    rng = np.arange(faces.shape[0]).reshape((-1, 1))
    a = np.hstack((rng, rng, rng)).ravel()
    b = clusters[faces[:, 1:]].ravel()  # take?
    c = np.ones(len(a), dtype='bool')

    boolmatrix = sparse.csr_matrix((c, (a, b)), dtype='bool')

    # Find all points with three neighboring clusters.  Each of the three
    # cluster neighbors becomes a point on a triangle
    nadjclus = boolmatrix.sum(1)
    adj = np.array(nadjclus == 3).nonzero()[0]
    idx = boolmatrix[adj].nonzero()[1]

    # Append these points and faces
    points = ccent
    f = idx.reshape((-1, 3))

    # Remove duplicate faces
    f = f[unique_row_indices(np.sort(f, 1))]

    # Mean normals of clusters each face is build from
    if flipnorm:
        adjcnorm = cnorm[f].sum(1)
        adjcnorm /= np.linalg.norm(adjcnorm, axis=1).reshape(-1, 1)

        # and compare this with the normals of each face
        newnorm = TriSurface(points, f).normals()
        print(f"newnorm {newnorm.shape}")

        # If the dot is negative, reverse the order of those faces
        agg = (adjcnorm * newnorm).sum(1)  # dot product
        mask = agg < 0.0
        f[mask] = f[mask, ::-1]

    return TriSurface(points,f)


def unique_row_indices(a):
    """ Indices of unique rows """
    b = np.ascontiguousarray(a).view(
        np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
    _, idx = np.unique(b, return_index=True)
    return idx


def weighted_points(mesh, return_weighted=True, additional_weights=None):
    """Returns point weight based on area weight and weighted points.
    Points are weighted by adjcent area faces.

    Parameters
    ----------
    mesh : pv.PolyData
        All triangular surface mesh.

    return_weighted : bool, optional
        Returns vertices mutlipled by point weights.

    Returns
    -------
    pweight : np.ndarray, np.double
        Point weight array

    wvertex : np.ndarray, np.double
        Vertices mutlipled by their corresponding weights.  Returned only
        when return_weighted is True.

    """
    faces = mesh.faces.reshape(-1, 4)
    if faces.dtype != np.int32:
        faces = faces.astype(np.int32)
    points = mesh.points

    if additional_weights is not None:
        weights = additional_weights
        return_weighted = True
        if not weights.flags['C_CONTIGUOUS']:
            weights = np.ascontiguousarray(weights, dtype=ctypes.c_double)
        elif weights.dtype != ctypes.c_double:
            weights = weights.astype(ctypes.c_double)

        if (weights < 0).any():
            raise Exception('Negtive weights not allowed')

    else:
        weights = np.array([])

    if points.dtype == np.float64:
        weighted_point_func = _clustering.weighted_points_double
    else:
        weighted_point_func = _clustering.weighted_points_float

    return weighted_point_func(points, faces, weights, return_weighted)


def neighbors_from_mesh(mesh):
    """Assemble neighbor array.  Assumes all-triangular mesh.

    Parameters
    ----------
    mesh : pyvista.PolyData
        Mesh to assemble neighbors from.

    Returns
    -------
    neigh : int np.ndarray [:, ::1]
        Indices of each neighboring node for each node.

    nneigh : int np.ndarray [::1]
        Number of neighbors for each node.
    """
    npoints = mesh.number_of_points
    faces = mesh.faces.reshape(-1, 4)
    if faces.dtype != np.int32:
        faces = faces.astype(np.int32)

    return _clustering.neighbors_from_faces(npoints, faces)


def remesh(S, npoints=None, ndiv=1):
    if npoints is None:
        npoints = S.ncoords()
    clus = Clustering(S)
    clus.subdivide(ndiv)
    clus.cluster(npoints)
    return clus.create_mesh()

# End
