#-------------------------------------------------------------------------------
#
#  Define a 'splitter bar' control for the Enable toolkit.
#
#  Written by: David C. Morrill
#
#  Date: 11/27/2003
#
#  (c) Copyright 2003 by Enthought, Inc.
#
#  Classes defined: Splitter
#
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
#  Imports:
#-------------------------------------------------------------------------------

from enthought.traits.api    import Trait, List, Range, TraitError, true
from enthought.traits.ui.api import Group, View, Include

from base                import xy_in_bounds, transparent_color
from base_container      import BaseContainer                             
from component           import Component
from enable_traits       import layout_style_trait, eval_editor, \
                                basic_sequence_types, clear_color_trait, \
                                black_color_trait

#-------------------------------------------------------------------------------
#  Verify the validity of a set of splitter sizes:
#-------------------------------------------------------------------------------

def valid_sizes ( object, name, value ):
    try:
        if ((type( value ) in basic_sequence_types) and 
            (len( value ) == len( object.components ))):
           if len( value ) == 0:
               return value
           total = 0.0
           for n in value:
               n = float( n )
               if n < 0.0:
                   raise TraitError
               total += n
           if total == 0.0:
               value    = list( value )
               value[0] = total = 1.0
           return [ x / total for x in value ]
    except:
        pass
    raise TraitError
    
valid_sizes.info = 'a list or tuple of non-negative numbers'

#-------------------------------------------------------------------------------
#  'Splitter' class:
#-------------------------------------------------------------------------------

class Splitter ( Component, BaseContainer ):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    components   = List( Component )
    style        = Trait( 'horizontal', layout_style_trait )
    bg_color     = clear_color_trait
    border_color = black_color_trait
    bar_size     = Range( 1, 15, value = 4 )
    sizes        = Trait( [], valid_sizes, edit = eval_editor )
    auto_scale   = true
    
    #---------------------------------------------------------------------------
    #  Trait view definitions:
    #---------------------------------------------------------------------------
    
    traits_view = View( Group( '<component>', id = 'component' ),
                        Group( '<links>', '<clinks>', 'components', 
                               id = 'links' ),
                        Group( 'bg_color{Background color}', '_', 
                               'style', 'auto_scale', 'bar_size', ' ', 
                               'sizes', 
                               id    = 'splitter',
                               style = 'custom' ) )
    
    colorchip_map = {
        'bg_color':  'bg_color',
        'alt_color': 'border_color'
    }
    
    #---------------------------------------------------------------------------
    #  Initialize the object:
    #---------------------------------------------------------------------------
    
    def __init__ ( self, *components, **traits ):
        self.components = []   ### TEMPORARY ###
        Component.__init__( self, **traits )
        self.add( *components )
        
    #---------------------------------------------------------------------------
    #  Add one or more components to the container:
    #---------------------------------------------------------------------------
    
    def add ( self, *components ):
        the_components = self.components
        for component in components:
            self.add_at( len( the_components ), component )
        
    #----------------------------------------------------------------------------
    #  Add a component at a specified index:
    #----------------------------------------------------------------------------
    
    def add_at ( self, index, component ):
        if index < 0:
            raise IndexError
        the_components = self.components
        sizes          = self.sizes[:]
        container      = component.container
        if container is self:
            try:
                index = the_components.index( component )
                del the_components[ index ]
                del sizes[ index ]
            except:
                pass
        else:
            container.remove( component )
        the_components[ index: index ] = [ component ]
        component.container = self
        if len( sizes ) == 0:
            sizes = [ 1.0 ]
        else:
            sizes[ index: index ] = [ reduce( lambda a, b: a + b, sizes ) / 
                                      len( sizes ) ]
        self.sizes = sizes
                
    #---------------------------------------------------------------------------
    #  Remove one or more components from the container:
    #---------------------------------------------------------------------------
    
    def remove ( self, *components ):
        the_components = self.components
        sizes          = self.sizes[:]
        for component in components:
            try:
                index = the_components.index( component )
                del the_components[ index ]
                del sizes[ index ]
                component.container = default_container
            except:
                pass
        self.sizes = sizes
        
    #---------------------------------------------------------------------------
    #  Handle the bounds being changed: 
    #---------------------------------------------------------------------------
    
    def _bounds_changed ( self, old, new ):
        Component._bounds_changed( self, old, new )
        if self.auto_scale:
            self.update_sizes()
        else:
            self.update_bounds()
        
    #---------------------------------------------------------------------------
    #  Handle the style being changed:
    #---------------------------------------------------------------------------
    
    def _style_changed ( self ):
        self._sizes_changed()

    #----------------------------------------------------------------------------
    #  Handle the bar size being changed:
    #----------------------------------------------------------------------------
    
    def _bar_size_changed ( self ):
        self._sizes_changed()
        
    #---------------------------------------------------------------------------
    #  Handle the sizes trait being changed:
    #---------------------------------------------------------------------------
    
    def _sizes_changed ( self ):
        self.redraw()
        self.update_sizes()
        
    #---------------------------------------------------------------------------
    #  Update the sizes of each of the contained components:
    #---------------------------------------------------------------------------
    
    def update_sizes ( self ):
        x, y, dx, dy  = self.bounds
        sizes         = self.sizes
        self._cbounds = _cbounds = []
        self._bars    = _bars   = []
        num_bars      = len( sizes ) - 1
        if self.style[0] == 'v':
            bdy   = float( self.bar_size )
            ady   = max( 0.0, dy - (num_bars * bdy) )
            total = pdy = 0.0
            y    += dy
            for i, component in enumerate( self.components ):
                total += sizes[i]
                ndy    = round( total * ady )
                cdy    = ndy - pdy
                y     -= cdy
                bounds = ( x, y, dx, cdy )
                component.bounds = bounds
                _cbounds.append( bounds )
                pdy = ndy
                y  -= bdy
                if i < num_bars:
                    _bars.append( ( x, y, dx, bdy ) )
        else:
            bdx   = float( self.bar_size )
            adx   = max( 0.0, dx - (num_bars * bdx) )
            total = pdx = 0.0
            for i, component in enumerate( self.components ):
                total += sizes[i]
                ndx    = round( total * adx )
                cdx    = ndx - pdx
                bounds = ( x, y, cdx, dy )
                component.bounds = bounds
                _cbounds.append( bounds )
                pdx = ndx
                x  += cdx + bdx
                if i < num_bars:
                    _bars.append( ( x - bdx, y, bdx, dy ) )
        
    #---------------------------------------------------------------------------
    #  Update the sizes of each of the contained components to match the
    #  bounds of the component (without moving the splitter bars):
    #---------------------------------------------------------------------------
    
    def update_bounds ( self ):
        x, y, dx, dy = self.bounds
        bars         = self._bars
        cbounds      = self._cbounds
        components   = self.components
        bdxy         = float( self.bar_size )
        if self.style[0] == 'v':
            component        = components[0]
            component.height = max( 0, dy - (component.y - y) )
            for i, component in enumerate( components ):
                component.width = dx
                cbounds[i]      = component.bounds
            for i, bounds in enumerate( bars ):
                bars[i] = ( bounds[0], bounds[1], dx, bounds[3] ) 
        else:
            component        = components[-1]
            component.width  = max( 0, dx - (component.x - x) )
            for i, component in enumerate( components ):
                component.height = dy
                cbounds[i]       = component.bounds
            for i, bounds in enumerate( bars ):
                bars[i] = ( bounds[0], bounds[1], bounds[2], dy ) 
                    
    #---------------------------------------------------------------------------
    #  Return whether or not a specified point is over one of the splitter bars:
    #---------------------------------------------------------------------------
    
    def _is_over ( self, x, y = None ):
        if y is None:
            y = x.y
            x = x.x
        for i, bounds in enumerate( self._bars ):
            if xy_in_bounds( x, y, bounds ):
                return i
        return -1
                       
    #---------------------------------------------------------------------------
    #  Return the components that contain a specified (x,y) point:
    #---------------------------------------------------------------------------
       
    def _components_at ( self, x, y ):
        if self._is_over( x, y ) >= 0:
            return [ self ]
        _cbounds = self._cbounds
        for i, component in enumerate( self.components ):
            if xy_in_bounds( x, y, _cbounds[i] ):
                return component.components_at( x, y )
        return []

    #---------------------------------------------------------------------------
    #  Draw the splitter and its contained components in a specified graphics 
    #  context:
    #---------------------------------------------------------------------------
    
    def _draw ( self, gc ):
        # Draw each component contained in the splitter:
        _cbounds = self._cbounds
        for i, component in enumerate( self.components ):
            gc.save_state()
            gc.clip_to_rect(*_cbounds[i])
            component.draw(gc)
            gc.restore_state()
        
        # Draw the splitter bars (if required):
        bg_color = self.bg_color_
        if bg_color is not transparent_color:
            gc.save_state()
            gc.set_fill_color( bg_color )
            gc.set_stroke_color( self.border_color_ )
            for bounds in self._bars:
                x, y, dx, dy = bounds
                gc.begin_path()
                gc.rect( x + 0.5, y + 0.5, dx - 1.0, dy - 1.0 )
                gc.draw_path()
            gc.restore_state()
        
    #---------------------------------------------------------------------------
    #  Handle mouse events:
    #---------------------------------------------------------------------------
    
    def _left_down_changed ( self, event ):
        event.handled = True
        bar = self._is_over( event )
        if bar >= 0:
            self.window.mouse_owner = self
            sizes    = self.sizes
            _cbounds = self._cbounds
            if self.style[0] == 'v':
                self._dragging = event.y
                self._info     = ( bar, sizes[ bar ] + sizes[ bar + 1 ],
                                   _cbounds[ bar ][3],
                                   _cbounds[ bar ][3] + _cbounds[ bar + 1 ][3],
                                   event.y )
            else:
                self._dragging = event.x
                self._info     = ( bar, sizes[ bar ] + sizes[ bar + 1 ],
                                   _cbounds[ bar ][2],
                                   _cbounds[ bar ][2] + _cbounds[ bar + 1 ][2],
                                   event.x )
    
    def _left_up_changed ( self, event ):
        event.handled = True
        self.pointer  = 'arrow'
        self.window.mouse_owner = self._dragging = self._info = None
    
    def _mouse_move_changed ( self, event ):
        event.handled = True
        if self._dragging is not None:
            if self.style[0] == 'v':
                np    = event.y
                mult  = -1.0
                index = 3
            else:
                np    = event.x
                mult  = 1.0
                index = 2
            if np != self._dragging:
                bar, tsize, sdp, tdp, sp = self._info
                ndp = sdp + mult * (np - sp)
                ndp = max( 0.0, min( tdp, ndp ) )
                if ndp != self._cbounds[ bar ][ index ]:
                    self._dragging = np 
                    nsize          = (ndp * tsize) / tdp
                    sizes          = self.sizes[:]
                    sizes[ bar: bar + 2 ] = [ nsize, tsize - nsize ]
                    self.sizes = sizes
        elif self._is_over( event ) >= 0:
            self.window.mouse_owner = self
            self.pointer = [ 'size left', 'size top' ][ self.style[0] == 'v' ]
        else:
            self.window.mouse_owner = None
            self.pointer = 'arrow'
