/*
 * Copyright (c) 2005 The University of Wroclaw.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *    1. Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *    2. Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *    3. The name of the University may not be used to endorse or promote
 *       products derived from this software without specific prior
 *       written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
 * NO EVENT SHALL THE UNIVERSITY BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

using Nemerle.Collections;
using Nemerle.Utility;
using Nemerle.Logging;

using Nemerle.Compiler;
using Nemerle.Compiler.Typedtree;
using Nemerle.Compiler.SolverMacros;


#if DEBUG
  #define CHECK_STV
#endif

[assembly: LogFlag (STV, false)]

namespace Nemerle.Compiler
{
  class Typer4
  {
    current_fun : Fun_header;
    current_type : TypeBuilder;
    messenger : Messenger;
    the_method : MethodBuilder;


    #region Entry points
    public this (meth : MethodBuilder)
    {
      current_fun = meth.GetHeader ();
      the_method = meth;
      messenger = Passes.Solver.CurrentMessenger;
      current_type = meth.DeclaringType;
    }


    public Run () : void
    {
      Util.locate (current_fun.loc, {
        match (current_fun.body) {
          | FunBody.Typed (body) =>
            label_blocks.Clear ();
            // Message.Debug ("start ush");
            def body' = UnShare (body);
            // Message.Debug ("start T4");
            def body' = Walk (body');
            def body' =
              if (NeedBoxing (current_fun.ret_type.Fix (), body'))
                Box (current_fun.ret_type.Fix (), body')
              else body';

            #if CHECK_STV
              current_t4 = this;
              foreach (tv in current_fun.typarms) {
                Util.cassert (tv.current_type == null);
                tv.current_type = current_type;
                tv.current_method = the_method;
              }
              foreach (tv in current_type.typarms) {
                when (tv.current_type == null)
                  tv.current_type = current_type;
              }
              log (STV, $"check header $the_method");
              foreach (parm in current_fun.parms)
                CheckSTV (parm.ty);
              CheckSTV (current_fun.ret_type);
            #endif

            goto_targets.Clear ();
            _ = Throws (body', allow_try = true, is_top = true);

            current_fun.body = FunBody.Typed (body');
            when (Options.ShouldDump (current_fun))
              Message.Debug ($ "after T4: $the_method [$(current_fun.name)] "
                               ":\n$(body')\n");
          | _ => assert (false)
        }
      })
    }


    static NeedBoxing (target_type : MType, expr : TExpr) : bool
    {
      def src = expr.Type.Fix ();

      (src.IsValueType || src is MType.Void) 
       && (!target_type.IsValueType && !(target_type is MType.TyVarRef)
       ) ||
       (src is MType.TyVarRef && !target_type.IsValueType)
    }
    

    static Box (target_type : MType, expr : TExpr) : TExpr
    {
      TExpr.TypeConversion (target_type, expr, target_type,
                            ConversionKind.Boxing ())
    }

    static Convert (target_type : TyVar, expr : TExpr) : TExpr
    {
      TExpr.TypeConversion (target_type, expr, target_type,
                            ConversionKind.GenericSim ());
    }
    #endregion


    #region Throw handling
    static goto_targets : Hashtable [int, object] = Hashtable ();

    static IsJumpTarget (expr : TExpr) : bool
    {
      def res =
        match (expr) {
          | Label (id, _) =>
            goto_targets.Contains (id);
          | Sequence (e, _) =>
            IsJumpTarget (e)
          | _ => false
        }
      expr.JumpTarget = res;
      res
    }

    #if CHECK_STV
    static mutable current_t4 : Typer4;
    static CheckSTV (t : TyVar) : void
    {
      if (Passes.Solver.CanEnterPossiblyLooping ())
        try {
          match (t.Fix ()) {
            | Class (_, args)
            | Tuple (args) =>
              args.Iter (CheckSTV);

            | Intersection (args) =>
              foreach (a in args) CheckSTV (a);

            | Ref (t)
            | Out (t)
            | Array (t, _) =>
              CheckSTV (t);

            | Void => {}

            | Fun (t1, t2) =>
              CheckSTV (t1);
              CheckSTV (t2);

            | TyVarRef (tv) =>
              Util.cassert (tv.current_type != null, $"type is null for $tv");
              def check (t : TypeInfo) {
                if (t == null)
                  Util.ice ($ "tv $tv defined in $(tv.current_type) and accessed from "
                               "$(current_t4.current_type)");
                else if (t : object == tv.current_type) {}
                else check (t.DeclaringType)
              }
              check (current_t4.current_type);
              Util.cassert (tv.current_method == null ||
                            tv.current_method : object == current_t4.the_method,
                            $ "tv $tv defined in $(tv.current_method) and accessed from "
                               "$(current_t4.the_method)");
              {}
          }
        } finally {
          Passes.Solver.LeavePossiblyLooping ()
        }
      else {
        ReportError (Passes.Solver.CurrentMessenger,
                     $ "CheckSTV for $t failed");
        when (Passes.Solver.CurrentMessenger.NeedMessage)
          Message.MaybeBailout ();
      }
    }
    #endif

    static NoThrowPlease (e : TExpr) : void
    {
      when (Throws (e))
        Message.Error (e.loc, "`throw' is not allowed here");
    }

    static Throws (expr : TExpr, allow_try = false, is_top = false) : bool
    {
      log (STV, expr.loc, $ "{ throws: top=$is_top $(expr.GetType()) $expr");
      #if CHECK_STV
        when (expr.ty != null) {
          log (STV, expr.loc, $ "( : $(expr.Type)");
          CheckSTV (expr.Type);
          log (STV, expr.loc, $ "done )");
        }
        match (expr) {
          | DefValIn (name, _, _)
          | TryWith (_, name, _) =>
            log (STV, expr.loc, $ "( name($(name)) : $(name.Type)");
            CheckSTV (name.Type);
            log (STV, expr.loc, $ "done )");
          | _ => {}
        }
      #endif
      def res = 
        match (expr) {
          | Throw (e) =>
            unless (e == null)
              NoThrowPlease (e);
            true

          | Goto (id, _) as g =>
            goto_targets [id] = null;
            if (label_blocks.Contains (id))
              g.try_block -= label_blocks [id];
            else
              Message.Error (expr.loc, 
                             $ "non local goto (block return?) detected (l$id)");
            true

          | Sequence (e1, e2)
          | DefValIn (_, e1, e2) =>
            if (Throws (e1, allow_try)) {
              if (IsJumpTarget (e2))
                Throws (e2, allow_try, is_top)
              else true
            } else Throws (e2, allow_try, is_top)

          | If (cond, e1, e2) =>
            _ = Throws (cond, allow_try);
            def th1 = Throws (e1, allow_try, is_top);
            Throws (e2, allow_try, is_top) && th1;

          | Assign (e1, e2) =>
            when (Throws (e1))
              Message.Error (expr.loc, "`throw' in assignment target");

            if (e1 is TExpr.LocalRef)
              if (Throws (e2, allow_try)) {
                // Message.Warning (expr.loc, "bad style: each branch in this block "
                //                 "ends with a break");
                true
              } else false
            else
              if (Throws (e2)) {
                Message.Error (expr.loc, "`throw' in assignment source");
                true
              } else false

          | TypeConversion (e, t, _) =>
            #if CHECK_STV
              CheckSTV (t);
            #endif
            if (is_top && Options.GeneralTailCallOpt)
              Throws (e, allow_try, t.Fix ().Equals (e.MType))
            else
              Throws (e, allow_try, false)

          | Label (_, e) =>
            Throws (e, allow_try, is_top)

          | Switch (idx, defl, opts) =>
            // avoid closure
            mutable all = Throws (idx, allow_try);
            foreach ((_, e) in opts)
              all = Throws (e, allow_try, is_top) && all;
            match (defl) {
              | Some (e) => Throws (e) && all
              | None => all // ???
            }

          | LocalFunRef
          | PropertyMember
          | StaticPropertyRef
          | EventMember
          | StaticEventRef
          | ConstantObjectRef
          | Delayed
          | Error
          | DefFunctionsIn
          | Match
          | Block
          | SelfTailCall =>
            Util.cassert (Message.SeenError);
            false

          | MethodRef (e, meth, tp, _) =>
            NoThrowPlease (e);
            Util.cassert (meth.GetHeader ().typarms.Length == tp.Length,
                          $ "typarms check failed for $meth "
                            "$(meth.GetHeader ().typarms) $tp");
            #if CHECK_STV
              foreach (tp in tp) CheckSTV (tp);
            #endif
            false
            
          | HasType (e, _t) =>
            NoThrowPlease (e);
            #if CHECK_STV
              CheckSTV (_t);
            #endif
            false
          
          | FieldMember (e, _)
          | TupleIndexer (e, _, _)
          | NotNull (e) =>
            NoThrowPlease (e);
            false
          
          | ArrayIndexer (obj, args) =>
            NoThrowPlease (obj);
            args.Iter (NoThrowPlease);
            false
            
          | MultipleAssign (assigns) =>
            foreach ((_, e) in assigns)
              NoThrowPlease (e);
            false

          | Array (args, dimensions) =>
            args.Iter (NoThrowPlease);
            dimensions.Iter (NoThrowPlease);
            false
            
          | Call (func, parms, _) =>
            NoThrowPlease (func);
            def parm_throws =
              match (parms) {
                | p :: ps =>
                  def fst = Throws (p.expr, allow_try);
                  foreach (parm in ps) NoThrowPlease (parm.expr);
                  fst
                  
                | [] => false
              }
            if (parm_throws) true
            else
              if (is_top && Options.GeneralTailCallOpt) {
                expr.GenerateTail = true;
                foreach (p in parms)
                  when (p.kind != ParmKind.Normal)
                    expr.GenerateTail = false;

                match (func) {
                  | MethodRef (obj = obj)
                    when obj.NeedsConstrained || obj.MType.IsValueType =>
                    expr.GenerateTail = false;
                  | _ => {}
                }

                expr.GenerateTail
              } else false
            
          | Tuple (exprs) =>
            exprs.Iter (NoThrowPlease);
            false

          | TryFault (body, handler)
          | TryWith (body, _, handler)
          | TryFinally (body, handler) =>
            when (! allow_try)
              // use ice here?
              Message.Error (expr.loc, 
                             "try-blocks cannot be used inside expressions, "
                             "this message shouldn't happen though");
            _ = Throws (body, true);
            _ = Throws (handler, true);
            false

          | StaticRef (_t, meth is IMethod, tp) =>
            Util.cassert (meth.GetHeader ().typarms.Length == tp.Length,
                          $ "typarms check failed for $meth "
                            "$(meth.GetHeader ().typarms) $tp");
            #if CHECK_STV
              CheckSTV (_t);
              foreach (tp in tp) CheckSTV (tp);
            #endif
            false

          | TypeOf (_t) =>
            #if CHECK_STV
              CheckSTV (_t);
            #endif
            false
            
          | StaticRef
          | LocalRef
          | ImplicitValueTypeCtor
          | Literal
          | This
          | Base
          | OpCode
          | MethodAddress
          | DefaultValue =>
            false
        }

      expr.Throws = res;
      log (STV, expr.loc, $ "throws } $(expr.GetType()) ");
      res
    }
    #endregion


    #region Unsharing and throw handling
    static UnShare (expr : TExpr) : TExpr
    {
      expr.Walk (DoUnShare)
    }
    
    static DoUnShare (expr : TExpr) : TExpr
    {
      // Message.Debug ($"unshare $expr");
      if (expr.Visited) {
        def expr = expr.Copy ();
        Util.cassert (!expr.Visited);
        UnShare (expr)
      } else {
        expr.Visited = true;
        null
      }
    }
    #endregion


    #region Top level stuff
    static mutable current_try_block : int;
    static label_blocks : Hashtable [int, int] = Hashtable ();
    
    static Walk (expr : TExpr) : TExpr
    {
      expr.Walk (DoWalk)
    }

    static WalkTry (expr : TExpr) : TExpr
    {
      def backup = current_try_block;
      try {
        current_try_block = Util.next_id ();
        Walk (expr)
      } finally {
        current_try_block = backup;
      }
    }

    static DoWalk (expr : TExpr) : TExpr
    {
      // Message.Debug ($ "dowalk: $(expr.GetType()) $(expr.Type)");
      def res =
        match (expr) {
          | FieldMember (obj, fld) when ! expr.IsAssigned =>
            def obj = Walk (obj);
              
            // use address here for better performance
            when (! obj.NeedAddress && fld.DeclaringType.IsValueType)
              obj.NeedAddress = true;
              
            TExpr.FieldMember (obj, fld)

          | Assign (e1, e2) =>
            e1.IsAssigned = true;

            def real_type =
              match (e1) {
                | LocalRef (decl) => decl.Type.Fix ()
                | ArrayIndexer
                | This
                | FieldMember
                | StaticRef => e1.Type.Fix ()
                | _ => Util.ice ($ "wrong assignment target $e1")
              }
            def e1 = Walk (e1);
            def e2 = Walk (e2);
            
            if (e1.Type.Fix ().IsValueType && ! e1.IsAddressable)
              Message.Error ("this expression is not a proper lvalue: "
                             "cannot load value type address");
            else
              e1.NeedAddress = true;

          
            if (NeedBoxing (real_type, e2))
              TExpr.Assign (InternalType.Void, e1, Box (real_type, e2))
            else 
              TExpr.Assign (InternalType.Void, e1, e2)

          | DefValIn (decl, e1, e2) =>
            def e1 = Walk (e1);
            def e2 = Walk (e2);
            if (NeedBoxing (decl.Type.Fix (), e1))
              TExpr.DefValIn (decl, Box (decl.Type.Fix (), e1), e2)
            else
              TExpr.DefValIn (decl, e1, e2)
            
          | Call (ImplicitValueTypeCtor, [], _) => null
          
          | Call (OpCode ("==.ref"), [p1, p2], _)
          | Call (OpCode ("!=.ref"), [p1, p2], _)
          | Call (OpCode ("=="), [p1, p2], _)
          | Call (OpCode ("!="), [p1, p2], _) =>
            p1.expr = Walk (p1.expr);
            p2.expr = Walk (p2.expr);
            when (p1.expr.MType is MType.TyVarRef)
              p1.expr = Box (InternalType.Object, p1.expr);
            when (p2.expr.MType is MType.TyVarRef)
              p2.expr = Box (InternalType.Object, p2.expr);
            expr

          // I hope other opcodes don't need boxing ...
          | Call (OpCode, _, _) => null
          
          | Call (origfunc, parms, is_tail) =>
            mutable func = Walk (origfunc);
            def meth =
              match (func) {
                | MethodRef (obj, meth, type_parms, nonvirt) =>
                  if (obj.MType is MType.TyVarRef) {
                    obj.NeedsConstrained = true;
                    obj.NeedAddress = true;
                  }
                  else
                  // we would kinda like address here
                  when (obj.Type.Fix ().IsValueType) {
                    def methty = meth.DeclaringType;
                    if (methty.IsValueType)
                      obj.NeedAddress = true;
                    else {
                      // but maybe we should employ boxing
                      def obj = Box (methty.GetMemType (), obj);
                      func = TExpr.MethodRef (func.Type, obj, meth, type_parms, nonvirt);
                    }
                  }
                  
                  meth
                  
                | Base (meth)
                | StaticRef (_, meth is IMethod, _) => meth
                | _ => Util.ice ($ "invalid thing called $func")
              }

            unless (parms.IsEmpty) {
              if (meth.DeclaringType.IsDelegate && meth.Name == ".ctor") {
                def parm = parms.Head;
                when (NeedBoxing (InternalType.Object, parm.expr))
                  parm.expr = Box (InternalType.Object, parm.expr);
              } else {
                mutable formals = 
                  origfunc.MType.FunReturnTypeAndParms (meth) [0];

                //Message.Debug ($"origfunc: $origfunc type=$(origfunc.MType)");
                Util.cassert (formals.Length == parms.Length,
                              $ "call to $meth $parms $formals");
                
                foreach (parm in parms) {
                  match (formals) { 
                    | f :: fs =>
                      def t = f.Fix ();
                      parm.expr = Walk (parm.expr);
                      if (parm.kind == ParmKind.Normal) {
                        if (parm.expr.MType is MType.Void)
                          parm.expr =
                            TExpr.Sequence (
                              InternalType.Object,
                              parm.expr,
                              TExpr.Literal (InternalType.Object,
                                             Literal.Null ()))
                        else
                          when (NeedBoxing (t, parm.expr))
                            parm.expr = Box (t, parm.expr);
                      } else {
                        if (parm.expr.IsAddressable)
                          parm.expr.NeedAddress = true;
                        else
                          Message.Error ($ "non-addressable expression passed "
                                           "as a ref/out parameter");
                      }
                      formals = fs;
                    | [] => Util.ice ();
                  }
                }
              }
            }
            if (meth.ReturnType.Fix () is MType.TyVarRef &&
                expr.MType is MType.Void)
              Convert (expr.Type,
                       TExpr.Call (InternalType.Object, func, parms, false))
            else
              TExpr.Call (func, parms, is_tail)

          | Array (parms, dimensions) =>
            def ty =
              match (expr.Type.Fix ()) {
                | MType.Array (t, _) => t.Fix ()
                | _ => Util.ice ()
              }
            mutable res = [];
            foreach (parm in parms) {
              def parm = Walk (parm);
              def parm =
                if (NeedBoxing (ty, parm))
                  Box (ty, parm)
                else parm;
              res = parm :: res;
            }
            TExpr.Array (res.Rev (), dimensions)

          | TryWith (body, exn, handler) =>
            TExpr.TryWith (WalkTry (body), exn, WalkTry (handler))
            
          | TryFinally (body, handler) =>
            TExpr.TryFinally (WalkTry (body), Walk (handler))

          | TryFault (body, handler) =>
            TExpr.TryFault (WalkTry (body), Walk (handler))

          | Goto as g =>
            g.try_block = current_try_block;
            null

          | Label (id, _) =>
            label_blocks [id] = current_try_block;
            null

          | DefFunctionsIn
          | Match
          | SelfTailCall =>
            Util.cassert (Message.SeenError);
            null
          
          | _ => null
        }
      //Message.Debug ($"do walk: $expr -> $res");
      res
    }
    #endregion
  }
}
