#!/usr/bin/env python

__copyright__ = "(C) 2006-2008 Ola Skavhaug and Simula Research Laboratory"
__license__   = "GNU LGPL Version 2.1."
__cite__      = """Ola Skavhaug, Viper Visualization Software, 
http://www.fenics.org/wiki/viper/"""
__version__ = "0.4.3"


__doc__ = """
Viper for DOLFIN

A simple mesh plotter and run--time visualization module for plotting and
saving simulation data. Adjusted Viper for plotting native DOLFIN data
structures like dolfin::Mesh, dolfin::MeshFunction, and dolfin::Function.

Copyright: %s
Licence: %s
Citation: %s

""" % (__copyright__, __license__, __cite__)

from viper import Viper as ViperBase
import dolfin
from dolfin.cpp import Mesh, MeshFunctionInt, MeshFunctionUInt, MeshFunctionBool, MeshFunctionReal, Function, FunctionSpace, FunctionPlotData
import numpy
import vtk

import ffc

_plotter = None

class Viper(ViperBase):
    """A custom Viper sub-class for visualizing meshes and fields in Dolfin."""

    def __init__(self, data, *args, **kwargs):
        self._update = self.update
        self.update = self.dolfin_update
        kwargs["rescale"] = kwargs.get("rescale", True)
        self.initcommon(kwargs)
        self.warpscalar = True
        self.lutfile = kwargs.get("lutfile", "gauss_120.lut")
        self.plot(data, *args, **kwargs)


    def parse_input(self, args, kwargs):
        if len(args) == 1 and isinstance(args[0], str):
            lst = [s.strip() for s in args[0].split(",")]
            ar = []
            kw = {}
            for l in lst:
                if "=" in l:
                    (k,v) = l.split("=")
                    kw[k] = eval(v)
                else:
                    ar.append(eval(l))
            return ar, kw
        return args, kwargs

    def plot(self, data, *args, **kwargs):
        args, kwargs = self.parse_input(args, kwargs)
        if "noninteractive" in args: kwargs["interactive"] = False

        if isinstance(data, dolfin.Mesh):
            self.plottype = "mesh"
            self.plot_mesh(data, *args, **kwargs)
        elif isinstance(data, Function):
            self.plottype = "function"
            self.plot_function(data, *args, **kwargs)
        elif isinstance(data, FunctionPlotData):
            self.plottype = "functionplotdata"
            self.plot_functionplotdata(data, *args, **kwargs)
        elif isinstance(data, tuple([eval("MeshFunction%s" % x) for x in ["Int", "UInt", "Real", "Bool"]])):
            self.plottype = "meshfunction"
            self.plot_meshfunction(data, *args, **kwargs)
        else:
            from dolfin.project import project
            print "using project"
            data = project(data)
            self.plot(data, *args, **kwargs)

    def plot_mesh(self, data, *args, **kwargs):
        self.rescale = False
        self.mesh = data
        self.vtkgrid = self.make_vtk_grid(self.mesh)
        self.filter = self.vtkgrid
        self.x = numpy.zeros(self.mesh.numVertices())
        (self.iren, self.renWin, self.ren) = self.simple_plotter(self.filter, self.x.min(), self.x.max(), wireframe=True)
        self._update(self.x)
        self.renWin.Render()
        if kwargs.get("interactive", True):
            self.interactive()

    def plot_function(self, data, *args, **kwargs):
        self.mesh = data.function_space().mesh()
        self.vtkgrid = self.make_vtk_grid(self.mesh)
        self.filter = self.vtkgrid

        self.warpscalar = kwargs.get("warpscalar", True)
        self.rescale = kwargs.get("rescale", False)
        self.mode = kwargs.get("mode", "auto")
        wireframe=kwargs.get("wireframe", False)
        self.pts = kwargs.get("eval_pts", None)
        (dpn, rank) = self._dofs_pr_node(data)
        if self.mode == "auto":
            if dpn == 1:
                if self.mesh.cells().shape[1] == 2:
                    self.mode = "scalar_xy"
                else:
                    self.mode = "scalar"
            elif dpn == self.mesh.geometry().dim():
                self.mode = "vector"
            else:
                raise RuntimeError, "Can't plot function with %d dofs pr node" % dpn

        if self.pts is None:
            nno = self.mesh.numVertices()
            self.x = numpy.zeros(nno*dpn)
            data.interpolate(self.x)
            coords = self.mesh.coordinates()
        else:
            """Only arbitrary point evaluation for vector valued functions."""
            assert self.mode == "vector"
            nno = len(self.pts)
            self.x = numpy.zeros((nno,dpn))
            v = numpy.zeros(dpn, dtype='d')
            function_space = data.function_space()
            for (i, point) in enumerate(self.pts):
                function_space.eval(v, point, data)
                self.x[i] = v.copy()
            self.x = self.x.transpose().copy()
            self.x.shape = nno*dpn,
            coords = numpy.array(self.pts)
            coords.shape = (len(self.pts), len(self.pts[0]))

        vmin = kwargs.get("vmin", self.x.min())
        vmax = kwargs.get("vmax", self.x.max())

        assert vmax >= vmin, "Empty range, please specify vmin and/or vmax"

        if self.mode == "scalar":
            minmax = self.x.max() - self.x.min()
            if self.warpscalar and minmax > 0 and self.mesh.geometry().dim() < 3:
                self.filter = self.warp_scalar(self.x, minmax)
                #self.__warping = True
            (self.iren, self.renWin, self.ren) = self.simple_plotter(self.filter, vmin, vmax, wireframe=wireframe)
        elif self.mode == "scalar_xy":
            xx = self.mesh.coordinates()            
            (self.iren, self.renWin, self.ren) = self.plot_xy(xx,self.x,".-")
        elif self.mode in ("vector", "displacement"):
            self.x.shape = (dpn, nno)
            self.x = self.vec3d(self.x.transpose().copy())

            if  self.mode == "vector":
                (self.iren, self.renWin, self.ren) = self.vector_plotter(coords, self.x, vmin, vmax, wireframe=False)
            else:
                self.displacement = self.x.copy()
                # Use vector norms for scalar coloring
                self.x = numpy.sqrt(self.x[:,0]**2 + self.x[:,1]**2 + self.x[:,2]**2)
                self.filter = self.warp_vector(self.displacement)
                (self.iren, self.renWin, self.ren) = self.simple_plotter(self.filter, vmin, vmax, wireframe=False)

        self.iren.Initialize()
        self._update(self.x)
        self.renWin.Render()
        if kwargs.get("interactive", True):
            self.interactive()

    def plot_functionplotdata(self, data, *args, **kwargs):
        self.mesh = data.mesh
        self.vtkgrid = self.make_vtk_grid(self.mesh)
        self.filter = self.vtkgrid

        self.warpscalar = kwargs.get("warpscalar", True)
        self.rescale = kwargs.get("rescale", False)
        self.mode = kwargs.get("mode", "auto")
        wireframe=kwargs.get("wireframe", False)
        self.pts = kwargs.get("eval_pts", None)
        rank = data.rank
        dpn = 1
        if rank > 0:
            dpn = self.mesh.geometry().dim()
        if self.mode == "auto":
            if dpn == 1:
                if self.mesh.cells().shape[1] == 2:
                    self.mode = "scalar_xy"
                else:
                    self.mode = "scalar"
            elif dpn == self.mesh.geometry().dim():
                self.mode = "vector"
            else:
                raise RuntimeError, "Can't plot function with %d dofs pr node" % dpn

        nno = self.mesh.numVertices()
        self.x = numpy.zeros(nno*dpn)
        self.x = data.vertex_values.array().copy()
        coords = self.mesh.coordinates()

        vmin = kwargs.get("vmin", self.x.min())
        vmax = kwargs.get("vmax", self.x.max())

        assert vmax >= vmin, "Empty range, please specify vmin and/or vmax"

        if self.mode == "scalar":
            minmax = self.x.max() - self.x.min()
            if self.warpscalar and minmax > 0 and self.mesh.geometry().dim() < 3:
                self.filter = self.warp_scalar(self.x, minmax)
                #self.__warping = True
            (self.iren, self.renWin, self.ren) = self.simple_plotter(self.filter, vmin, vmax, wireframe=wireframe)
        elif self.mode == "scalar_xy":
            xx = self.mesh.coordinates()            
            (self.iren, self.renWin, self.ren) = self.plot_xy(xx,self.x,".-")
        elif self.mode in ("vector", "displacement"):
            self.x.shape = (dpn, nno)
            self.x = self.vec3d(self.x.transpose().copy())

            if  self.mode == "vector":
                (self.iren, self.renWin, self.ren) = self.vector_plotter(coords, self.x, vmin, vmax, wireframe=False)
            else:
                self.displacement = self.x.copy()
                # Use vector norms for scalar coloring
                self.x = numpy.sqrt(self.x[:,0]**2 + self.x[:,1]**2 + self.x[:,2]**2)
                self.filter = self.warp_vector(self.displacement)
                (self.iren, self.renWin, self.ren) = self.simple_plotter(self.filter, vmin, vmax, wireframe=False)

        self.iren.Initialize()
        self._update(self.x)
        self.renWin.Render()
        if kwargs.get("interactive", True):
            self.interactive()

    def plot_meshfunction(self, data, *args, **kwargs):
        self.mesh = data.mesh()
        dim = data.dim()
        size = data.size()
        values = data.values()
        self.vtkgrid = self.make_vtk_grid(self.mesh)
        self.filter = self.vtkgrid
        self.x = numpy.array(values, dtype='d')
        if str(values.dtype).count("int"):
            self.lutfile = ""
        if size in (self.mesh.numVertices(), self.mesh.numCells()):
            if dim==self.mesh.topology().dim(): # meshfunction over cells
                self.vertex_plot = False
                self.lutfile = ""
            (self.iren, self.renWin, self.ren) = self.simple_plotter(self.filter, self.x.min(), self.x.max(), wireframe=False)
        else:
            print "Unknown size %d" % size
            print "Valid size are v(%d) and c(%d)" % (self.mesh.numVertices(), self.mesh.numCells())
            raise RuntimeError, "Only vertex and cell valued meshfunctions can be plotted"
        self.iren.Initialize()
        self._update(self.x)
        self.renWin.Render()
        if kwargs.get("interactive", True):
            self.interactive()

    def _dofs_pr_node(self, f):
        element = FunctionSpace.element(f.function_space())
        rank = element.value_rank()
        dpn = 1
        for i in xrange(rank):
            dpn *= element.value_dimension(i)
        return dpn, rank

    def add_polygon(self, polygon, idx=0):
        assert isinstance(polygon, (list, tuple))
        numpoints = len(polygon)
        assert isinstance(polygon[0], (list, tuple, numpy.ndarray))
        points2d = False
        if len(polygon[0]) == 2:
            points2d = True
        points = vtk.vtkPoints()
        points.SetNumberOfPoints(numpoints)
        for i in xrange(numpoints):
            point = list(polygon[i])
            if points2d:
                point.append(0.0)
            points.InsertPoint(i, *point)
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(numpoints)
        for i in xrange(numpoints):
            line.GetPointIds().SetId(i, i)
        grid = vtk.vtkUnstructuredGrid()
        grid.Allocate(1, 1)
        grid.InsertNextCell(line.GetCellType(),
                            line.GetPointIds())
        grid.SetPoints(points)

        extract = vtk.vtkGeometryFilter()
        extract.SetInput(grid)
        extract.GetOutput().ReleaseDataFlagOn()
        self.polygon_data.AddInput(extract.GetOutput())

        actor = self.polygon_actors[idx]
        mapper = actor.GetMapper()
        if mapper is None:
            mapper = vtk.vtkPolyDataMapper()
            actor.SetMapper(mapper)
        
        mapper.SetInput(self.polygon_data.GetOutput())
        actor.GetProperty().SetColor(0, 0, 1)
        actor.GetProperty().SetLineWidth(1)

        self.ren.AddActor(actor)

    def dolfin_update(self, data):

        # Prepare data to viper internal format
        if self.plottype == "mesh":
            self.mesh = data
            self.vtkgrid = self.make_vtk_grid(self.mesh)
            self.filter = self.vtkgrid
            self.update_scalar_mapper(self.filter)
            self.x = numpy.zeros(self.mesh.numVertices())
        elif self.plottype == "function":
            (dpn, rank) = self._dofs_pr_node(data)
            if self.pts is None:
                nno = self.mesh.numVertices()
                self.x = numpy.zeros(nno*dpn)
                data.interpolate(self.x)
            else:
                nno = len(self.pts)
                self.x = numpy.zeros((nno,dpn))
                v = numpy.zeros(dpn, dtype='d')
                function_space = data.function_space()
                for (i, point) in enumerate(self.pts):
                    function_space.eval(v, point, data)
                    self.x[i] = v.copy()
                self.x = self.x.transpose().copy()
                self.x.shape = nno*dpn,

            if self.mode == "scalar":
                minmax = self.x.max() - self.x.min()
                if self.warpscalar and minmax > 0 and  self.mesh.geometry().dim() < 3:
                    self.filter = self.warp_scalar(self.x, minmax)
            if self.mode in ("vector", "displacement"):
                self.x.shape = (dpn, nno)
                self.x = self.vec3d(self.x.transpose().copy())
                if self.mode == "displacement":
                    self.displacement[:] = self.x.copy()
                    self.x = numpy.sqrt(self.x[:,0]**2 + self.x[:,1]**2 + self.x[:,2]**2)
        elif self.plottype == "meshfunction":
            self.mesh = data.mesh()
            self.vtkgrid = self.make_vtk_grid(self.mesh)
            self.filter = self.vtkgrid
            self.update_scalar_mapper(self.filter)
            self.x = numpy.array(data.values(), dtype='d')
        else:
            print "Unknown plottype %s. Can't update" % (self.plottype,)

        # Plot data
        self._update(self.x)

class PlotManager(object):
    def __init__(self):
        self.plots = {}
        self.index = None # By default, use automatic plotting
        self.length = 0
        import random
        self.random = random

    def figure(self, index):
        self.index = index

    def reset(self):
        del self.plots
        self.plots = {}
        self.index = None
        self.length = 0

    def plot(self, plot_object, *args, **kwargs):
        if self.index is None:
            return self.autoplot(plot_object, *args, **kwargs)
        return self.figureplot(plot_object, *args, **kwargs)

    def autoplot(self, plot_object, *args, **kwargs):
        for (idx, (plotter,obj)) in self.plots.items():
            if obj is plot_object:
                plotter.update(plot_object)
                return plotter
        idx = self.random.randint(0,10000)
        keys = self.plots.keys()
        while idx in keys:
            idx = self.random.randint(0,10000)

        kwargs["interactive"] = False
        plotter = Viper(plot_object, *args, **kwargs)
        self.plots[idx] = (plotter, plot_object)
        return plotter

    def figureplot(self, plot_object, *args, **kwargs):
        if self.plots.has_key(self.index):
            (plotter, obj) = self.plots[self.index]
            plotter.update(plot_object)
            self.plots[self.index] = (plotter, plot_object)
            return plotter

        kwargs["interactive"] = False
        plotter = Viper(plot_object, *args, **kwargs)
        self.plots[self.index] = (plotter, plot_object)
        return plotter

    def interactive(self):
        if len(self.plots) > 0:
            k = self.plots.keys()[0]
            self.plots[k][0].interactive()

def figure(index):
    global _plotter
    if _plotter is None:
        _plotter = PlotManager()
    _plotter.figure(index)

def plot(data, *args, **kwargs):
    global _plotter
    if _plotter is None:
        _plotter = PlotManager()

    interactive = False
    if kwargs.has_key("interactive"):
        interactive = kwargs["interactive"]
        kwargs["interactive"] = False
    fig = _plotter.plot(data, **kwargs)
    if interactive:
        _plotter.interactive()
    return fig

def update(data):
    print "Not implemented"
    return
    global _viper
    if _viper != None:
        _viper.update(data)
        return _viper
    print "No plot object, cannot update"

def interactive():
    global _plotter
    if _plotter is None:
        print "No plot object, interaction not possible"
        return
    _plotter.interactive()

def save_plot(u, filename="plot.png"):
    v = Viper(u, interactive=False)
    v.init_writer()
    v.write_png(filename)
