4#include <ginkgo/ginkgo.hpp>
5#include <hip/hip_runtime.h>
6#include <hip/hip_runtime_api.h>
7#include <hipblas/hipblas.h>
8#include <hipsolver/hipsolver.h>
13 auto scope_guard = exec->get_scoped_device_id_guard();
19 hipError_t hipErrorCode;
20 hipErrorCode = hipMalloc((
void **) &dWork,
sizeof(
double));
21 assert(hipErrorCode == hipSuccess);
22 hipErrorCode = hipMalloc((
void **) &devInfo,
sizeof(
int));
23 assert(hipErrorCode == hipSuccess);
25 hipsolverDnHandle_t solverHandle;
26 hipsolverDnCreate(&solverHandle);
29 auto A_T = gko::share(gko::matrix::Dense<>::create(exec, gko::dim<2>(A_Q->get_size()[1], A_Q->get_size()[0])));
33 const unsigned int M = A_T->get_size()[1];
34 const unsigned int N = A_T->get_size()[0];
36 const int lda =
max(1, M);
37 const int k =
max(M, N);
44 hipErrorCode = hipMalloc((
void **) &dTau,
sizeof(
double) * M);
45 assert(hipErrorCode == hipSuccess);
50 hipsolverStatus_t hipsolverStatus = hipsolverDnDgeqrf_bufferSize(solverHandle, M, N, A_T->get_values(), lda, &lwork_geqrf);
51 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
52 hipsolverStatus = hipsolverDnDorgqr_bufferSize(solverHandle, M, N, k, A_T->get_values(), lda, dTau, &lwork_orgqr);
53 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
54 lwork = (lwork_geqrf > lwork_orgqr) ? lwork_geqrf : lwork_orgqr;
55 hipErrorCode = hipMalloc((
void **) &dWork,
sizeof(
double) * lwork);
56 assert(hipSuccess == hipErrorCode);
60 hipsolverStatus = hipsolverDnDgeqrf(solverHandle, M, N, A_T->get_values(), lda, dTau, (
double *) dWork, lwork, devInfo);
61 hipErrorCode = hipDeviceSynchronize();
62 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
63 assert(hipSuccess == hipErrorCode);
69 hipsolverStatus = hipsolverDnDorgqr(solverHandle, M, N, k, A_T->get_values(), lda, dTau, (
double *) dWork, lwork, devInfo);
70 hipErrorCode = hipDeviceSynchronize();
71 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
72 assert(hipSuccess == hipErrorCode);
76 hipErrorCode = hipDeviceSynchronize();
77 assert(hipSuccess == hipErrorCode);
78 calculateQRDecompEvent.
stop();
81 hipErrorCode = hipFree(dTau);
82 assert(hipSuccess == hipErrorCode);
83 hipErrorCode = hipFree(dWork);
84 assert(hipSuccess == hipErrorCode);
85 hipErrorCode = hipFree(devInfo);
86 assert(hipSuccess == hipErrorCode);
87 hipsolverStatus = hipsolverDnDestroy(solverHandle);
88 assert(hipsolverStatus == HIPSOLVER_STATUS_SUCCESS);
93 auto scope_guard = exec->get_scoped_device_id_guard();
95 hipblasHandle_t handle;
96 hipblasStatus_t hipblasStatus = hipblasCreate(&handle);
97 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
100 hipblasStatus = hipblasDgemv(handle, HIPBLAS_OP_T,
101 matQ->get_size()[0], matQ->get_size()[1],
103 matQ->get_values(), matQ->get_size()[0],
104 in_vec->get_values(), 1,
106 rhs->get_values(), 1);
107 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
109 hipblasFillMode_t uplo = HIPBLAS_FILL_MODE_LOWER;
110 hipblasOperation_t trans = HIPBLAS_OP_T;
113 hipblasDiagType_t diag = HIPBLAS_DIAG_NON_UNIT;
114 int rows = rhs->get_size()[0];
115 const int lda =
max(1, rows);
117 hipblasStatus = hipblasDtrsv(handle, uplo,
119 rows, U->get_values(), lda,
120 rhs->get_values(), 1);
121 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
123 hipError_t hipErrorCode = hipDeviceSynchronize();
124 assert(hipSuccess == hipErrorCode);
126 hipblasStatus = hipblasDestroy(handle);
127 assert(hipblasStatus == HIPBLAS_STATUS_SUCCESS);
gko::matrix::Dense<> GinkgoMatrix
gko::matrix::Dense<> GinkgoVector
void stop()
Stops a running event.