//  This file is part of ff3d - http://www.freefem.org/ff3d
//  Copyright (C) 2001, 2002, 2003 Stphane Del Pino

//  This program is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2, or (at your option)
//  any later version.

//  This program is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.

//  You should have received a copy of the GNU General Public License
//  along with this program; if not, write to the Free Software Foundation,
//  Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.  

//  $Id: MultiLinearExpression.hpp,v 1.5 2004/12/31 14:00:47 delpinux Exp $

#ifndef MULTI_LINEAR_EXPRESSION_HPP
#define MULTI_LINEAR_EXPRESSION_HPP

#include <Expression.hpp>
#include <LinearExpression.hpp>

#include <EmbededFunctions.hpp>

#include <Stringify.hpp>
#include <ErrorHandler.hpp>

#include <list>

class MultiLinearExpression
  : public Expression
{
public:
  enum LinearFormType {
    linear,
    biLinear
  };

  enum OperatorType {
    gradUgradV,
    dxUdxV,
    dxUV,
    UdxV,
    UV,
    gradV,
    gradFgradV,
    dxV,
    V,
    dxFV
  };

  typedef std::list<ReferenceCounting<LinearExpression> > LinearListType;
  typedef std::list<ReferenceCounting<FunctionExpression> > FunctionListType;
  typedef std::list<ReferenceCounting<RealExpression> > RealListType;

private:
  LinearListType __linearList;
  FunctionListType __functionList;
  RealListType __realList;

  std::ostream& put(std::ostream& os) const
  {
    bool notFirst = false;
    for (LinearListType::const_iterator i = __linearList.begin();
	 i != __linearList.end(); ++i) {
      if (notFirst) {
	os << '*';
      } else {
	notFirst = true;
      }
      os << (*(*i));
    }

    for (FunctionListType::const_iterator i = __functionList.begin();
	 i != __functionList.end(); ++i) {
      if (notFirst) {
	os << '*';
      } else {
	notFirst = true;
      }
      os << (*(*i));
    }

    for (RealListType::const_iterator i = __realList.begin();
	 i != __realList.end(); ++i) {
      if (notFirst) {
	os << '*';
      } else {
	notFirst = true;
      }
      os << (*(*i));
    }

    return os;
  }
public:

  ReferenceCounting<FunctionExpression>
  getFunction()
  {
    ReferenceCounting<FunctionExpression> f = 0;
    ReferenceCounting<RealExpression> r = 0;

    if (__realList.size() > 0) {
      RealListType::iterator i = __realList.begin();
      r = *i;
      i++;
      for ( ; i != __realList.end(); ++i) {
	r = new RealExpressionBinaryOperator<product>(r,*i);
      }
    }
    if (__functionList.size() > 0) {
      FunctionListType::iterator i = __functionList.begin();
      f = *i;
      ++i;
      for ( ; i != __functionList.end(); ++i) {
	f = new FunctionExpressionBinaryOperator<ExpressionMultiplies<FunctionExpression> >(f, *i);
      }
    }

    for (LinearListType::iterator i = __linearList.begin();
	 i != __linearList.end(); ++i) {
      switch ((*(*i)).formType()) {
      case (LinearExpression::elementary): {
	break;			// nothing to do
      }
      case (LinearExpression::elementaryTimesFunction): {
	LinearExpressionElementaryTimesFunction& L
	  = dynamic_cast<LinearExpressionElementaryTimesFunction&>(*(*i));
	if (f == 0) {
	  f = L.function();
	} else {
	  f = new FunctionExpressionBinaryOperator<ExpressionMultiplies<FunctionExpression> >(f, L.function());
	}
	break;
      }
      case (LinearExpression::elementaryTimesReal): {
	LinearExpressionElementaryTimesReal& L
	  = dynamic_cast<LinearExpressionElementaryTimesReal&>(*(*i));
	if (r == 0) {
	  r = L.real_t();
	} else {
	  r = new RealExpressionBinaryOperator<product>(r,L.real_t());
	}
	break;
      }
      case (LinearExpression::elementaryTimesFunctionOperator): {
	throw ErrorHandler(__FILE__,__LINE__,
			   "not implemented",
			   ErrorHandler::unexpected);
	break;
      }
      default: {
	throw ErrorHandler(__FILE__,__LINE__,
			   "not implemented",
			   ErrorHandler::unexpected);
      }
      }
    }

    if ((r == 0)&&(f == 0)) {
      r = new RealExpressionValue(1);
      f = new FunctionExpressionConstant(r);
    } else if (r == 0) {
      ; // everything is ok !
    } else if (f == 0){
      f = new FunctionExpressionConstant(r);
    } else {
      f = new FunctionExpressionBinaryOperator<ExpressionMultiplies<FunctionExpression> >(f, new FunctionExpressionConstant(r));
    }

    return f;
  }

  ReferenceCounting<FunctionVariable> getUnknown()
  {
    ReferenceCounting<FunctionVariable> u = 0;
    LinearListType::iterator i = __linearList.begin();
    while (i != __linearList.end()) {
      IntegratedExpression& IE =
	(*(*(*(*i)).integrated()).integratedExpression());

      if (IE.integratedExpressionType()
	  == IntegratedExpression::unknownFunction) {
	IntegratedExpressionUnknown& I
	  = dynamic_cast<IntegratedExpressionUnknown&>(IE);

	u = I.unknown();
      }
      ++i;
    }
    assert(u != 0);
    return u;
  }

  ReferenceCounting<TestFunctionVariable> getTestFunction()
  {
    ReferenceCounting<TestFunctionVariable> t = 0;
    LinearListType::iterator i = __linearList.begin();
    while (i != __linearList.end()) {
      IntegratedExpression& IE =
	(*(*(*(*i)).integrated()).integratedExpression());

      if (IE.integratedExpressionType()
	  == IntegratedExpression::testFunction) {
	IntegratedExpressionTest& I
	  = dynamic_cast<IntegratedExpressionTest&>(IE);
	t = I.testFunction();
      }
      ++i;
    }
    assert(t != 0);
    return t;
  }

  IntegratedOperatorExpression::OperatorType getUnknownOperator()
  {
    IntegratedOperatorExpression::OperatorType t
      = IntegratedOperatorExpression::undefined;

    LinearListType::iterator i = __linearList.begin();
    while (i != __linearList.end()) {
      IntegratedOperatorExpression& I = (*(*(*i)).integrated());

      IntegratedExpression& IE = (*I.integratedExpression());

      if (IE.integratedExpressionType()
	  == IntegratedExpression::unknownFunction) {
	t = I.operatorType();
      }
      ++i;
    }

    return t;
  }

  IntegratedOperatorExpression::OperatorType getTestFunctionOperator()
  {
    IntegratedOperatorExpression::OperatorType t
      = IntegratedOperatorExpression::undefined;

    LinearListType::iterator i = __linearList.begin();
    while (i != __linearList.end()) {
      IntegratedOperatorExpression& I = (*(*(*i)).integrated());

      IntegratedExpression& IE = (*I.integratedExpression());

      if (IE.integratedExpressionType()
	  == IntegratedExpression::testFunction) {
	t = I.operatorType();
      }
      ++i;
    }

    return t;
  }

  IntegratedOperatorExpression::OperatorType getFunctionOperator()
  {
    IntegratedOperatorExpression::OperatorType t
      = IntegratedOperatorExpression::undefined;

    LinearListType::iterator i = __linearList.begin();
    while (i != __linearList.end()) {
      IntegratedOperatorExpression& I = (*(*(*i)).integrated());

      IntegratedExpression& IE = (*I.integratedExpression());

      if (IE.integratedExpressionType()
	  == IntegratedExpression::function) {
	t = I.operatorType();
      }
      ++i;
    }

    return t;
  }

  ReferenceCounting<FunctionExpression> getFunctionExpression()
  {
    ReferenceCounting<FunctionExpression> f = 0;
    LinearListType::iterator i = __linearList.begin();
    while (i != __linearList.end()) {
      IntegratedExpression& IE =
	(*(*(*(*i)).integrated()).integratedExpression());

      if (IE.integratedExpressionType()
	  == IntegratedExpression::function) {
	IntegratedExpressionFunctionExpression& I
	  = dynamic_cast<IntegratedExpressionFunctionExpression&>(IE);

	f = I.function();
      }
      ++i;
    }
    assert(f != 0);
    return f;
  }


  LinearListType::iterator beginLinear()
  {
    return __linearList.begin();
  }

  LinearListType::iterator endLinear()
  {
    return __linearList.end();
  }

  FunctionListType::iterator beginFunction()
  {
    return __functionList.begin();
  }

  FunctionListType::iterator endFunction()
  {
    return __functionList.end();
  }

  RealListType::iterator beginReal()
  {
    return __realList.begin();
  }

  RealListType::iterator endReal()
  {
    return __realList.end();
  }

  OperatorType operatorType()
  {
    IntegratedOperatorExpression::OperatorType unknown
      = IntegratedOperatorExpression::undefined;

    IntegratedOperatorExpression::OperatorType test
      = IntegratedOperatorExpression::undefined;

    IntegratedOperatorExpression::OperatorType function
      = IntegratedOperatorExpression::undefined;

    for(LinearListType::iterator i = __linearList.begin();
	i != __linearList.end(); ++i) {
      IntegratedOperatorExpression& I = (*(*(*i)).integrated());
      switch((*I.integratedExpression()).integratedExpressionType()) {
      case IntegratedExpression::testFunction: {
	test = I.operatorType();
	break;
      }
      case IntegratedExpression::unknownFunction: {
	unknown = I.operatorType();
	break;
      }
      case IntegratedExpression::function: {
	function = I.operatorType();
	break;
      }
      }
    }

    OperatorType t;

    switch (__linearList.size()) {
    case 1: {
      switch (test) {
      case IntegratedOperatorExpression::orderZero: {
	t=V;
	break;
      }
      case IntegratedOperatorExpression::gradient: {
	t=gradV;
	break;
      }
      case IntegratedOperatorExpression::dx:
      case IntegratedOperatorExpression::dy:
      case IntegratedOperatorExpression::dz: {
	return dxV;
	break;
      }
      default: {
	throw ErrorHandler(__FILE__,__LINE__,
			   "not implemented",
			   ErrorHandler::unexpected);
      }
      }
      break;
    }
    case 2: {
      if (test != IntegratedOperatorExpression::undefined) {
	if (unknown != IntegratedOperatorExpression::undefined) {
	  switch (unknown) {
	  case IntegratedOperatorExpression::orderZero: {
	    switch (test) {
	    case IntegratedOperatorExpression::orderZero: {
	      t=UV;
	      break;
	    }
	    case IntegratedOperatorExpression::gradient:
	    case IntegratedOperatorExpression::dx:
	    case IntegratedOperatorExpression::dy:
	    case IntegratedOperatorExpression::dz: {
	      t=UdxV;
	      break;
	    }
	    default: {
	      throw ErrorHandler(__FILE__,__LINE__,
				 "not implemented",
				 ErrorHandler::unexpected);
	    }
	    }
	    break;
	  }
	  case IntegratedOperatorExpression::gradient: {
	    switch (test) {
	    case IntegratedOperatorExpression::orderZero: {
	      t=dxUV;
	      break;
	    }
	    case IntegratedOperatorExpression::gradient: {
	      t=gradUgradV;
	      break;
	    }
	    case IntegratedOperatorExpression::dx:
	    case IntegratedOperatorExpression::dy:
	    case IntegratedOperatorExpression::dz: {
	      return dxUdxV;
	      break;
	    }
	    default: {
	      throw ErrorHandler(__FILE__,__LINE__,
				 "not implemented",
				 ErrorHandler::unexpected);
	    }
	    }
	    break;
	  }
	  case IntegratedOperatorExpression::dx:
	  case IntegratedOperatorExpression::dy:
	  case IntegratedOperatorExpression::dz: {
	    switch (test) {
	    case IntegratedOperatorExpression::orderZero: {
	      t=dxUV;
	      break;
	    }
	    case IntegratedOperatorExpression::gradient:
	    case IntegratedOperatorExpression::dx:
	    case IntegratedOperatorExpression::dy:
	    case IntegratedOperatorExpression::dz: {
	      t=dxUdxV;
	      break;
	    }
	    default: {
	      throw ErrorHandler(__FILE__,__LINE__,
				 "not implemented",
				 ErrorHandler::unexpected);
	    }
	    }
	    break;
	  }
	  default: {
	    throw ErrorHandler(__FILE__,__LINE__,
			       "not implemented",
			       ErrorHandler::unexpected);
	  }
	  }
	} else {
	  switch (function) {
	  case IntegratedOperatorExpression::gradient: {
	    switch (test) {
	    case IntegratedOperatorExpression::gradient: {
	      t = gradFgradV;
	      break;
	    }
	    case IntegratedOperatorExpression::orderZero:
	    case IntegratedOperatorExpression::dx:
	    case IntegratedOperatorExpression::dy:
	    case IntegratedOperatorExpression::dz:
	    default: {
	      throw ErrorHandler(__FILE__,__LINE__,
				 "not implemented",
				 ErrorHandler::unexpected);
	    }
	    }
	    break;
	  }
	  case IntegratedOperatorExpression::dx:
	  case IntegratedOperatorExpression::dy:
	  case IntegratedOperatorExpression::dz: {
	    switch (test) {
	    case IntegratedOperatorExpression::orderZero: {
	      t = dxFV;
	      break;
	    }
	    case IntegratedOperatorExpression::gradient:
	    case IntegratedOperatorExpression::dx:
	    case IntegratedOperatorExpression::dy:
	    case IntegratedOperatorExpression::dz: {
	      throw ErrorHandler(__FILE__,__LINE__,
				 "not implemented",
				 ErrorHandler::unexpected);
	      break;
	    }
	    default: {
	      throw ErrorHandler(__FILE__,__LINE__,
				 "not implemented",
				 ErrorHandler::unexpected);
	    }
	    }
	    break;
	  }
	  default: {
	    throw ErrorHandler(__FILE__,__LINE__,
			       "not implemented",
			       ErrorHandler::unexpected);
	  }
	  }
	}
      }
      break;
    }
    default: {
      throw ErrorHandler(__FILE__,__LINE__,
			 "not implemented",
			 ErrorHandler::unexpected);
    }
    }
    return t;
  }

  LinearFormType linearFormType()
  {
    LinearFormType t;
    LinearListType::iterator i = __linearList.begin();
    switch (__linearList.size()) {
    case 1: {
      IntegratedExpression& I=*(*(*(*i)).integrated()).integratedExpression();
      if (I.integratedExpressionType() != IntegratedExpression::testFunction) {
	const std::string errorMsg
	  = "error defining "+stringify(*this)+"\n"
	  "1-linear operator should be defined using test functions";

	throw ErrorHandler(__FILE__,__LINE__,
			   errorMsg,
			   ErrorHandler::normal);
      } else {
	return linear;
      }
      break;
    }
    case 2: {
      return biLinear;
      break;
    }
    default: {
      throw ErrorHandler(__FILE__,__LINE__,
			 "not implemented",
			 ErrorHandler::unexpected);
    }
    }

    return t;
  }

  void check()
  {
    if (__linearList.size() > 2) {
      this->execute();
      const std::string errorMsg
	= stringify(*this)+" is "+stringify(__linearList.size())
	  +"-linear.\ncan discretize at most 2-linear expressions\n";

      throw ErrorHandler(__FILE__,__LINE__,
			 errorMsg,
			 ErrorHandler::normal);
    } else if (__linearList.size() == 2) {
      LinearListType::iterator i = __linearList.begin();
      IntegratedExpression::IType itype
	= (*(*(*(*i)).integrated()).integratedExpression()).integratedExpressionType();
      i++;
      IntegratedExpression::IType jtype
	= (*(*(*(*i)).integrated()).integratedExpression()).integratedExpressionType();
      if (itype == jtype) {
	this->execute();
	const std::string errorMsg 
	  = stringify(*this)+" is non linear\n"
	  +"Cannot discretize non linear operators\n";

	throw ErrorHandler(__FILE__,__LINE__,
			   errorMsg,
			   ErrorHandler::normal);
      }
    }
  }

  void execute()
  {
    for (LinearListType::iterator i = __linearList.begin();
	 i != __linearList.end(); ++i) {
      (*(*i)).execute();
    }

    for (FunctionListType::iterator i = __functionList.begin();
	 i != __functionList.end(); ++i) {
      (*(*i)).execute();
    }

    for (RealListType::iterator i = __realList.begin();
	 i != __realList.end(); ++i) {
      (*(*i)).execute();
    }
  }

  void times(ReferenceCounting<LinearExpression> LF)
  {
    __linearList.push_back(LF);
  }

  void times(ReferenceCounting<FunctionExpression> f)
  {
    __functionList.push_back(f);
  }

  void times(ReferenceCounting<RealExpression> r)
  {
    __realList.push_back(r);
  }

  MultiLinearExpression(ReferenceCounting<LinearExpression> LF)
    : Expression(Expression::multiLinearExp)
  {
    __linearList.push_back(LF);
  }

  MultiLinearExpression(const MultiLinearExpression& MLF)
    : Expression(MLF),
      __linearList(MLF.__linearList),
      __functionList(MLF.__functionList),
      __realList(MLF.__realList)
  {
    ;
  }

  ~MultiLinearExpression()
  {
    ;
  }
};


class MultiLinearExpressionSum
  : public Expression
{
private:
  typedef std::list<ReferenceCounting<MultiLinearExpression> > ListType;

public:
  typedef ListType::iterator iterator;

  MultiLinearExpressionSum::iterator beginPlus()
  {
    return __listPlus.begin();
  }

  MultiLinearExpressionSum::iterator endPlus()
  {
    return __listPlus.end();
  }

  MultiLinearExpressionSum::iterator beginMinus()
  {
    return __listMinus.begin();
  }

  MultiLinearExpressionSum::iterator endMinus()
  {
    return __listMinus.end();
  }

private:
  ListType __listPlus;
  ListType __listMinus;

  std::ostream& put(std::ostream& os) const
  {
    bool first = true;

    for (ListType::const_iterator i = __listPlus.begin();
	 i != __listPlus.end(); ++i) {
      if (!first) {
	os << '+';
      } else {
	first = false;
      }
      os << *(*i);

    }
    for (ListType::const_iterator i = __listMinus.begin();
	 i != __listMinus.end(); ++i) {
      os << '-' << *(*i);
    }

    return os;
  }

public:
  void plus(ReferenceCounting<MultiLinearExpression> m) {
    __listPlus.push_back(m);
  }

  void minus(ReferenceCounting<MultiLinearExpression> m) {
    __listMinus.push_back(m);
  }

  void check()
  {
    for (ListType::iterator i = __listPlus.begin();
	 i != __listPlus.end(); ++i) {
      (*(*i)).check();
    }

    for (ListType::iterator i = __listMinus.begin();
	 i != __listMinus.end(); ++i) {
      (*(*i)).check();
    }
  }

  void execute()
  {
    for (ListType::iterator i = __listPlus.begin();
	 i != __listPlus.end(); ++i) {
      (*(*i)).execute();
    }

    for (ListType::iterator i = __listMinus.begin();
	 i != __listMinus.end(); ++i) {
      (*(*i)).execute();
    }
  }

  MultiLinearExpressionSum()
    : Expression(Expression::multiLinearExpSum)
  {
    ;
  }

  MultiLinearExpressionSum(const MultiLinearExpressionSum& M)
    : Expression(M),
      __listPlus(M.__listPlus),
      __listMinus(M.__listMinus)
  {
    ;
  }

  ~MultiLinearExpressionSum()
  {
    ;
  }

};
#endif // MULTI_LINEAR_EXPRESSION_HPP

