// Copyright (C) 2006-2008 Garth N. Wells.
// Licensed under the GNU LGPL Version 2.1.
//
// Modified by Anders Logg 2006.
// Modified by Dag Lindbo 2008.
//
// First added:  2006-06-01
// Last changed: 2008-07-08

#include <dolfin/log/dolfin_log.h>
#include "UmfpackLUSolver.h"
#include "GenericMatrix.h"
#include "GenericVector.h"
#include "KrylovSolver.h"
#include "LUSolver.h"

extern "C"
{
#ifdef HAS_UMFPACK
#include <umfpack.h>
#endif
}

using namespace dolfin;

//-----------------------------------------------------------------------------
Parameters UmfpackLUSolver::default_parameters()
{
  Parameters p(LUSolver::default_parameters());
  p.rename("umfpack_lu_solver");
  return p;
}
//-----------------------------------------------------------------------------
UmfpackLUSolver::UmfpackLUSolver()
{
  // Set parameter values
  parameters = default_parameters();
}
//-----------------------------------------------------------------------------
UmfpackLUSolver::~UmfpackLUSolver()
{
  // Do nothing
}
//-----------------------------------------------------------------------------
#ifdef HAS_UMFPACK
dolfin::uint UmfpackLUSolver::solve(const GenericMatrix& A, GenericVector& x,
                                    const GenericVector& b)
{
  // Factorize matrix
  factorize(A);

  // Solve system
  factorized_solve(x, b);

  // Clear data
  umfpack.clear();

  return 1;
}
//-----------------------------------------------------------------------------
dolfin::uint UmfpackLUSolver::factorize(const GenericMatrix& A)
{
  // Check dimensions and get number of non-zeroes
  std::tr1::tuple<const std::size_t*, const std::size_t*, const double*, int> data = A.data();
  const uint M   = A.size(0);
  const uint nnz = std::tr1::get<3>(data);
  assert(A.size(0) == A.size(1));

  assert(nnz >= M);

  // Initialise umfpack data
  umfpack.init((const long int*) std::tr1::get<0>(data),
    (const long int*) std::tr1::get<1>(data), std::tr1::get<2>(data), M, nnz);

  // Factorize
  info("LU-factorizing linear system of size %d x %d (UMFPACK).", M, M);
  umfpack.factorize();

  return 1;
}
//-----------------------------------------------------------------------------
dolfin::uint UmfpackLUSolver::factorized_solve(GenericVector& x, const GenericVector& b) const
{
  const uint N = b.size();

  if(!umfpack.factorized)
    error("Factorized solve must be preceded by call to factorize.");

  if(N != umfpack.N)
    error("Vector does not match size of factored matrix");

  // Initialise solution vector and solve
  x.resize(N);

  info("Solving factorized linear system of size %d x %d (UMFPACK).", N, N);
  // Solve for tranpose since we use compressed rows and UMFPACK expected compressed columns
  umfpack.factorized_solve(x.data(), b.data(), true);

  return 1;
}
//-----------------------------------------------------------------------------
#else
dolfin::uint UmfpackLUSolver::solve(const GenericMatrix& A, GenericVector& x,
                                    const GenericVector& b)
{
  warning("UMFPACK must be installed to peform a LU solve for uBLAS matrices. A Krylov iterative solver will be used instead.");

  KrylovSolver solver;
  return solver.solve(A, x, b);
}
//-----------------------------------------------------------------------------
dolfin::uint UmfpackLUSolver::factorize(const GenericMatrix& A)
{
  error("UMFPACK must be installed to perform sparse LU factorization.");
  return 0;
}
//-----------------------------------------------------------------------------
dolfin::uint UmfpackLUSolver::factorized_solve(GenericVector& x,
                                              const GenericVector& b) const
{
  error("UMFPACK must be installed to perform sparse backward and forward substitutions.");
  return 0;
}
#endif
//-----------------------------------------------------------------------------
#ifdef HAS_UMFPACK
// UmfpackLUSolver::Umfpack implementation
//-----------------------------------------------------------------------------
void UmfpackLUSolver::Umfpack::clear()
{
  delete dnull; dnull = 0;
  delete inull; inull = 0;
  if(Symbolic)
  {
    umfpack_dl_free_symbolic(&Symbolic);
    Symbolic = 0;
  }
  if(Numeric)
  {
    umfpack_dl_free_numeric(&Numeric);
    Numeric = 0;
  }
  if(local_matrix)
  {
    delete [] Rp; Rp = 0;
    delete [] Ri; Ri = 0;
    delete [] Rx; Rx = 0;
    local_matrix = false;
  }
  factorized =  false;
  N = 0;
}
//-----------------------------------------------------------------------------
void UmfpackLUSolver::Umfpack::init(const long int* Ap, const long int* Ai,
                                         const double* Ax, uint M, uint nz)
{
  if(factorized)
    warning("LUSolver already contains a factorized matrix! Clearing and starting over.");

  // Clear any data
  clear();

  // Set umfpack data
  N  = M;
  Rp = Ap;
  Ri = Ai;
  Rx = Ax;
  N  = M;
  local_matrix = false;
}
//-----------------------------------------------------------------------------
void UmfpackLUSolver::Umfpack::init_transpose(const long int* Ap, const long int* Ai,
                                         const double* Ax, uint M, uint nz)
{
  if(Rp || Ri || Rx)
    error("UmfpackLUSolver data already points to a matrix");

  // Allocate memory and take ownership
  clear();
  Rp = new long int[M+1];
  Ri = new long int[nz];
  Rx = new double[nz];
  local_matrix = true;
  N  = M;

  // Compute transpse
  long int status = umfpack_dl_transpose(M, M, Ap, Ai, Ax, inull, inull,
                    const_cast<long int*>(Rp), const_cast<long int*>(Ri), const_cast<double*>(Rx));
  Umfpack::check_status(status, "transpose");
}
//-----------------------------------------------------------------------------
void UmfpackLUSolver::Umfpack::factorize()
{
  assert(Rp);
  assert(Ri);
  assert(Rx);
  assert(!Symbolic);
  assert(!Numeric);

  long int status;

  // Symbolic step (reordering etc)
  status= umfpack_dl_symbolic(N, N, (const long int*) Rp,(const long int*) Ri,
                              Rx, &Symbolic, dnull, dnull);
  check_status(status, "symbolic");

  // Factorization step
  status = umfpack_dl_numeric((const long int*) Rp,(const long int*) Ri, Rx,
                               Symbolic, &Numeric, dnull, dnull);
  Umfpack::check_status(status, "numeric");

  // Discard the symbolic part (since the factorization is complete.)
  umfpack_dl_free_symbolic(&Symbolic);
  Symbolic = 0;

  factorized = true;
}
//-----------------------------------------------------------------------------
void UmfpackLUSolver::Umfpack::factorized_solve(double*x, const double* b, bool transpose) const
{
  assert(Rp);
  assert(Ri);
  assert(Rx);
  assert(Numeric);

  long int status;
  if(transpose)
    status = umfpack_dl_solve(UMFPACK_At, Rp, Ri, Rx, x, b, Numeric, dnull, dnull);
  else
    status = umfpack_dl_solve(UMFPACK_A, Rp, Ri, Rx, x, b, Numeric, dnull, dnull);

  Umfpack::check_status(status, "solve");
}
//-----------------------------------------------------------------------------
void UmfpackLUSolver::Umfpack::check_status(long int status, std::string function) const
{
  if(status == UMFPACK_OK)
    return;

  // Printing which UMFPACK function is returning an warning/error
  cout << "UMFPACK problem related to call to " << function << endl;

  if(status == UMFPACK_WARNING_singular_matrix)
    warning("UMFPACK reports that the matrix being solved is singular.");
  else if(status == UMFPACK_ERROR_out_of_memory)
    error("UMFPACK has run out of memory solving a system.");
  else if(status == UMFPACK_ERROR_invalid_system)
    error("UMFPACK reports an invalid system. Is the matrix square?.");
  else if(status == UMFPACK_ERROR_invalid_Numeric_object)
    error("UMFPACK reports an invalid Numeric object.");
  else if(status == UMFPACK_ERROR_invalid_Symbolic_object)
    error("UMFPACK reports an invalid Symbolic object.");
  else if(status != UMFPACK_OK)
    warning("UMFPACK is reporting an unknown error.");
}
#endif
//-----------------------------------------------------------------------------

