//                                               -*- C++ -*-
/**
 * @file  PythonNumericalMathEvaluationImplementation.cxx
 * @brief This class binds a Python function to an Open TURNS' NumericalMathFunction
 *
 *  Copyright (C) 2005-2014 Airbus-EDF-Phimeca
 *
 *  This library is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Lesser General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
 *
 * @author schueller
 * @date   2012-07-16 12:24:33 +0200 (Mon, 16 Jul 2012)
 */

#include "PythonNumericalMathEvaluationImplementation.hxx"
#include "OSS.hxx"
#include "Description.hxx"
#include "PythonWrappingFunctions.hxx"
#include "PersistentObjectFactory.hxx"
#include "Exception.hxx"

BEGIN_NAMESPACE_OPENTURNS

typedef NumericalMathEvaluationImplementation::CacheKeyType             CacheKeyType;
typedef NumericalMathEvaluationImplementation::CacheValueType           CacheValueType;
typedef NumericalMathEvaluationImplementation::CacheType                CacheType;


CLASSNAMEINIT(PythonNumericalMathEvaluationImplementation);

static Factory<PythonNumericalMathEvaluationImplementation> RegisteredFactory("PythonNumericalMathEvaluationImplementation");



/* Default constructor */
PythonNumericalMathEvaluationImplementation::PythonNumericalMathEvaluationImplementation()
  : NumericalMathEvaluationImplementation()
  , pyObj_(0)
{
  // Nothing to do
}


/* Constructor from Python object*/
PythonNumericalMathEvaluationImplementation::PythonNumericalMathEvaluationImplementation(PyObject * pyCallable)
  : NumericalMathEvaluationImplementation()
  , pyObj_(pyCallable)
{
  Py_XINCREF( pyCallable );

  // Set the name of the object as its Python classname
  ScopedPyObjectPointer cls(PyObject_GetAttrString ( pyObj_,
                            const_cast<char *>( "__class__" ) ));
  ScopedPyObjectPointer name(PyObject_GetAttrString( cls.get(),
                             const_cast<char *>( "__name__" ) ));
  setName( convert< _PyString_, String >( name.get() ) );



  const UnsignedLong inputDimension  = getInputDimension();
  const UnsignedLong outputDimension = getOutputDimension();
  Description description(inputDimension + outputDimension);

  ScopedPyObjectPointer descIn(PyObject_CallMethod( pyObj_,
                               const_cast<char *>( "getInputDescription" ),
                               const_cast<char *>( "()" ) ));
  if ( descIn.get()
       && PySequence_Check( descIn.get() )
       && ( PySequence_Size( descIn.get() ) == (SignedInteger)inputDimension ) )
  {
    Description inputDescription(convert< _PySequence_, Description >( descIn.get() ));
    for (UnsignedLong i = 0; i < inputDimension; ++i)
    {
      description[i] = inputDescription[i];
    }
  }
  else for (UnsignedLong i = 0; i < inputDimension; ++i) description[i] = (OSS() << "x" << i);


  ScopedPyObjectPointer descOut(PyObject_CallMethod( pyObj_,
                                const_cast<char *>( "getOutputDescription" ),
                                const_cast<char *>( "()" ) ));
  if ( descIn.get()
       && PySequence_Check( descOut.get() )
       && ( PySequence_Size( descOut.get() ) == (SignedInteger)outputDimension ) )
  {
    Description outputDescription(convert< _PySequence_, Description >( descOut.get() ));
    for (UnsignedLong i = 0; i < outputDimension; ++i)
    {
      description[inputDimension + i] = outputDescription[i];
    }
  }
  else for (UnsignedLong i = 0; i < outputDimension; ++i) description[inputDimension + i] = (OSS() << "y" << i);

  setDescription(description);
}

/* Virtual constructor */
PythonNumericalMathEvaluationImplementation * PythonNumericalMathEvaluationImplementation::clone() const
{
  return new PythonNumericalMathEvaluationImplementation(*this);
}

/* Copy constructor */
PythonNumericalMathEvaluationImplementation::PythonNumericalMathEvaluationImplementation(const PythonNumericalMathEvaluationImplementation & other)
  : NumericalMathEvaluationImplementation(other)
  , pyObj_(other.pyObj_)
{
  Py_XINCREF( pyObj_ );
}

/* Destructor */
PythonNumericalMathEvaluationImplementation::~PythonNumericalMathEvaluationImplementation()
{
  Py_XDECREF( pyObj_ );
}

/* Comparison operator */
Bool PythonNumericalMathEvaluationImplementation::operator ==(const PythonNumericalMathEvaluationImplementation & other) const
{
  return true;
}

/* String converter */
String PythonNumericalMathEvaluationImplementation::__repr__() const
{
  OSS oss;
  oss << "class=" << PythonNumericalMathEvaluationImplementation::GetClassName()
      << " name=" << getName()
      << " description=" << getDescription()
      << " parameters=" << getParameters();
  return oss;
}

/* String converter */
String PythonNumericalMathEvaluationImplementation::__str__(const String & offset) const
{
  OSS oss;
  oss << "class=" << PythonNumericalMathEvaluationImplementation::GetClassName()
      << " name=" << getName();
  return oss;
}

/* Test for actual implementation */
Bool PythonNumericalMathEvaluationImplementation::isActualImplementation() const
{
  return true;
}



/* Here is the interface that all derived class must implement */

/* Operator () */
NumericalPoint PythonNumericalMathEvaluationImplementation::operator() (const NumericalPoint & inP) const
{
  const UnsignedLong dimension( inP.getDimension() );

  if ( dimension != getInputDimension() )
    throw InvalidDimensionException(HERE) << "Input point has incorrect dimension. Got " << dimension << ". Expected " << getInputDimension();

  NumericalPoint outP;
  CacheKeyType inKey( inP.getCollection() );
  if ( p_cache_->isEnabled() && p_cache_->hasKey( inKey ) )
  {
    outP = NumericalPoint::ImplementationType( p_cache_->find( inKey ) );
  }
  else
  {
    ++ callsNumber_;

    ScopedPyObjectPointer point(convert< NumericalPoint, _PySequence_ >(inP));
    ScopedPyObjectPointer result(PyObject_CallFunctionObjArgs( pyObj_, point.get(), NULL ));

    if ( result.isNull() )
    {
      handleException();
    }

    try
    {
      outP = convert< _PySequence_, NumericalPoint >( result.get() );
    }
    catch (const InvalidArgumentException & ex)
    {
      throw InvalidArgumentException(HERE) << "Output value for " << getName() << "._exec() method is not a sequence object (list, tuple, NumericalPoint, etc.)";
    }

    if (outP.getDimension() != getOutputDimension())
    {
      throw InvalidDimensionException(HERE) << "Output point has incorrect dimension. Got " << outP.getDimension() << ". Expected " << getOutputDimension();
    }

    if ( p_cache_->isEnabled() )
    {
      CacheValueType outValue( outP.getCollection() );
      p_cache_->add( inKey, outValue );
    }
  }
  if (isHistoryEnabled_)
  {
    inputStrategy_.store(inP);
    outputStrategy_.store(outP);
  }
  return outP;
}


/* Operator () */
NumericalSample PythonNumericalMathEvaluationImplementation::operator() (const NumericalSample & inS) const
{
  const UnsignedLong size = inS.getSize();
  const UnsignedLong inDim = inS.getDimension();
  const UnsignedLong outDim = getOutputDimension();
  const bool useCache = p_cache_->isEnabled();

  NumericalSample outS(size, outDim);
  NumericalSample toDo(0, inDim);
  if ( useCache )
  {
    std::set<NumericalPoint> uniqueValues;
    for ( UnsignedLong i = 0; i < size; ++ i )
    {
      CacheKeyType inKey( inS[i].getCollection() );
      if ( p_cache_->hasKey( inKey ) )
      {
        outS[i] = NumericalPoint::ImplementationType( p_cache_->find( inKey ) );
      }
      else
      {
        uniqueValues.insert( inS[i] );
      }
    }
    for(std::set<NumericalPoint>::const_iterator it = uniqueValues.begin(); it != uniqueValues.end(); ++it)
    {
      // store unique values
      toDo.add(*it);
    }
  }
  else
  {
    // compute all values, including duplicates
    toDo = inS;
  }

  UnsignedLong toDoSize = toDo.getSize();
  CacheType tempCache( toDoSize );
  if ( useCache ) tempCache.enable();

  if ( toDoSize > 0 )
  {
    callsNumber_ += toDoSize;

    ScopedPyObjectPointer inTuple(PyTuple_New( toDoSize ));

    for ( UnsignedLong i = 0; i < toDoSize; ++ i )
    {
      PyObject * eltTuple = PyTuple_New( inDim );
      for ( UnsignedLong j = 0; j < inDim; ++ j ) PyTuple_SetItem( eltTuple, j, convert< NumericalScalar, _PyFloat_ > ( toDo[i][j] ) );
      PyTuple_SetItem( inTuple.get(), i, eltTuple );
    }
    ScopedPyObjectPointer result(PyObject_CallFunctionObjArgs( pyObj_, inTuple.get(), NULL ));

    if ( result.isNull() )
    {
      handleException();
    }

    if ( PySequence_Check( result.get() ) )
    {
      const UnsignedLong lengthResult = PySequence_Size( result.get() );
      if ( lengthResult == toDoSize )
      {
        for (UnsignedLong i = 0; i < toDoSize; ++i)
        {
          ScopedPyObjectPointer elt(PySequence_GetItem( result.get(), i ));
          if ( PySequence_Check( elt.get() ) )
          {
            const UnsignedLong lengthElt = PySequence_Size( elt.get() );
            if ( lengthElt == outDim )
            {
              if ( useCache )
              {
                NumericalPoint outP(outDim);
                for (UnsignedLong j = 0; j < outDim; ++j)
                {
                  ScopedPyObjectPointer val(PySequence_GetItem( elt.get(), j ));
                  outP[j] = convert< _PyFloat_, NumericalScalar >( val.get() );
                }
                tempCache.add( toDo[i].getCollection(), outP.getCollection() );
              }
              else
              {
                for (UnsignedLong j = 0; j < outDim; ++j)
                {
                  ScopedPyObjectPointer val(PySequence_GetItem( elt.get(), j ));
                  outS[i][j] = convert< _PyFloat_, NumericalScalar >( val.get() );
                }
              }
            }
            else
            {
              throw InvalidArgumentException(HERE) << "Python NumericalMathFunction returned an sequence object with incorrect dimension (at position "
                                                   << i << ")";
            }
          }
          else
          {
            throw InvalidArgumentException(HERE) << "Python NumericalMathFunction returned an object which is NOT a sequence (at position "
                                                 << i << ")";
          }
        }
      }
      else
      {
        throw InvalidArgumentException(HERE) << "Python NumericalMathFunction returned an sequence object with incorrect size (got "
                                             << lengthResult << ", expected " << toDoSize << ")";
      }
    }
  }

  if ( useCache )
  {
    // fill all the output values
    for( UnsignedLong i = 0; i < size; ++i )
    {
      CacheKeyType inKey( inS[i].getCollection() );
      if ( tempCache.hasKey( inKey ) )
      {
        outS[i] = NumericalPoint::ImplementationType( tempCache.find( inKey ) );
      }
    }
    p_cache_->merge( tempCache );
  }
  if (isHistoryEnabled_)
  {
    inputStrategy_.store(inS);
    outputStrategy_.store(outS);
  }
  outS.setDescription(getOutputDescription());
  return outS;
}


/* Accessor for input point dimension */
UnsignedLong PythonNumericalMathEvaluationImplementation::getInputDimension() const
{
  ScopedPyObjectPointer result(PyObject_CallMethod ( pyObj_,
                               const_cast<char *>( "getInputDimension" ),
                               const_cast<char *>( "()" ) ));
  UnsignedLong dim = convert< _PyInt_, UnsignedLong >( result.get() );
  return dim;
}


/* Accessor for output point dimension */
UnsignedLong PythonNumericalMathEvaluationImplementation::getOutputDimension() const
{
  ScopedPyObjectPointer result(PyObject_CallMethod ( pyObj_,
                               const_cast<char *>( "getOutputDimension" ),
                               const_cast<char *>( "()" ) ));
  UnsignedLong dim = convert< _PyInt_, UnsignedLong >( result.get() );
  return dim;
}


/* Method save() stores the object through the StorageManager */
void PythonNumericalMathEvaluationImplementation::save(Advocate & adv) const
{
  NumericalMathEvaluationImplementation::save( adv );

  pickleSave(adv, pyObj_);
}


/* Method save() reloads the object from the StorageManager */
void PythonNumericalMathEvaluationImplementation::load(Advocate & adv)
{
  NumericalMathEvaluationImplementation::load( adv );

  pickleLoad(adv, pyObj_);
}


END_NAMESPACE_OPENTURNS
