#include <petscdevice.h>
#include <../src/ksp/ksp/utils/lmvm/lmvm.h> /*I "petscksp.h" I*/
#include <petsc/private/deviceimpl.h>

PetscErrorCode MatReset_LMVM(Mat B, PetscBool destructive)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  lmvm->k        = -1;
  lmvm->prev_set = PETSC_FALSE;
  lmvm->shift    = 0.0;
  if (destructive && lmvm->allocated) {
    PetscCall(MatLMVMClearJ0(B));
    PetscCall(VecDestroyVecs(lmvm->m, &lmvm->S));
    PetscCall(VecDestroyVecs(lmvm->m, &lmvm->Y));
    PetscCall(VecDestroy(&lmvm->Xprev));
    PetscCall(VecDestroy(&lmvm->Fprev));
    lmvm->nupdates  = 0;
    lmvm->nrejects  = 0;
    lmvm->m_old     = 0;
    lmvm->allocated = PETSC_FALSE;
    B->preallocated = PETSC_FALSE;
    B->assembled    = PETSC_FALSE;
  }
  ++lmvm->nresets;
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatAllocate_LMVM(Mat B, Vec X, Vec F)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
  PetscBool same, allocate = PETSC_FALSE;
  VecType   vtype;

  PetscFunctionBegin;
  if (lmvm->allocated) {
    VecCheckMatCompatible(B, X, 2, F, 3);
    PetscCall(VecGetType(X, &vtype));
    PetscCall(PetscObjectTypeCompare((PetscObject)lmvm->Xprev, vtype, &same));
    if (!same) {
      /* Given X vector has a different type than allocated X-type data structures.
         We need to destroy all of this and duplicate again out of the given vector. */
      allocate = PETSC_TRUE;
      PetscCall(MatLMVMReset(B, PETSC_TRUE));
    }
  } else allocate = PETSC_TRUE;
  if (allocate) {
    PetscCall(VecGetType(X, &vtype));
    PetscCall(MatSetVecType(B, vtype));
    PetscCall(PetscLayoutReference(F->map, &B->rmap));
    PetscCall(PetscLayoutReference(X->map, &B->cmap));
    PetscCall(VecDuplicate(X, &lmvm->Xprev));
    PetscCall(VecDuplicate(F, &lmvm->Fprev));
    if (lmvm->m > 0) {
      PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lmvm->S));
      PetscCall(VecDuplicateVecs(lmvm->Fprev, lmvm->m, &lmvm->Y));
    }
    lmvm->m_old     = lmvm->m;
    lmvm->allocated = PETSC_TRUE;
    B->preallocated = PETSC_TRUE;
    B->assembled    = PETSC_TRUE;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatUpdateKernel_LMVM(Mat B, Vec S, Vec Y)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
  PetscInt  i;
  Vec       Stmp, Ytmp;

  PetscFunctionBegin;
  if (lmvm->k == lmvm->m - 1) {
    /* We hit the memory limit, so shift all the vectors back one spot
       and shift the oldest to the front to receive the latest update. */
    Stmp = lmvm->S[0];
    Ytmp = lmvm->Y[0];
    for (i = 0; i < lmvm->k; ++i) {
      lmvm->S[i] = lmvm->S[i + 1];
      lmvm->Y[i] = lmvm->Y[i + 1];
    }
    lmvm->S[lmvm->k] = Stmp;
    lmvm->Y[lmvm->k] = Ytmp;
  } else {
    ++lmvm->k;
  }
  /* Put the precomputed update into the last vector */
  PetscCall(VecCopy(S, lmvm->S[lmvm->k]));
  PetscCall(VecCopy(Y, lmvm->Y[lmvm->k]));
  ++lmvm->nupdates;
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatUpdate_LMVM(Mat B, Vec X, Vec F)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
  if (lmvm->prev_set) {
    /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
    PetscCall(VecAXPBY(lmvm->Xprev, 1.0, -1.0, X));
    PetscCall(VecAXPBY(lmvm->Fprev, 1.0, -1.0, F));
    /* Update S and Y */
    PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
  }

  /* Save the solution and function to be used in the next update */
  PetscCall(VecCopy(X, lmvm->Xprev));
  PetscCall(VecCopy(F, lmvm->Fprev));
  lmvm->prev_set = PETSC_TRUE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMultAdd_LMVM(Mat B, Vec X, Vec Y, Vec Z)
{
  PetscFunctionBegin;
  PetscCall(MatMult(B, X, Z));
  PetscCall(VecAXPY(Z, 1.0, Y));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatMult_LMVM(Mat B, Vec X, Vec Y)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  VecCheckSameSize(X, 2, Y, 3);
  VecCheckMatCompatible(B, X, 2, Y, 3);
  PetscCheck(lmvm->allocated, PetscObjectComm((PetscObject)B), PETSC_ERR_ORDER, "LMVM matrix must be allocated first");
  PetscCall((*lmvm->ops->mult)(B, X, Y));
  if (lmvm->shift) PetscCall(VecAXPY(Y, lmvm->shift, X));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatSolve_LMVM(Mat B, Vec F, Vec dX)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  VecCheckSameSize(F, 2, dX, 3);
  VecCheckMatCompatible(B, F, 2, dX, 3);
  PetscCheck(lmvm->allocated, PetscObjectComm((PetscObject)B), PETSC_ERR_ORDER, "LMVM matrix must be allocated first");
  PetscCheck(*lmvm->ops->solve, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_INCOMP, "LMVM matrix does not have a solution or inversion implementation");
  PetscCall((*lmvm->ops->solve)(B, F, dX));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatCopy_LMVM(Mat B, Mat M, MatStructure str)
{
  Mat_LMVM *bctx = (Mat_LMVM *)B->data;
  Mat_LMVM *mctx;
  PetscInt  i;
  PetscBool allocatedM;

  PetscFunctionBegin;
  if (str == DIFFERENT_NONZERO_PATTERN) {
    PetscCall(MatLMVMReset(M, PETSC_TRUE));
    PetscCall(MatLMVMAllocate(M, bctx->Xprev, bctx->Fprev));
  } else {
    PetscCall(MatLMVMIsAllocated(M, &allocatedM));
    PetscCheck(allocatedM, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Target matrix must be allocated first");
    MatCheckSameSize(B, 1, M, 2);
  }

  mctx = (Mat_LMVM *)M->data;
  if (bctx->user_pc) {
    PetscCall(MatLMVMSetJ0PC(M, bctx->J0pc));
  } else if (bctx->user_ksp) {
    PetscCall(MatLMVMSetJ0KSP(M, bctx->J0ksp));
  } else if (bctx->J0) {
    PetscCall(MatLMVMSetJ0(M, bctx->J0));
  } else if (bctx->user_scale) {
    if (bctx->J0diag) {
      PetscCall(MatLMVMSetJ0Diag(M, bctx->J0diag));
    } else {
      PetscCall(MatLMVMSetJ0Scale(M, bctx->J0scalar));
    }
  }
  mctx->nupdates = bctx->nupdates;
  mctx->nrejects = bctx->nrejects;
  mctx->k        = bctx->k;
  for (i = 0; i <= bctx->k; ++i) {
    if (bctx->S) PetscCall(VecCopy(bctx->S[i], mctx->S[i]));
    if (bctx->Y) PetscCall(VecCopy(bctx->Y[i], mctx->Y[i]));
    PetscCall(VecCopy(bctx->Xprev, mctx->Xprev));
    PetscCall(VecCopy(bctx->Fprev, mctx->Fprev));
  }
  if (bctx->ops->copy) PetscCall((*bctx->ops->copy)(B, M, str));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatDuplicate_LMVM(Mat B, MatDuplicateOption op, Mat *mat)
{
  Mat_LMVM *bctx = (Mat_LMVM *)B->data;
  Mat_LMVM *mctx;
  MatType   lmvmType;
  Mat       A;

  PetscFunctionBegin;
  PetscCall(MatGetType(B, &lmvmType));
  PetscCall(MatCreate(PetscObjectComm((PetscObject)B), mat));
  PetscCall(MatSetType(*mat, lmvmType));

  A                = *mat;
  mctx             = (Mat_LMVM *)A->data;
  mctx->m          = bctx->m;
  mctx->ksp_max_it = bctx->ksp_max_it;
  mctx->ksp_rtol   = bctx->ksp_rtol;
  mctx->ksp_atol   = bctx->ksp_atol;
  mctx->shift      = bctx->shift;
  PetscCall(KSPSetTolerances(mctx->J0ksp, mctx->ksp_rtol, mctx->ksp_atol, PETSC_CURRENT, mctx->ksp_max_it));

  PetscCall(MatLMVMAllocate(*mat, bctx->Xprev, bctx->Fprev));
  if (op == MAT_COPY_VALUES) PetscCall(MatCopy(B, *mat, SAME_NONZERO_PATTERN));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatShift_LMVM(Mat B, PetscScalar a)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  PetscCheck(lmvm->allocated, PetscObjectComm((PetscObject)B), PETSC_ERR_ORDER, "LMVM matrix must be allocated first");
  lmvm->shift += PetscRealPart(a);
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatView_LMVM(Mat B, PetscViewer pv)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
  PetscBool isascii;
  MatType   type;

  PetscFunctionBegin;
  PetscCall(PetscObjectTypeCompare((PetscObject)pv, PETSCVIEWERASCII, &isascii));
  if (isascii) {
    PetscCall(MatGetType(B, &type));
    PetscCall(PetscViewerASCIIPrintf(pv, "Max. storage: %" PetscInt_FMT "\n", lmvm->m));
    PetscCall(PetscViewerASCIIPrintf(pv, "Used storage: %" PetscInt_FMT "\n", lmvm->k + 1));
    PetscCall(PetscViewerASCIIPrintf(pv, "Number of updates: %" PetscInt_FMT "\n", lmvm->nupdates));
    PetscCall(PetscViewerASCIIPrintf(pv, "Number of rejects: %" PetscInt_FMT "\n", lmvm->nrejects));
    PetscCall(PetscViewerASCIIPrintf(pv, "Number of resets: %" PetscInt_FMT "\n", lmvm->nresets));
    if (lmvm->J0) {
      PetscCall(PetscViewerASCIIPrintf(pv, "J0 Matrix:\n"));
      PetscCall(PetscViewerPushFormat(pv, PETSC_VIEWER_ASCII_INFO));
      PetscCall(MatView(lmvm->J0, pv));
      PetscCall(PetscViewerPopFormat(pv));
    }
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatSetFromOptions_LMVM(Mat B, PetscOptionItems *PetscOptionsObject)
{
  Mat_LMVM *lmvm  = (Mat_LMVM *)B->data;
  PetscInt  m_new = lmvm->m;

  PetscFunctionBegin;
  PetscOptionsHeadBegin(PetscOptionsObject, "Limited-memory Variable Metric matrix for approximating Jacobians");
  PetscCall(PetscOptionsInt("-mat_lmvm_hist_size", "number of past updates kept in memory for the approximation", "", m_new, &m_new, NULL));
  PetscCall(PetscOptionsInt("-mat_lmvm_ksp_its", "(developer) fixed number of KSP iterations to take when inverting J0", "", lmvm->ksp_max_it, &lmvm->ksp_max_it, NULL));
  PetscCall(PetscOptionsReal("-mat_lmvm_eps", "(developer) machine zero definition", "", lmvm->eps, &lmvm->eps, NULL));
  PetscOptionsHeadEnd();
  if (m_new != lmvm->m) PetscCall(MatLMVMSetHistorySize(B, m_new));
  PetscCall(KSPSetFromOptions(lmvm->J0ksp));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatSetUp_LMVM(Mat B)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  if (!lmvm->allocated) {
    PetscCall(PetscLayoutSetUp(B->rmap));
    PetscCall(PetscLayoutSetUp(B->cmap));
    PetscCall(MatCreateVecs(B, &lmvm->Xprev, &lmvm->Fprev));
    if (lmvm->m > 0) {
      PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lmvm->S));
      PetscCall(VecDuplicateVecs(lmvm->Fprev, lmvm->m, &lmvm->Y));
    }
    lmvm->m_old     = lmvm->m;
    lmvm->allocated = PETSC_TRUE;
    B->preallocated = PETSC_TRUE;
    B->assembled    = PETSC_TRUE;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode MatDestroy_LMVM(Mat B)
{
  Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

  PetscFunctionBegin;
  if (lmvm->allocated) {
    PetscCall(VecDestroyVecs(lmvm->m, &lmvm->S));
    PetscCall(VecDestroyVecs(lmvm->m, &lmvm->Y));
    PetscCall(VecDestroy(&lmvm->Xprev));
    PetscCall(VecDestroy(&lmvm->Fprev));
  }
  PetscCall(KSPDestroy(&lmvm->J0ksp));
  PetscCall(MatLMVMClearJ0(B));
  PetscCall(PetscFree(B->data));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*MC
   MATLMVM - MATLMVM = "lmvm" - A matrix type used for Limited-Memory Variable Metric (LMVM) matrices.

   Level: intermediate

   Developer notes:
   Improve this manual page as well as many others in the MATLMVM family.

.seealso: [](sec_matlmvm), `Mat`
M*/
PetscErrorCode MatCreate_LMVM(Mat B)
{
  Mat_LMVM *lmvm;

  PetscFunctionBegin;
  PetscCall(PetscNew(&lmvm));
  B->data = (void *)lmvm;

  lmvm->m_old    = 0;
  lmvm->m        = 5;
  lmvm->k        = -1;
  lmvm->nupdates = 0;
  lmvm->nrejects = 0;
  lmvm->nresets  = 0;

  lmvm->ksp_max_it = 20;
  lmvm->ksp_rtol   = 0.0;
  lmvm->ksp_atol   = 0.0;

  lmvm->shift = 0.0;

  lmvm->eps        = PetscPowReal(PETSC_MACHINE_EPSILON, 2.0 / 3.0);
  lmvm->allocated  = PETSC_FALSE;
  lmvm->prev_set   = PETSC_FALSE;
  lmvm->user_scale = PETSC_FALSE;
  lmvm->user_pc    = PETSC_FALSE;
  lmvm->user_ksp   = PETSC_FALSE;
  lmvm->square     = PETSC_FALSE;

  B->ops->destroy        = MatDestroy_LMVM;
  B->ops->setfromoptions = MatSetFromOptions_LMVM;
  B->ops->view           = MatView_LMVM;
  B->ops->setup          = MatSetUp_LMVM;
  B->ops->shift          = MatShift_LMVM;
  B->ops->duplicate      = MatDuplicate_LMVM;
  B->ops->mult           = MatMult_LMVM;
  B->ops->multadd        = MatMultAdd_LMVM;
  B->ops->solve          = MatSolve_LMVM;
  B->ops->copy           = MatCopy_LMVM;

  lmvm->ops->update   = MatUpdate_LMVM;
  lmvm->ops->allocate = MatAllocate_LMVM;
  lmvm->ops->reset    = MatReset_LMVM;

  PetscCall(KSPCreate(PetscObjectComm((PetscObject)B), &lmvm->J0ksp));
  PetscCall(PetscObjectIncrementTabLevel((PetscObject)lmvm->J0ksp, (PetscObject)B, 1));
  PetscCall(KSPSetOptionsPrefix(lmvm->J0ksp, "mat_lmvm_"));
  PetscCall(KSPSetType(lmvm->J0ksp, KSPGMRES));
  PetscCall(KSPSetTolerances(lmvm->J0ksp, lmvm->ksp_rtol, lmvm->ksp_atol, PETSC_CURRENT, lmvm->ksp_max_it));
  PetscFunctionReturn(PETSC_SUCCESS);
}
