preCICE v3.2.0
Loading...
Searching...
No Matches
HipQRSolver.hip.cpp
Go to the documentation of this file.
1#ifdef PRECICE_WITH_HIP
2
4#include <ginkgo/ginkgo.hpp>
5#include <hip/hip_runtime.h>
6#include <hip/hip_runtime_api.h>
7#include <hipblas/hipblas.h>
8#include <hipsolver/hipsolver.h>
9#include "profiling/Event.hpp"
10
11void computeQRDecompositionHip(const std::shared_ptr<gko::Executor> &exec, GinkgoMatrix *A_Q, GinkgoVector *R)
12{
13 auto scope_guard = exec->get_scoped_device_id_guard();
14
15 void *dWork{};
16 int *devInfo{};
17
18 // Allocating important HIP variables
19 hipError_t hipErrorCode;
20 hipErrorCode = hipMalloc((void **) &dWork, sizeof(double));
21 assert(hipErrorCode == hipSuccess);
22 hipErrorCode = hipMalloc((void **) &devInfo, sizeof(int));
23 assert(hipErrorCode == hipSuccess);
24
25 hipsolverDnHandle_t solverHandle;
26 hipsolverDnCreate(&solverHandle);
27 // NOTE: It's important to transpose since hipsolver assumes column-major memory layout
28 // Making a copy since every value will be overridden
29 auto A_T = gko::share(gko::matrix::Dense<>::create(exec, gko::dim<2>(A_Q->get_size()[1], A_Q->get_size()[0])));
30 A_Q->transpose(A_T);
31
32 // Setting dimensions for solver
33 const unsigned int M = A_T->get_size()[1];
34 const unsigned int N = A_T->get_size()[0];
35
36 const int lda = max(1, M); // 1 > M ? 1 : M;
37 const int k = max(M, N); // M < N ? M : N;
38
39 int lwork_geqrf = 0;
40 int lwork_orgqr = 0;
41 int lwork = 0;
42
43 double *dTau{};
44 hipErrorCode = hipMalloc((void **) &dTau, sizeof(double) * M);
45 assert(hipErrorCode == hipSuccess);
46
47 precice::profiling::Event calculateQRDecompEvent{"calculateQRDecomp"};
48
49 // Query working space of geqrf and orgqr
50 hipsolverStatus_t hipsolverStatus = hipsolverDnDgeqrf_bufferSize(solverHandle, M, N, A_T->get_values(), lda, &lwork_geqrf);
51 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
52 hipsolverStatus = hipsolverDnDorgqr_bufferSize(solverHandle, M, N, k, A_T->get_values(), lda, dTau, &lwork_orgqr);
53 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
54 lwork = (lwork_geqrf > lwork_orgqr) ? lwork_geqrf : lwork_orgqr;
55 hipErrorCode = hipMalloc((void **) &dWork, sizeof(double) * lwork);
56 assert(hipSuccess == hipErrorCode);
57
58 void *hWork{};
59 // Compute QR factorization
60 hipsolverStatus = hipsolverDnDgeqrf(solverHandle, M, N, A_T->get_values(), lda, dTau, (double *) dWork, lwork, devInfo);
61 hipErrorCode = hipDeviceSynchronize();
62 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
63 assert(hipSuccess == hipErrorCode);
64
65 // Copy A_T to R s.t. the upper triangle corresponds to R
66 A_T->transpose(R);
67
68 // Compute Q
69 hipsolverStatus = hipsolverDnDorgqr(solverHandle, M, N, k, A_T->get_values(), lda, dTau, (double *) dWork, lwork, devInfo);
70 hipErrorCode = hipDeviceSynchronize();
71 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
72 assert(hipSuccess == hipErrorCode);
73
74 A_T->transpose(A_Q);
75
76 hipErrorCode = hipDeviceSynchronize();
77 assert(hipSuccess == hipErrorCode);
78 calculateQRDecompEvent.stop();
79
80 // Free the utilized memory
81 hipErrorCode = hipFree(dTau);
82 assert(hipSuccess == hipErrorCode);
83 hipErrorCode = hipFree(dWork);
84 assert(hipSuccess == hipErrorCode);
85 hipErrorCode = hipFree(devInfo);
86 assert(hipSuccess == hipErrorCode);
87 hipsolverStatus = hipsolverDnDestroy(solverHandle);
88 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
89}
90
91void solvewithQRDecompositionHip(const std::shared_ptr<gko::Executor> &exec, GinkgoMatrix *U, GinkgoVector *x, GinkgoVector *rhs, GinkgoMatrix *matQ, GinkgoVector *in_vec)
92{
93 auto scope_guard = exec->get_scoped_device_id_guard();
94
95 hipblasHandle_t handle;
96 hipblasStatus_t hipblasStatus = hipblasCreate(&handle);
97 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
98 double a = 1;
99 double b = 0;
100 hipblasStatus = hipblasDgemv(handle, HIPBLAS_OP_T,
101 matQ->get_size()[0], matQ->get_size()[1],
102 &a,
103 matQ->get_values(), matQ->get_size()[0],
104 in_vec->get_values(), 1,
105 &b,
106 rhs->get_values(), 1);
107 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
108
109 hipblasFillMode_t uplo = HIPBLAS_FILL_MODE_LOWER;
110 hipblasOperation_t trans = HIPBLAS_OP_T;
111
112 // unit triangular = diag = 1
113 hipblasDiagType_t diag = HIPBLAS_DIAG_NON_UNIT;
114 int rows = rhs->get_size()[0];
115 const int lda = max(1, rows);
116
117 hipblasStatus = hipblasDtrsv(handle, uplo,
118 trans, diag,
119 rows, U->get_values(), lda,
120 rhs->get_values(), 1);
121 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
122
123 hipError_t hipErrorCode = hipDeviceSynchronize();
124 assert(hipSuccess == hipErrorCode);
125 *x = *rhs;
126 hipblasStatus = hipblasDestroy(handle);
127 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
128}
129#endif
gko::matrix::Dense<> GinkgoMatrix
gko::matrix::Dense<> GinkgoVector
void stop()
Stops a running event.
Definition Event.cpp:51
T max(T... args)