preCICE v3.1.2
Loading...
Searching...
No Matches
HipQRSolver.hip.cpp
Go to the documentation of this file.
1#ifdef PRECICE_WITH_HIP
2
4#include <ginkgo/ginkgo.hpp>
5#include <hip/hip_runtime.h>
6#include <hip/hip_runtime_api.h>
7#include <hipsolver.h>
8
9void computeQRDecompositionHip(const int deviceId, const std::shared_ptr<gko::Executor> &exec, gko::matrix::Dense<> *A_Q, gko::matrix::Dense<> *R)
10{
11 int backupDeviceId{};
12 hipGetDevice(&backupDeviceId);
13 hipSetDevice(deviceId);
14
15 void *dWork{};
16 int * devInfo{};
17
18 // Allocating important HIP variables
19 hipMalloc((void **) &dWork, sizeof(double));
20 hipMalloc((void **) &devInfo, sizeof(int));
21
22 hipsolverDnHandle_t solverHandle;
23 hipsolverDnCreate(&solverHandle);
24 // NOTE: It's important to transpose since hipsolver assumes column-major memory layout
25 // Making a copy since every value will be overridden
26 auto A_T = gko::share(gko::matrix::Dense<>::create(exec, gko::dim<2>(A_Q->get_size()[1], A_Q->get_size()[0])));
27 A_Q->transpose(gko::lend(A_T));
28
29 // Setting dimensions for solver
30 const unsigned int M = A_T->get_size()[1];
31 const unsigned int N = A_T->get_size()[0];
32
33 const int lda = max(1, M); // 1 > M ? 1 : M;
34 const int k = max(M, N); // M < N ? M : N;
35
36 int lwork_geqrf = 0;
37 int lwork_orgqr = 0;
38 int lwork = 0;
39
40 double *dTau{};
41 hipMalloc((void **) &dTau, sizeof(double) * M);
42
43 // Query working space of geqrf and orgqr
44 hipsolverStatus_t hipsolverStatus = hipsolverDnDgeqrf_bufferSize(solverHandle, M, N, A_T->get_values(), lda, &lwork_geqrf);
45 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
46 hipsolverStatus = hipsolverDnDorgqr_bufferSize(solverHandle, M, N, k, A_T->get_values(), lda, dTau, &lwork_orgqr);
47 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
48 lwork = (lwork_geqrf > lwork_orgqr) ? lwork_geqrf : lwork_orgqr;
49 hipError_t hipErrorCode = hipMalloc((void **) &dWork, sizeof(double) * lwork);
50 assert(hipSuccess == hipErrorCode);
51
52 void *hWork{};
53 // Compute QR factorization
54 hipsolverStatus = hipsolverDnDgeqrf(solverHandle, M, N, A_T->get_values(), lda, dTau, (double *) dWork, lwork, devInfo);
55 hipErrorCode = hipDeviceSynchronize();
56 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
57 assert(hipSuccess == hipErrorCode);
58
59 // Copy A_T to R s.t. the upper triangle corresponds to R
60 A_T->transpose(gko::lend(R));
61
62 // Compute Q
63 hipsolverStatus = hipsolverDnDorgqr(solverHandle, M, N, k, A_T->get_values(), lda, dTau, (double *) dWork, lwork, devInfo);
64 hipErrorCode = hipDeviceSynchronize();
65 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
66 assert(hipSuccess == hipErrorCode);
67
68 A_T->transpose(gko::lend(A_Q));
69
70 hipDeviceSynchronize();
71
72 // Free the utilized memory
73 hipFree(dTau);
74 hipFree(dWork);
75 hipFree(devInfo);
76 hipsolverDnDestroy(solverHandle);
77
78 // ...and switch back to the GPU used for all coupled solvers
79 hipSetDevice(backupDeviceId);
80}
81#endif
T max(T... args)