/*
 * Copyright (c) 2004 The University of Wroclaw.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met :
 *    1. Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *    2. Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *    3. The name of the University may not be used to endorse or promote
 *       products derived from this software without specific prior
 *       written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
 * NO EVENT SHALL THE UNIVERSITY BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

namespace Nemerle.Collections
{
  /**
   * A functional Red-Black Trees implementation
   */
  public module Tree
  {
    /*
     * Definition of the node Node ['a] of tree
     */
    public variant Node ['a] where 'a : System.IComparable ['a] {
      | Red {
          key : 'a;
          lchild : Node ['a];
          rchild : Node ['a]; 
        }
      | Black {
          key : 'a;
          lchild : Node ['a];
          rchild : Node ['a]; 
        }
      | Leaf
    }

    /**
     * Function finds a node and returns it (if any) as an option ['a]
     */
    public Get ['a] (tree : Node ['a], elem : 'a) : option ['a]
      where 'a : System.IComparable ['a]
    {
      mutable hascandidate = false;
      mutable candidate = elem;

      def get (tree : Node ['a]) : option ['a] 
      { 
        | Node.Leaf =>
          if (hascandidate) 
            if (elem.CompareTo (candidate) == 0)
              Some (candidate)
            else 
              None ()
          else
            None ()
          
        | Node.Black (key = key) as tree =>
          if (key.CompareTo (elem) > 0) get (tree.lchild)
          else {
            hascandidate = true;
            candidate = key;
            get (tree.rchild)
          }
          
        | Node.Red (key = key) as tree =>
          if (key.CompareTo (elem) > 0) get (tree.lchild)
          else {
            hascandidate = true;
            candidate = key;
            get (tree.rchild)
          }
      }
      get (tree)
    }

    
    /**
     * Function returns a passed tree TREE with inserted element ELEM. If node
     * is already present in tree either throw exception or replace node,
     * depending on REPLACE.
     */
    public Insert ['a] (tree : Node ['a], elem : 'a, replace : bool) : Node ['a]
      where 'a : System.IComparable ['a]
    {
      def insert (tree) {
        | Node.Red (key, ltree, rtree) =>
          if (elem.CompareTo (key) > 0)
            Node.Red (key, ltree, insert (rtree))
          else if (elem.CompareTo (key) < 0)
            Node.Red (key, insert (ltree), rtree)
          else if (replace)
            Node.Red (elem, ltree, rtree)
          else
            throw System.ArgumentException ("node already in the tree")
            
        | Node.Black (key, ltree, rtree) =>
          if (elem.CompareTo (key) > 0)
            BalanceRight (key, ltree, insert (rtree))
          else if (elem.CompareTo (key) < 0)
            BalanceLeft (key, insert (ltree), rtree)
          else if (replace)
            Node.Black (elem, ltree, rtree)
          else 
            throw System.ArgumentException ("node already in the tree")
            
        | Node.Leaf => 
          Node.Red (elem, Node.Leaf (), Node.Leaf ())
      }
      
      match (insert (tree)) {
        | (Node.Black) as tree => tree
        | Node.Red (key, ltree, rtree) => Node.Black (key, ltree, rtree)
        | Node.Leaf => assert (false)
      }
    }

    
    /** 
     * Function returns a passed tree TREE with removed element ELEM. If element was
     * not in the tree exception is thrown.
     */ 
    public Delete ['a] (tree : Node ['a], elem : 'a, throw_on_err : bool) : Node ['a]
      where 'a : System.IComparable ['a]
    {
      def delete (tree) {
        | Node.Red (key, ltree, rtree)
        | Node.Black (key, ltree, rtree) =>
          if (elem.CompareTo (key) > 0)
            match (rtree) {
              | Node.Black =>
                BalRight (key, ltree, delete (rtree))
              | _ => 
                Node.Red (key, ltree, delete (rtree))
            } 
          else if (elem.CompareTo (key) < 0)
            match (ltree) {
              | Node.Black =>
                BalLeft (key, delete (ltree), rtree)
              | _ =>
                Node.Red (key, delete (ltree), rtree)
            }
          else 
            GetSubst (ltree, rtree)
            
        | Node.Leaf =>
          if (throw_on_err)
            throw System.ArgumentException ("node not in the tree")
          else tree
      }
      
      match (delete (tree)) {
        | (Node.Black) as res => res
        | Node.Red (key, ltree, rtree) =>
          Node.Black (key, ltree, rtree)
        | Node.Leaf => Node.Leaf ()
      } 
    }

    public Delete ['a] (tree : Node ['a], elem : 'a) : Node ['a]
      where 'a : System.IComparable ['a]
    {
      Delete (tree, elem, true)
    }
    
    /**
     * Function goes through each TREE node and counts cumulative
     * value of function FUNC with intial value INI
     */
    public Fold ['a, 'b] (tree : Node ['a], ini : 'b, func : 'a * 'b -> 'b) : 'b 
      where 'a : System.IComparable ['a] 
    {
      match (tree) {
        | Node.Red (key, ltree, rtree)
        | Node.Black (key, ltree, rtree) =>
            Fold (rtree, func (key, Fold (ltree, ini, func)) , func)
        | Node.Leaf => ini
      }
    } 

    
    /**
     * Function returns true if and only if there exists such node X 
     * of TREE that FUNC(X) is true 
     */
    public Exists ['a] (tree : Node ['a], func : 'a -> bool) : bool
      where 'a : System.IComparable ['a] 
    {
      match (tree) {
        | Node.Red (key, ltree, rtree)
        | Node.Black (key, ltree, rtree) =>
          func (key) || Exists (ltree, func) || Exists (rtree, func)
        | _ => false
      }
    }

    
    /**
     * Function returns true if and only if for every node X 
     * of TREE FUNC(X) is true 
     */ 
    public ForAll ['a] (tree : Node ['a], func : 'a -> bool) : bool
      where 'a : System.IComparable ['a] 
    {
      match (tree) {
        | Node.Red (key, ltree, rtree)
        | Node.Black (key, ltree, rtree) =>
          func (key) && ForAll (ltree, func) && ForAll (rtree, func)
        | _ => true
      }
    }


    /**
     * Function returns TREE1 * INT1 * TREE2 * INT2 where tree TREE1 consists
     * of this nodes X of TREE that FUNC(X) is true and tree TREE2 contains
     * all nodes of TREE that are not in TREE1. INT1 is the size of TREE1 and
     * INT2 is the size of TREE2   
     */
    public CountPartition ['a] (tree : Node ['a], func : 'a -> bool) 
      : Node ['a] * int * Node ['a] * int where 'a : System.IComparable ['a]
    {
      def partition (tree, yntree)
      {
        match (tree) {
          | Node.Red (key, ltree, rtree)
          | Node.Black (key, ltree, rtree) =>
            def (ytree, ysize, ntree, nsize) = yntree;
            def yntree =
              if (func (key))
                (Insert (ytree, key, false), ysize + 1, ntree, nsize)
              else 
                (ytree, ysize, Insert (ntree, key, false), nsize + 1);
            partition (rtree, partition (ltree, yntree))
          | Node.Leaf => yntree
        }
      }
      
      partition (tree, (Node.Leaf (), 0, Node.Leaf (), 0))
    } 

    
    /**
     * Function returns TREE1 * TREE2 where tree TREE1 consists of this nodes X
     * of TREE that FUNC(X) is true and tree TREE2 contains all nodes of TREE 
     * that are not in TREE1.   
     */
    public Partition ['a] (tree : Node ['a], func : 'a -> bool)
      : Node ['a] * Node ['a] where 'a : System.IComparable ['a] 
    {
      def (ytree, _, ntree, _) = CountPartition (tree, func);
      (ytree, ntree)
    }

    
    /**
     * Functions returns TREE1 * INT1 where TREE1 is a tree that contains this nodes X 
     * of TREE that FUNC(X) is true and INT1 is the size of TREE1
     */
    public CountFilter ['a] (tree : Node ['a], func : 'a -> bool)
      : Node ['a] * int where 'a : System.IComparable ['a] 
    {
      def filter (tree, ytree)
      {
        match (tree) {
          | Node.Red (key, ltree, rtree)
          | Node.Black (key, ltree, rtree) =>
            def (ytree, ysize) = ytree;
            if (func (key))
              filter (rtree, filter (ltree, (Insert (ytree, key, false), ysize + 1)))
            else 
              filter (rtree, filter (ltree, (ytree, ysize)))
          | Node.Leaf => ytree
        }
      }
      
      filter (tree, (Node.Leaf (), 0))
    } 


    /**
     * Functions returns a tree that contains this nodes X 
     * of TREE that FUNC(X) is true
     */
    public Filter ['a] (tree : Node ['a], func : 'a -> bool) : Node ['a]
      where 'a : System.IComparable ['a] 
    {
      def (tree, _) = CountFilter (tree, func);
      tree
    }

  
/*
  FIXME: There is a problem with this function - it doesn't need to do correct tranformation

    public Map['a,'b] (tree : Node ['a], func : 'a -> 'b) : Node ['b]
      where 'a : System.IComparable ['a]
      where 'b :> System.IComparable ['b]
    {
      def map (tree: Node ['a]) : Node ['b]
          {
            match (tree) {
              | Tr (key, color, ltree, rtree) =>
                  Tr (func (key), color, map (ltree), map (rtree))
              | Node.Leaf => Node.Leaf ()
            } 
          };
      map (tree)
    }
*/   

    /**
     * Internal functions used for tree balancing 
     */
    private BalRight ['a] (elem : 'a, lchild : Node ['a], rchild : Node ['a])
      : Node ['a] where 'a : System.IComparable ['a]
    {
      match ((elem, lchild, rchild)) {
        | (key, ltree, Node.Red (key1, ltree1, rtree1)) =>
          Node.Red (key, ltree, Node.Black (key1, ltree1, rtree1))
        | (key, Node.Black (key1, ltree1, rtree1), rtree) =>
          BalanceLeft (key, Node.Red (key1, ltree1, rtree1), rtree)
        | (key, Node.Red (key1, Node.Black (key3, ltree3, rtree3), Node.Black (key2, ltree2, rtree2)), rtree) =>
          Node.Red (key2,
                   BalanceLeft (key1, Node.Red (key3, ltree3, rtree3), ltree2),
                   Node.Black (key, rtree2, rtree))
        | _ =>
          assert (false, "balance violation")
      }
    }

    private BalLeft ['a] (elem : 'a, lchild : Node ['a], rchild : Node ['a])
      : Node ['a] where 'a : System.IComparable ['a]
    {
      match ((elem, lchild, rchild)) {
        | (key, Node.Red (key1, ltree1, rtree1), rtree) =>
          Node.Red (key, Node.Black (key1, ltree1, rtree1), rtree)
        | (key, ltree, Node.Black (key1, ltree1, rtree1)) =>
          BalanceRight (key, ltree, Node.Red (key1, ltree1, rtree1))
        | (key, ltree, Node.Red (key1, Node.Black (key3, ltree3, rtree3), Node.Black (key2, ltree2, rtree2))) =>
          Node.Red (key3,
                   Node.Black (key, ltree, ltree3),
                   BalanceRight (key1, rtree3, Node.Red (key2, ltree2, rtree2)))
        | _ =>
          assert (false, "balance violation")
      }
    }

    private GetSubst ['a] (lchild : Node ['a], rchild : Node ['a])
      : Node ['a] where 'a : System.IComparable ['a]
    {
      match ((lchild, rchild)) {
        | (Node.Leaf, tree) => tree
        | (tree, Node.Leaf) => tree
        | (Node.Red (key, ltree, rtree), Node.Red (key1, ltree1, rtree1)) =>
          match (GetSubst (rtree, ltree1)){
            | Node.Red (key2, ltree2, rtree2) =>
              Node.Red (key2, Node.Red (key, ltree, ltree2), Node.Red (key1, rtree2, rtree1))
            | tree => Node.Red (key, ltree, Node.Red (key1, tree, rtree1))
          }
        | (Node.Black (key, ltree, rtree), Node.Black (key1, ltree1, rtree1)) =>
          match (GetSubst (rtree, ltree1)){
            | Node.Red (key2, ltree2, rtree2) =>
              Node.Red (key2, Node.Black (key, ltree, ltree2), Node.Black (key1, rtree2, rtree1))
            | tree => BalLeft (key, ltree, Node.Black (key1, tree, rtree1))
          }
        | (tree, Node.Red (key, ltree, rtree)) =>
          Node.Red (key, GetSubst (tree, ltree), rtree)
        | (Node.Red (key, ltree, rtree), tree) =>
          Node.Red (key, ltree, GetSubst (rtree, tree))
      }
    }

    private BalanceLeft['a] (elem : 'a, lchild : Node ['a], rchild : Node ['a])
      : Node ['a] where 'a : System.IComparable ['a]
    {
      match ((elem, lchild, rchild)) {
        | (key, Node.Red (key1, ltree1, rtree1), Node.Red (key2, ltree2, rtree2)) =>
          Node.Red (key, Node.Black (key1, ltree1, rtree1), Node.Black (key2, ltree2, rtree2))
        | (key, Node.Red (key1, Node.Red (key2, ltree2, rtree2), rtree1), rtree) =>
          Node.Red (key1, Node.Black (key2, ltree2, rtree2), Node.Black (key, rtree1, rtree))
        | (key, Node.Red (key1, ltree1, Node.Red (key2, ltree2, rtree2)), rtree) =>
          Node.Red (key2, Node.Black (key1, ltree1, ltree2), Node.Black (key, rtree2, rtree))
        | (key, ltree, rtree) =>
          Node.Black (key, ltree, rtree)
      } 
    }

    private BalanceRight['a] (elem : 'a, lchild : Node ['a], rchild : Node ['a])
      : Node ['a] where 'a : System.IComparable ['a]
    {
      match ((elem, lchild, rchild)) {
        | (key, Node.Red (key1, ltree1, rtree1), Node.Red (key2, ltree2, rtree2)) =>
          Node.Red (key, Node.Black (key1, ltree1, rtree1), Node.Black (key2, ltree2, rtree2))
        | (key, ltree, Node.Red (key1, ltree1, Node.Red (key2, ltree2, rtree2))) =>
          Node.Red (key1, Node.Black (key, ltree, ltree1), Node.Black (key2, ltree2, rtree2))
        | (key, ltree, Node.Red (key1, Node.Red (key2, ltree2, rtree2), rtree1)) =>
          Node.Red (key2, Node.Black (key, ltree, ltree2), Node.Black (key1, rtree2, rtree1))
        | (key, ltree, rtree) =>
          Node.Black (key, ltree, rtree)
      } 
    } 
  } 


  /* ------------------------------------------------------------------------ */
  /* -- PUBLIC INTERFACES --------------------------------------------------- */
  /* ------------------------------------------------------------------------ */

  /**
   * Interface dedicated to be the only way to interact with Map object.
   *
   * FIXME: why this isn't IDictionary?
   */
  public interface IMap ['a, 'b]
  {
    /**
     * Method returns a IMap ('a, 'b) with added pair (k, v)
     */
    Add (k : 'a, v : 'b) : IMap ['a, 'b];
    
    /**
     * Returns the value associated with a key.
     */
    Get (k : 'a) : 'b;

    /**
     * Method returns an empty IMap ['a, 'b]
     */
    Clear () : IMap ['a, 'b];
    
    /**
     * Method returns a copy of THIS IMap ['a, 'b] 
     */
    Copy () : IMap ['a, 'b];
    
    /**
     * Method returns true if and only if there exists such pair (X,Y) 
     * of THIS IMap ('a,'b) that FUNC(X,Y) is true 
     */
    Exists (func : 'a * 'b -> bool) : bool;
    
    /**
     * Method returns an IMAP that consists of THIS pair (X,Y)
     * of THIS IMap that FUNC(X) is true   
     */ 
    Filter (func : 'a * 'b -> bool) : IMap ['a, 'b];
    
    /**
     * Method finds and returns a value associated with key K
     * (if there is no such value then None is returned)
     */
    Find (k : 'a) : option ['b];
    
    /**
     * Method returns some value that is contained in IMap
     * Note: This value depends on IMap manipulation 
     */
    First () : 'b;
    
    /**
     * Method goes through each of THIS IMap pair and counts cumulative
     * value of function FUNC with intial value INI
     */  
    Fold['c] (ini : 'c, func : 'a * 'b * 'c -> 'c) : 'c;
    
    /**
     * Method returns true if and only if for every pair (X,Y) 
     * of THIS IMap ('a,'b) FUNC(X,Y) is true 
     */
    ForAll (func : 'a * 'b -> bool) : bool;
    
    /**
     * Method goes through each of THIS Imap pair (X,Y) and computes
     * FUNC (X,Y) 
     */
    Iter (func : 'a * 'b -> void) : void;
    
    /**
     * Method return true if a key K is contained in THIS IMap
     */
    Member (k : 'a) : bool;
    
    /**
     * Method returns IMAP1 * IMAP2 where IMAP1 consists of this pair (X,Y)
     * of IMAP1 that FUNC(X) is true and IMAP2 contains all this pairs of THIS IMap 
     * that are not in IMAP1.   
     */
    Partition (func : 'a * 'b -> bool) : IMap ['a, 'b] * IMap ['a, 'b];
    
    /**
     * Method returns THIS IMap with removed key K and associated value
     */
    Remove (k : 'a) : IMap ['a, 'b];
    
    /**
     * Method returns THIS IMap with replaced pair (K,V)
     */
    Replace (k : 'a, v : 'b) : IMap ['a, 'b];

    /**
     * Checks if there are any elements in the map.
     */
    IsEmpty : bool { get; }
    
    /**
     * Returns the number of elements in THIS IMap
     */
    Size : int { get; }

    /**
     * Returns the number of elements in THIS IMap
     */
    Count : int { get; }
  }     

  internal struct NodeNem ['a, 'b]
    : System.IComparable [NodeNem ['a, 'b]]
    where 'a : System.IComparable ['a] 
  {
    public key : 'a;
    public val : 'b;
    
    public CompareTo (x : NodeNem ['a, 'b]) : int
    { 
      key.CompareTo (x.key) 
    }

    public this (k : 'a, v : 'b) 
    { 
      this.key = k; 
      this.val = v;
    }

    public this (k : 'a) 
    {
      this.key = k;
    }
  } 

  /* definition of Map class */
   
  public class Map ['a, 'b]
//    : IMap ['a, 'b] 
    where 'a : System.IComparable ['a]
  {
    private root : Tree.Node [NodeNem ['a, 'b]];
    private size : int;

    // TODO : Make it implement ICollection (problem with names and being functional or imperative - both mb)

    public this ()
    {
      this.root = Tree.Node.Leaf ();
      this.size = 0;
    }

    private this (size : int, r : Tree.Node [NodeNem ['a, 'b]]) 
    {
      this.size = size;
      this.root = r;
    }

    public Copy () : Map ['a, 'b] 
    {
      this
    }


    public First () : 'b  
    {
      match (root) {
        | Tree.Node.Leaf => throw System.ArgumentException ("map is empty")
        | Tree.Node.Red (key, _, _)
        | Tree.Node.Black (key, _, _) => key.val
      }
    }

    public IsEmpty : bool 
    {
      get
      {
        match (root) {
          | Tree.Node.Leaf => true
          | _ => false
        }
      }
    }

    public Clear () : Map ['a, 'b] 
    {
      _ = this;
      Map ()
    }

    public Add (k : 'a, v : 'b) : Map ['a, 'b]  
    { 
      Map (size + 1, Tree.Insert (root, NodeNem (k, v), false))
    }

    public Replace (k : 'a, v : 'b) : Map ['a, 'b]  
    {
      def node = NodeNem (k, v);
      if (Option.IsSome (Tree.Get (root, node)))
        Map (size, Tree.Insert (root, NodeNem (k, v), true))
      else
        Map (size + 1, Tree.Insert (root, NodeNem (k, v), true))
    }

    public Find (k : 'a) : option ['b]  
    {
      match (Tree.Get (root, NodeNem (k))) {
        | Some (n) => Some (n.val) 
        | None => None ()
      }
    }

    public Get (k : 'a) : 'b  
    {
      match (Tree.Get (root, NodeNem (k))) {
        | Some (n) => n.val 
        | None => throw System.ArgumentException ("key not found")
      }
    }

    public Member (k : 'a) : bool 
    {
      Contains (k)
    }

    public Contains (k : 'a) : bool
    {
      Option.IsSome (Tree.Get (root, NodeNem (k)))
    }
    
    public Remove (k : 'a) : Map ['a, 'b] 
    {
      Map (size - 1, Tree.Delete (root, NodeNem (k), true))
    }

    public Size : int 
    {
      get { size }
    }

    public Count : int 
    {
      get { size }
    }
    
    public Fold['d] (ini : 'd, func : 'a * 'b * 'd -> 'd) : 'd 
    {
      def wrap (n : NodeNem ['a, 'b], ctx) { func (n.key, n.val, ctx) };
      Tree.Fold (root, ini, wrap)
    }

    public Iter (func : 'a * 'b -> void) : void 
    {
      def wrap (n : NodeNem ['a, 'b], ctx) { 
        func (n.key, n.val); ctx 
      };
      ignore (Tree.Fold (root, null, wrap));
    }
    
    public ForAll (func : 'a * 'b -> bool) : bool 
    {
      def wrap (keyval : NodeNem ['a, 'b]) {
        func (keyval.key, keyval.val) 
      }; 
      Tree.ForAll (root, wrap)
    }

    public Exists (func : 'a * 'b -> bool) : bool 
    {
      def wrap (keyval : NodeNem ['a, 'b]) {
        func (keyval.key, keyval.val) 
      }; 
      Tree.Exists (root, wrap)
    }

    public Filter (func : 'a * 'b -> bool) : Map ['a, 'b] 
    {
      def wrap (keyval : NodeNem ['a, 'b]) {
        func (keyval.key, keyval.val) 
      };
      def (ytree, ycount) = Tree.CountFilter (this.root, wrap);
      Map (ycount, ytree) 
    }

    public Partition (func : 'a * 'b -> bool) 
      : Map ['a, 'b] * Map ['a, 'b] 
    {
      def wrap (keyval : NodeNem ['a, 'b]) { 
        func (keyval.key, keyval.val) 
      };
      def (ytree, ycount, ntree, ncount) = Tree.CountPartition (this.root, wrap);
      (Map (ycount, ytree), Map (ncount, ntree)) 
    }

  }
}
