mirror of
https://github.com/TREX-CoE/Sherman-Morrison.git
synced 2025-01-12 05:58:28 +01:00
- Sync and Async version
- OpenMP version - PP defines cleanup
This commit is contained in:
parent
2d5a34faed
commit
f7dbe3ddd8
@ -4,13 +4,13 @@ CFLAGS=-std=c99 -O3 -Wall -g -xCORE-AVX2
|
||||
LDFLAGS=-L/usr/lib/x86_64-linux-gnu/hdf5/serial -lhdf5 -lhdf5_hl
|
||||
LDFLAGS+=-L$(MKLROOT)/lib/intel64 -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -lm -ldl
|
||||
|
||||
all: test_icc_mkl
|
||||
all: clean test_icc_mkl_sequential
|
||||
|
||||
test_icc_mkl: kernels.o test.o helper.o
|
||||
$(CC) $(LDFLAGS) -o $@ $^
|
||||
test_icc_mkl_sequential: kernels.o test.o helper.o
|
||||
$(CC) -o $@ $^ $(LDFLAGS)
|
||||
|
||||
%.o : %.c
|
||||
$(CC) $(CFLAGS) $(INCLUDE) -c -o $@ $<
|
||||
|
||||
clean:
|
||||
rm -rf *.o *genmod* test_icc_mkl
|
||||
rm -rf *.o *genmod* test_icc_mkl_sequential
|
18
independent_test_harness/Makefile.icc_mkl_threaded.cpu
Normal file
18
independent_test_harness/Makefile.icc_mkl_threaded.cpu
Normal file
@ -0,0 +1,18 @@
|
||||
CC=icc
|
||||
|
||||
CFLAGS=-std=c99 -O3 -Wall -g -xCORE-AVX2 -DUSE_OMP -qopenmp
|
||||
INCLUDE=-I$(MKLROOT)/include
|
||||
|
||||
LDFLAGS=-L/usr/lib/x86_64-linux-gnu/hdf5/serial -lhdf5 -lhdf5_hl
|
||||
LDFLAGS+=-L$(MKLROOT)/lib/intel64 -lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -liomp5 -lpthread -lm -ldl
|
||||
|
||||
all: clean test_icc_mkl_threaded
|
||||
|
||||
test_icc_mkl_threaded: kernels.o test.o helper.o
|
||||
$(CC) -o $@ $^ $(LDFLAGS)
|
||||
|
||||
%.o : %.c
|
||||
$(CC) $(CFLAGS) $(INCLUDE) -c -o $@ $<
|
||||
|
||||
clean:
|
||||
rm -rf *.o *genmod* test_icc_mkl_threaded
|
@ -1,18 +1,19 @@
|
||||
#FC = ifx
|
||||
CC = nvc
|
||||
|
||||
#CFLAGS=-std=c99 -O0 -Wall -g -DHAVE_CUBLAS_OFFLOAD -DUSE_NVTX -mp -target=gpu
|
||||
CFLAGS=-std=c99 -O3 -Wall -g -DHAVE_CUBLAS_OFFLOAD -DUSE_NVTX -mp -target=gpu
|
||||
#CFLAGS=-std=c99 -O0 -Wall -g -DUSE_OMP_OFFLOAD_CUDA -DUSE_NVTX -mp -target=gpu
|
||||
CFLAGS=-std=c99 -O3 -Wall -g -DUSE_OMP_OFFLOAD_CUDA -DUSE_NVTX -mp -target=gpu
|
||||
|
||||
INCLUDE=-I$(NVHPC_ROOT)/math_libs/include
|
||||
INCLUDE =-I$(NVHPC_ROOT)/math_libs/include
|
||||
INCLUDE+=-I$(NVHPC_ROOT)/cuda/11.7/targets/x86_64-linux/include
|
||||
INCLUDE+=-I$(NVHPC_ROOT)/profilers/Nsight_Systems/target-linux-x64/nvtx/include
|
||||
INCLUDE+=-I$(MKLROOT)/include
|
||||
|
||||
LDFLAGS=-L/usr/lib/x86_64-linux-gnu/hdf5/serial -lhdf5 -lhdf5_hl
|
||||
LDFLAGS+=-L$(MKLROOT)/lib/intel64 -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -lm -ldl
|
||||
LDFLAGS+=-L${MKLROOT}/lib/intel64 -lmkl_intel_lp64 -lmkl_pgi_thread -lmkl_core -mp -lpthread -lm -ldl
|
||||
LDFLAGS+=-L$(NVHPC_ROOT)/math_libs/lib64 -lcublas -lcusolver -mp -target=gpu
|
||||
|
||||
all: test_nvc_ompol
|
||||
all: clean test_nvc_ompol
|
||||
|
||||
test_nvc_ompol: kernels.o test.o helper.o
|
||||
$(CC) $(LDFLAGS) -o $@ $^
|
||||
|
@ -1,6 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
for SIZE in 32 64 128 256 512 1024 2048 4096 8192 16384
|
||||
export OMP_NUM_THREADS=10
|
||||
export MKL_NUM_THREADS=$OMP_NUM_THREADS
|
||||
|
||||
for SIZE in 32 64 128 256 512 1024 2048 4096 8192 #16384
|
||||
do
|
||||
echo $SIZE >> SIZES
|
||||
for LOAD in 25 50 75 100
|
||||
@ -11,13 +14,13 @@ do
|
||||
do
|
||||
case $KERNEL in
|
||||
MKL)
|
||||
./test_nvc_ompol m | awk 'NR==5 {print $11}' >> ${KERNEL}_${LOAD}.dat
|
||||
././test_icc_mkl_threaded m | awk 'NR==7 {print $11}' >> ${KERNEL}_${LOAD}.dat
|
||||
;;
|
||||
WBK_CPU)
|
||||
./test_nvc_ompol k | awk 'NR==5 {print $11}' >> ${KERNEL}_${LOAD}.dat
|
||||
././test_icc_mkl_threaded o | awk 'NR==7 {print $11}' >> ${KERNEL}_${LOAD}.dat
|
||||
;;
|
||||
WBK_GPU)
|
||||
./test_nvc_ompol c | awk 'NR==5 {print $11}' >> ${KERNEL}_${LOAD}.dat
|
||||
./test_nvc_ompol c | awk 'NR==7 {print $11}' >> ${KERNEL}_${LOAD}.dat
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
cublasHandle_t init_cublas() {
|
||||
cublasHandle_t handle;
|
||||
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) {
|
||||
@ -24,6 +24,7 @@
|
||||
|
||||
void copy(double* Slater_invT_copy, uint64_t Lds, double* tmp, uint64_t Dim) {
|
||||
for (uint32_t i = 0; i < Dim; i++) {
|
||||
// #pragma omp parallel for
|
||||
for (uint32_t j = 0; j < Lds; j++) {
|
||||
if (j < Dim) Slater_invT_copy[i * Lds + j] = tmp[i * Dim + j];
|
||||
else Slater_invT_copy[i * Lds + j] = 0.0;
|
||||
@ -32,29 +33,29 @@ void copy(double* Slater_invT_copy, uint64_t Lds, double* tmp, uint64_t Dim) {
|
||||
}
|
||||
|
||||
void update(double* slaterT,double* upds, uint64_t* ui, uint64_t nupds,uint64_t Dim, u_int64_t Lds) {
|
||||
// #pragma omp parallel for collapse(2)
|
||||
for (int i = 0; i < nupds; i++) {
|
||||
int col = ui[i] - 1;
|
||||
for (int j = 0; j < Dim; j++) {
|
||||
int col = ui[i] - 1;
|
||||
slaterT[col + j * Dim] += upds[i * Lds + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void convert(double* upds, uint64_t nupds, uint64_t* ui, double* slaterT, uint64_t Dim, u_int64_t Lds) {
|
||||
// #pragma omp parallel for collapse(2)
|
||||
for (int i = 0; i < nupds; i++) {
|
||||
int col = ui[i] - 1;
|
||||
for (int j = 0; j < Lds; j++) {
|
||||
int col = ui[i] - 1;
|
||||
upds[i * Lds + j] -= slaterT[col + j * Dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void transpose(double* a, uint16_t lda, double *b, uint16_t ldb, uint16_t m, uint16_t n)
|
||||
{
|
||||
for(uint16_t i = 0; i < m; i++)
|
||||
{
|
||||
for( uint16_t j = 0; j < n; j++)
|
||||
{
|
||||
void transpose(double* a, uint16_t lda, double *b, uint16_t ldb, uint16_t m, uint16_t n) {
|
||||
// #pragma omp parallel for collapse(2)
|
||||
for(uint16_t i = 0; i < m; i++) {
|
||||
for( uint16_t j = 0; j < n; j++) {
|
||||
b[j * ldb + i] = a[i * lda + j];
|
||||
}
|
||||
}
|
||||
@ -167,10 +168,10 @@ void read_double(hid_t file_id, const char *key, double *data) {
|
||||
void update_slater_matrix(const uint64_t Lds, const uint64_t Dim,
|
||||
const uint64_t N_updates, const double *Updates,
|
||||
const uint64_t *Updates_index, double *Slater) {
|
||||
|
||||
// #pragma omp parallel for collapse(2)
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
uint32_t col = Updates_index[i] - 1;
|
||||
for (uint32_t j = 0; j < Dim; j++) {
|
||||
uint32_t col = Updates_index[i] - 1;
|
||||
Slater[col * Dim + j] += Updates[i * Lds + j];
|
||||
}
|
||||
}
|
||||
@ -185,7 +186,6 @@ uint32_t check_error(const uint64_t Lds, const uint64_t Dim, double *Slater_invT
|
||||
Dim, Dim, Dim,
|
||||
alpha, Slater, Dim, Slater_invT, Lds,
|
||||
beta, res, Dim);
|
||||
|
||||
for (uint32_t i = 0; i < Dim; i++) {
|
||||
for (uint32_t j = 0; j < Dim; j++) {
|
||||
double elm = res[i * Dim + j];
|
||||
@ -216,6 +216,7 @@ int32_t check_error_better(const double max, const double tolerance) {
|
||||
}
|
||||
|
||||
void residual(double *a, double *res, const uint64_t Dim) {
|
||||
// #pragma omp parallel for collapse(2)
|
||||
for (uint32_t i = 0; i < Dim; i++) {
|
||||
for (uint32_t j = 0; j < Dim; j++) {
|
||||
if (i == j) res[i * Dim + j] = a[i * Dim + j] - 1.0;
|
||||
|
@ -15,7 +15,7 @@ typedef struct Error {
|
||||
uint64_t error;
|
||||
} Error;
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
cublasHandle_t init_cublas();
|
||||
cusolverDnHandle_t init_cusolver();
|
||||
#endif
|
||||
|
@ -281,7 +281,7 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
const uint32_t Lds = vLDS;
|
||||
|
||||
// Compute C = S^{-1} U : Dim x K : standard dgemm
|
||||
double *C = calloc(1, Dim * N_updates * sizeof(double));
|
||||
double* __restrict __attribute__ ((aligned(8))) C = calloc(1, Dim * N_updates * sizeof(double));
|
||||
double alpha = 1.0, beta = 0.0;
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
Dim, N_updates, Lds,
|
||||
@ -289,8 +289,8 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
beta, C, N_updates);
|
||||
|
||||
// Construct B = 1 + V C : K x K, construct D = V S^{-1} : K x LDS
|
||||
double* B = calloc(1, sizeof *B * N_updates * N_updates);
|
||||
double* D = calloc(1, sizeof *D * N_updates * Lds);
|
||||
double* __restrict __attribute__ ((aligned(8))) B = malloc(sizeof *B * N_updates * N_updates);
|
||||
double* __restrict __attribute__ ((aligned(8))) D = malloc(sizeof *D * N_updates * Lds);
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
const uint32_t row = Updates_index[i] - 1;
|
||||
for (uint32_t j = 0; j < N_updates ; j++) B[i * N_updates + j] = C[row * N_updates + j] + (i == j);
|
||||
@ -298,7 +298,7 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
}
|
||||
|
||||
// Compute determinant by LU decomposition
|
||||
int* pivot = calloc(1, sizeof *pivot * N_updates);
|
||||
int* pivot = malloc(sizeof *pivot * N_updates);
|
||||
(void) LAPACKE_dgetrf(LAPACK_ROW_MAJOR, N_updates, N_updates, B, N_updates, pivot);
|
||||
|
||||
bool swap = false; uint32_t j = 0; double det = 1.0f;
|
||||
@ -317,7 +317,7 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
(void) LAPACKE_dgetri(LAPACK_ROW_MAJOR, N_updates, B, N_updates, pivot);
|
||||
|
||||
// tmp1 = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
|
||||
double* tmp1 = calloc(1, sizeof *tmp1 * N_updates * Lds);
|
||||
double* __restrict __attribute__ ((aligned(8))) tmp1 = calloc(1, sizeof *tmp1 * N_updates * Lds);
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||
N_updates, Lds, N_updates,
|
||||
alpha, B, N_updates, D, Lds,
|
||||
@ -338,33 +338,118 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
return 0;
|
||||
}
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHandle_t s_handle,
|
||||
#ifdef USE_OMP
|
||||
uint32_t qmckl_woodbury_k_omp(const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
const double *__restrict __attribute__((aligned(8))) Updates,
|
||||
const uint64_t *__restrict Updates_index,
|
||||
const double breakdown,
|
||||
double *__restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double *__restrict determinant) {
|
||||
|
||||
const uint32_t Dim = vDim;
|
||||
const uint32_t Lds = vLDS;
|
||||
|
||||
// Compute C = S^{-1} U : Dim x K : standard dgemm
|
||||
double* __restrict __attribute__ ((aligned(8))) C = calloc(1, Dim * N_updates * sizeof(double));
|
||||
double alpha = 1.0, beta = 0.0;
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
Dim, N_updates, Lds,
|
||||
alpha, Slater_inv, Lds, Updates, Lds,
|
||||
beta, C, N_updates);
|
||||
|
||||
// Construct B = 1 + V C : K x K
|
||||
double* __restrict __attribute__ ((aligned(8))) B = malloc(sizeof *B * N_updates * N_updates);
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
for (uint32_t j = 0; j < N_updates ; j++) {
|
||||
const uint32_t row = Updates_index[i] - 1;
|
||||
B[i * N_updates + j] = C[row * N_updates + j] + (i == j);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute determinant by LU decomposition
|
||||
int* pivot = malloc(sizeof *pivot * N_updates);
|
||||
(void) LAPACKE_dgetrf(LAPACK_ROW_MAJOR, N_updates, N_updates, B, N_updates, pivot);
|
||||
|
||||
bool swap = false; uint32_t j = 0; double det = 1.0f;
|
||||
#pragma omp parallel for reduction(+: j) reduction(*: det)
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
swap = (bool)(pivot[i] - (i + 1)); // swap = {0->false: no swap, >0->true: swap}
|
||||
j += (uint32_t)swap; // count # of swaps
|
||||
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
|
||||
}
|
||||
if (fabs(det) < breakdown) return 1; // check if determinant of B is too close to zero. If so, exit early.
|
||||
if (determinant) { // update det(Slater) if determinant != NULL
|
||||
if ((j & 1) != 0) det = -det; // multiply det with -1 if # of swaps is odd
|
||||
*determinant *= det;
|
||||
}
|
||||
|
||||
// Compute B^{-1} with explicit formula for K x K inversion
|
||||
(void) LAPACKE_dgetri(LAPACK_ROW_MAJOR, N_updates, B, N_updates, pivot);
|
||||
|
||||
// Construct D = V S^{-1} : K x LDS
|
||||
double* __restrict __attribute__ ((aligned(8))) D = malloc(sizeof *D * N_updates * Lds);
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
for (uint32_t j = 0; j < Lds; j++) {
|
||||
const uint32_t row = Updates_index[i] - 1;
|
||||
D[i * Lds + j] = Slater_inv[row * Lds + j];
|
||||
}
|
||||
}
|
||||
|
||||
// tmp1 = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
|
||||
double* __restrict __attribute__ ((aligned(8))) tmp1 = calloc(1, sizeof *tmp1 * N_updates * Lds);
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||
N_updates, Lds, N_updates,
|
||||
alpha, B, N_updates, D, Lds,
|
||||
beta, tmp1, Lds);
|
||||
|
||||
// Compute S^{-1} - C * tmp1 : Dim x LDS : standard dgemm
|
||||
alpha = -1.0, beta = 1.0;
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||
Dim, Lds, N_updates,
|
||||
alpha, C, N_updates, tmp1, Lds,
|
||||
beta, Slater_inv, Lds);
|
||||
|
||||
free(C);
|
||||
free(B);
|
||||
free(D);
|
||||
free(tmp1);
|
||||
free(pivot);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
uint32_t qmckl_woodbury_k_ompol_cuda_async(cublasHandle_t b_handle, cusolverDnHandle_t s_handle,
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
const double* Updates,
|
||||
const uint64_t* Updates_index,
|
||||
const double* __restrict __attribute__((aligned(8))) Updates,
|
||||
const uint64_t* __restrict Updates_index,
|
||||
const double breakdown,
|
||||
double* Slater_inv,
|
||||
double* determinant)
|
||||
double* __restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double* __restrict determinant)
|
||||
{
|
||||
PUSH_RANGE("Kernel execution", 1)
|
||||
const uint32_t Dim = vDim;
|
||||
const uint32_t Lds = vLDS;
|
||||
|
||||
bool swap;
|
||||
uint32_t j;
|
||||
double alpha, beta, det;
|
||||
int* pivot = malloc(sizeof *pivot * N_updates);
|
||||
double* C = malloc(sizeof *C * Dim * N_updates);
|
||||
double* B = malloc(sizeof *B * N_updates * N_updates);
|
||||
double* Binv = malloc(sizeof *Binv * N_updates * N_updates);
|
||||
double* D = malloc(sizeof *D * N_updates * Lds);
|
||||
double* T1 = malloc(sizeof *T1 * N_updates * Lds);
|
||||
double* T2 = malloc(sizeof *T2 * Dim * Lds);
|
||||
int* __restrict pivot = malloc(sizeof *pivot * N_updates);
|
||||
double* __restrict __attribute__((aligned(8))) C = malloc(sizeof *C * Dim * N_updates);
|
||||
double* __restrict __attribute__((aligned(8))) B = malloc(sizeof *B * N_updates * N_updates);
|
||||
double* __restrict __attribute__((aligned(8))) Binv = malloc(sizeof *Binv * N_updates * N_updates);
|
||||
double* __restrict __attribute__((aligned(8))) D = malloc(sizeof *D * N_updates * Lds);
|
||||
double* __restrict __attribute__((aligned(8))) T1 = malloc(sizeof *T1 * N_updates * Lds);
|
||||
double* __restrict __attribute__((aligned(8))) T2 = malloc(sizeof *T2 * Dim * Lds);
|
||||
|
||||
int workspace_size = 0, *info = NULL;
|
||||
double* workspace = NULL;
|
||||
double* __restrict __attribute__((aligned(8))) workspace = NULL;
|
||||
cusolverDnDgetrf_bufferSize(s_handle, N_updates, N_updates, B, N_updates, &workspace_size);
|
||||
// printf("SIZE OF CUSOLVER WORKSPACE: %d doubles of %lu byte = %lu byte\n", workspace_size, sizeof *workspace, sizeof *workspace * workspace_size);
|
||||
workspace = malloc(sizeof *workspace * workspace_size);
|
||||
@ -412,7 +497,7 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
||||
POP_RANGE
|
||||
#pragma omp target exit data map(delete: workspace[0:workspace_size])
|
||||
swap = false; j = 0; det = 1.0f;
|
||||
PUSH_RANGE("Test det(B) and count LU swaps", 4)
|
||||
PUSH_RANGE("Compute |det(B)| and count # of LU swaps", 4)
|
||||
#pragma omp target teams distribute parallel for reduction(+: j) reduction(*: det)
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
swap = (bool)(pivot[i] - (i + 1)); // swap = {0->false: no swap, >0->true: swap}
|
||||
@ -420,11 +505,13 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
||||
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
|
||||
}
|
||||
POP_RANGE
|
||||
PUSH_RANGE("A bunch of branches: test break, det-sign", 4)
|
||||
if (fabs(det) < breakdown) return 1; // check if determinant of B is too close to zero. If so, exit early.
|
||||
if (determinant) { // update det(Slater) if determinant != NULL
|
||||
if ((j & 1) != 0) det = -det; // multiply det with -1 if # of swaps is odd
|
||||
*determinant *= det;
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// Compute B^{-1} : initialise as I for solving BX=I
|
||||
PUSH_RANGE("Allocate Binv ON GPU", 2)
|
||||
@ -442,7 +529,8 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
||||
#pragma omp target data use_device_ptr(B, pivot, Binv)
|
||||
{
|
||||
PUSH_RANGE("Compute B^{-1}", 3)
|
||||
(void) cusolverDnDgetrs(s_handle, CUBLAS_OP_T, N_updates, N_updates, B, N_updates, pivot, Binv, N_updates, info); // Needs op(B) = B^T because of line 403
|
||||
(void) cusolverDnDgetrs(s_handle, CUBLAS_OP_T, N_updates, N_updates, B,N_updates,
|
||||
pivot, Binv, N_updates, info); // Needs op(B) = B^T because of line 403
|
||||
POP_RANGE
|
||||
}
|
||||
PUSH_RANGE("Deallocate B, pivot ON GPU", 2)
|
||||
@ -518,6 +606,154 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
||||
free(D);
|
||||
free(T1);
|
||||
|
||||
POP_RANGE
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t qmckl_woodbury_k_ompol_cuda_sync(cublasHandle_t b_handle, cusolverDnHandle_t s_handle,
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
const double* __restrict __attribute__((aligned(8))) Updates,
|
||||
const uint64_t* __restrict Updates_index,
|
||||
const double breakdown,
|
||||
double* __restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double* __restrict determinant) {
|
||||
const uint32_t Dim = vDim;
|
||||
const uint32_t Lds = vLDS;
|
||||
|
||||
uint32_t j;
|
||||
double alpha, beta, det;
|
||||
int *__restrict pivot = malloc(sizeof *pivot * N_updates);
|
||||
double *__restrict __attribute__((aligned(8))) C = malloc(sizeof *C * Dim * N_updates);
|
||||
double *__restrict __attribute__((aligned(8))) B = malloc(sizeof *B * N_updates * N_updates);
|
||||
double *__restrict __attribute__((aligned(8))) Binv = malloc(sizeof *Binv * N_updates * N_updates);
|
||||
double *__restrict __attribute__((aligned(8))) D = malloc(sizeof *D * N_updates * Lds);
|
||||
double *__restrict __attribute__((aligned(8))) T1 = malloc(sizeof *T1 * N_updates * Lds);
|
||||
double *__restrict __attribute__((aligned(8))) T2 = malloc(sizeof *T2 * Dim * Lds);
|
||||
|
||||
int workspace_size = 0, *info = NULL;
|
||||
double *__restrict __attribute__((aligned(8))) workspace = NULL;
|
||||
cusolverDnDgetrf_bufferSize(s_handle, N_updates, N_updates, B, N_updates, &workspace_size);
|
||||
workspace = malloc(sizeof *workspace * workspace_size);
|
||||
PUSH_RANGE("OpenMP OL Synchronous region", 1)
|
||||
// PUSH_RANGE("Data init from host to device", 2)
|
||||
#pragma omp target data map(to: Updates[0:Lds*N_updates], \
|
||||
Updates_index[0:N_updates]) \
|
||||
map(tofrom: Slater_inv[0:Dim*Lds]) \
|
||||
map(alloc: C[0:Dim*N_updates], \
|
||||
B[0:N_updates*N_updates], \
|
||||
workspace[0:workspace_size], \
|
||||
pivot[0:N_updates], \
|
||||
Binv[0:N_updates*N_updates], \
|
||||
D[0:N_updates*Lds], \
|
||||
T1[0:N_updates*Lds])
|
||||
// POP_RANGE
|
||||
{
|
||||
|
||||
// Compute C <- S^{-1} U : Dim x K : standard dgemm
|
||||
alpha = 1.0f, beta = 0.0f;
|
||||
PUSH_RANGE("Compute C <- S^{-1} U", 3)
|
||||
#pragma omp target data use_device_ptr(Slater_inv, Updates, C)
|
||||
{
|
||||
(void) cublasDgemm_v2(b_handle,
|
||||
CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
N_updates, Dim, Lds,
|
||||
&alpha, Updates, Lds, Slater_inv, Lds,
|
||||
&beta, C, N_updates);
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// Construct B <- 1 + V C : K x K
|
||||
PUSH_RANGE("Construct B <- 1 + V C", 4)
|
||||
#pragma omp target teams distribute parallel for collapse(2)
|
||||
for (int i = 0; i < N_updates; ++i) {
|
||||
for (int j = 0; j < N_updates; ++j) {
|
||||
const uint32_t row = Updates_index[i] - 1;
|
||||
B[i * N_updates + j] = C[row * N_updates + j] + (i == j);
|
||||
}
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// Compute det(B) via LU(B)
|
||||
PUSH_RANGE("Compute LU(B)", 3)
|
||||
#pragma omp target data use_device_ptr(B, workspace, pivot)
|
||||
{
|
||||
(void) cusolverDnDgetrf(s_handle, N_updates, N_updates, B, N_updates, workspace, pivot,
|
||||
info); // col-maj enforced, so res. is LU(B)^T
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
det = 1.0f;
|
||||
PUSH_RANGE("Compute |det(B)|", 4)
|
||||
#pragma omp target teams distribute parallel for reduction(+: j) reduction(*: det)
|
||||
for (uint32_t i = 0; i < N_updates; i++) {
|
||||
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// Compute B^{-1} : initialise as I for solving BX=I
|
||||
PUSH_RANGE("Construct and init B^{-1} as Id", 4)
|
||||
#pragma omp target teams distribute parallel for collapse(2)
|
||||
for (int i = 0; i < N_updates; ++i) {
|
||||
for (int j = 0; j < N_updates; ++j) {
|
||||
Binv[i * N_updates + j] = (i == j);
|
||||
}
|
||||
}
|
||||
POP_RANGE
|
||||
PUSH_RANGE("Compute B^{-1} from LU(B)", 3)
|
||||
#pragma omp target data use_device_ptr(B, pivot, Binv)
|
||||
{
|
||||
(void) cusolverDnDgetrs(s_handle, CUBLAS_OP_T, N_updates, N_updates, B, N_updates,
|
||||
pivot, Binv, N_updates, info); // Needs op(B) = B^T because of line 403
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// Construct D = V S^{-1} : K x LDS
|
||||
PUSH_RANGE("Construct D = V S^{-1}", 4)
|
||||
#pragma omp target teams distribute parallel for collapse(2)
|
||||
for (uint32_t i = 0; i < N_updates; ++i) {
|
||||
for (uint32_t j = 0; j < Lds; ++j) {
|
||||
const uint32_t row = Updates_index[i] - 1;
|
||||
D[i * Lds + j] = Slater_inv[row * Lds + j];
|
||||
}
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// T1 <- B^{-1} D : KxLDS : standard dgemm
|
||||
PUSH_RANGE("Compute T1 <- B^{-1} D", 3)
|
||||
#pragma omp target data use_device_ptr(D, Binv, T1)
|
||||
{
|
||||
(void) cublasDgemm_v2(b_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_T, // REMEMBER THIS IS Binv TRANSPOSED because of cusolverDnDgetrs CALL ON l.434 !!!
|
||||
Lds, N_updates, N_updates,
|
||||
&alpha, D, Lds, Binv, N_updates,
|
||||
&beta, T1, Lds);
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
// Compute S^{-1} <- S^{-1} - C * T1 : Dim x LDS : standard dgemm
|
||||
alpha = -1.0f, beta = 1.0f;
|
||||
PUSH_RANGE("Compute S^{-1} <- S^{-1} - C * T1", 3)
|
||||
#pragma omp target data use_device_ptr(T1, C, Slater_inv)
|
||||
{
|
||||
(void) cublasDgemm_v2(b_handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
Dim, Lds, N_updates,
|
||||
&alpha, T1, Lds, C, N_updates,
|
||||
&beta, Slater_inv, Lds);
|
||||
}
|
||||
POP_RANGE
|
||||
}
|
||||
POP_RANGE
|
||||
|
||||
free(pivot);
|
||||
free(B);
|
||||
free(Binv);
|
||||
free(C);
|
||||
free(D);
|
||||
free(T1);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
@ -3,9 +3,10 @@
|
||||
#include <mkl_lapacke.h>
|
||||
#include <mkl.h>
|
||||
|
||||
#define HAVE_CUBLAS_OFFLOAD
|
||||
//#define USE_OMP
|
||||
//#define USE_OMP_OFFLOAD_CUDA
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
#include <stdio.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cusolverDn.h>
|
||||
@ -38,7 +39,8 @@ uint32_t qmckl_sherman_morrison_smw32s(
|
||||
double *__restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double *__restrict determinant);
|
||||
|
||||
uint32_t qmckl_woodbury_3(const uint64_t vLDS, const uint64_t vDim,
|
||||
uint32_t qmckl_woodbury_3(
|
||||
const uint64_t vLDS, const uint64_t vDim,
|
||||
const double *__restrict __attribute__((aligned(8)))
|
||||
Updates,
|
||||
const uint64_t *__restrict Updates_index,
|
||||
@ -47,7 +49,8 @@ uint32_t qmckl_woodbury_3(const uint64_t vLDS, const uint64_t vDim,
|
||||
Slater_inv,
|
||||
double *__restrict determinant);
|
||||
|
||||
uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
uint32_t qmckl_woodbury_k(
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
const double *__restrict __attribute__((aligned(8))) Updates,
|
||||
@ -56,8 +59,8 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
||||
double *__restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double *__restrict determinant);
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHandle_t s_handle,
|
||||
#ifdef USE_OMP
|
||||
uint32_t qmckl_woodbury_k_omp(
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
@ -68,7 +71,35 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
||||
double *__restrict determinant);
|
||||
#endif
|
||||
|
||||
uint32_t qmckl_woodbury_2(const uint64_t vLDS, const uint64_t vDim,
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
uint32_t qmckl_woodbury_k_ompol_cuda_async(
|
||||
cublasHandle_t b_handle,
|
||||
cusolverDnHandle_t s_handle,
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
const double *__restrict __attribute__((aligned(8))) Updates,
|
||||
const uint64_t *__restrict Updates_index,
|
||||
const double breakdown,
|
||||
double *__restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double *__restrict determinant);
|
||||
|
||||
uint32_t qmckl_woodbury_k_ompol_cuda_sync(
|
||||
cublasHandle_t b_handle,
|
||||
cusolverDnHandle_t s_handle,
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const uint64_t N_updates,
|
||||
const double *__restrict __attribute__((aligned(8))) Updates,
|
||||
const uint64_t *__restrict Updates_index,
|
||||
const double breakdown,
|
||||
double *__restrict __attribute__((aligned(8))) Slater_inv,
|
||||
double *__restrict determinant);
|
||||
#endif
|
||||
|
||||
uint32_t qmckl_woodbury_2(
|
||||
const uint64_t vLDS,
|
||||
const uint64_t vDim,
|
||||
const double *__restrict __attribute__((aligned(8)))
|
||||
Updates,
|
||||
const uint64_t *__restrict Updates_index,
|
||||
@ -77,7 +108,9 @@ uint32_t qmckl_woodbury_2(const uint64_t vLDS, const uint64_t vDim,
|
||||
Slater_inv,
|
||||
double *__restrict determinant);
|
||||
|
||||
void detupd(const uint64_t Dim, const uint64_t Lds,
|
||||
void detupd(
|
||||
const uint64_t Dim,
|
||||
const uint64_t Lds,
|
||||
const double *__restrict __attribute__((aligned(8))) Updates,
|
||||
const uint64_t *__restrict Updates_index,
|
||||
double *__restrict __attribute__((aligned(8))) Slater_inv,
|
||||
|
1499
independent_test_harness/nvToolsExt.h
Executable file
1499
independent_test_harness/nvToolsExt.h
Executable file
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#define USE_NVTX
|
||||
//#define USE_NVTX
|
||||
|
||||
#ifdef USE_NVTX
|
||||
#include "nvToolsExt.h"
|
||||
|
469
independent_test_harness/nvtxDetail/nvtxImpl.h
Executable file
469
independent_test_harness/nvtxDetail/nvtxImpl.h
Executable file
@ -0,0 +1,469 @@
|
||||
/* This file was procedurally generated! Do not modify this file by hand. */
|
||||
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
/* ---- Include required platform headers ---- */
|
||||
|
||||
#if defined(_WIN32)
|
||||
|
||||
#include <Windows.h>
|
||||
|
||||
#else
|
||||
#include <unistd.h>
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
#include <android/api-level.h>
|
||||
#endif
|
||||
|
||||
#if defined(__linux__) || defined(__CYGWIN__)
|
||||
#include <sched.h>
|
||||
#endif
|
||||
|
||||
#include <limits.h>
|
||||
#include <dlfcn.h>
|
||||
#include <fcntl.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <errno.h>
|
||||
|
||||
#include <string.h>
|
||||
#include <sys/types.h>
|
||||
#include <pthread.h>
|
||||
#include <stdlib.h>
|
||||
#include <wchar.h>
|
||||
|
||||
#endif
|
||||
|
||||
/* ---- Define macros used in this file ---- */
|
||||
|
||||
#define NVTX_INIT_STATE_FRESH 0
|
||||
#define NVTX_INIT_STATE_STARTED 1
|
||||
#define NVTX_INIT_STATE_COMPLETE 2
|
||||
|
||||
#ifdef NVTX_DEBUG_PRINT
|
||||
#ifdef __ANDROID__
|
||||
#include <android/log.h>
|
||||
#define NVTX_ERR(...) __android_log_print(ANDROID_LOG_ERROR, "NVTOOLSEXT", __VA_ARGS__);
|
||||
#define NVTX_INFO(...) __android_log_print(ANDROID_LOG_INFO, "NVTOOLSEXT", __VA_ARGS__);
|
||||
#else
|
||||
#include <stdio.h>
|
||||
#define NVTX_ERR(...) fprintf(stderr, "NVTX_ERROR: " __VA_ARGS__)
|
||||
#define NVTX_INFO(...) fprintf(stderr, "NVTX_INFO: " __VA_ARGS__)
|
||||
#endif
|
||||
#else /* !defined(NVTX_DEBUG_PRINT) */
|
||||
#define NVTX_ERR(...)
|
||||
#define NVTX_INFO(...)
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC visibility push(hidden)
|
||||
#endif
|
||||
|
||||
/* ---- Forward declare all functions referenced in globals ---- */
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)(void);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiGetModuleFunctionTable)(
|
||||
NvtxCallbackModule module,
|
||||
NvtxFunctionTable* out_table,
|
||||
unsigned int* out_size);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiSetInjectionNvtxVersion)(
|
||||
uint32_t version);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION const void* NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxGetExportTable)(
|
||||
uint32_t exportTableId);
|
||||
|
||||
#include "nvtxInitDecls.h"
|
||||
|
||||
/* ---- Define all globals ---- */
|
||||
|
||||
typedef struct nvtxGlobals_t
|
||||
{
|
||||
volatile unsigned int initState;
|
||||
NvtxExportTableCallbacks etblCallbacks;
|
||||
NvtxExportTableVersionInfo etblVersionInfo;
|
||||
|
||||
/* Implementation function pointers */
|
||||
nvtxMarkEx_impl_fntype nvtxMarkEx_impl_fnptr;
|
||||
nvtxMarkA_impl_fntype nvtxMarkA_impl_fnptr;
|
||||
nvtxMarkW_impl_fntype nvtxMarkW_impl_fnptr;
|
||||
nvtxRangeStartEx_impl_fntype nvtxRangeStartEx_impl_fnptr;
|
||||
nvtxRangeStartA_impl_fntype nvtxRangeStartA_impl_fnptr;
|
||||
nvtxRangeStartW_impl_fntype nvtxRangeStartW_impl_fnptr;
|
||||
nvtxRangeEnd_impl_fntype nvtxRangeEnd_impl_fnptr;
|
||||
nvtxRangePushEx_impl_fntype nvtxRangePushEx_impl_fnptr;
|
||||
nvtxRangePushA_impl_fntype nvtxRangePushA_impl_fnptr;
|
||||
nvtxRangePushW_impl_fntype nvtxRangePushW_impl_fnptr;
|
||||
nvtxRangePop_impl_fntype nvtxRangePop_impl_fnptr;
|
||||
nvtxNameCategoryA_impl_fntype nvtxNameCategoryA_impl_fnptr;
|
||||
nvtxNameCategoryW_impl_fntype nvtxNameCategoryW_impl_fnptr;
|
||||
nvtxNameOsThreadA_impl_fntype nvtxNameOsThreadA_impl_fnptr;
|
||||
nvtxNameOsThreadW_impl_fntype nvtxNameOsThreadW_impl_fnptr;
|
||||
|
||||
nvtxNameCuDeviceA_fakeimpl_fntype nvtxNameCuDeviceA_impl_fnptr;
|
||||
nvtxNameCuDeviceW_fakeimpl_fntype nvtxNameCuDeviceW_impl_fnptr;
|
||||
nvtxNameCuContextA_fakeimpl_fntype nvtxNameCuContextA_impl_fnptr;
|
||||
nvtxNameCuContextW_fakeimpl_fntype nvtxNameCuContextW_impl_fnptr;
|
||||
nvtxNameCuStreamA_fakeimpl_fntype nvtxNameCuStreamA_impl_fnptr;
|
||||
nvtxNameCuStreamW_fakeimpl_fntype nvtxNameCuStreamW_impl_fnptr;
|
||||
nvtxNameCuEventA_fakeimpl_fntype nvtxNameCuEventA_impl_fnptr;
|
||||
nvtxNameCuEventW_fakeimpl_fntype nvtxNameCuEventW_impl_fnptr;
|
||||
|
||||
nvtxNameClDeviceA_fakeimpl_fntype nvtxNameClDeviceA_impl_fnptr;
|
||||
nvtxNameClDeviceW_fakeimpl_fntype nvtxNameClDeviceW_impl_fnptr;
|
||||
nvtxNameClContextA_fakeimpl_fntype nvtxNameClContextA_impl_fnptr;
|
||||
nvtxNameClContextW_fakeimpl_fntype nvtxNameClContextW_impl_fnptr;
|
||||
nvtxNameClCommandQueueA_fakeimpl_fntype nvtxNameClCommandQueueA_impl_fnptr;
|
||||
nvtxNameClCommandQueueW_fakeimpl_fntype nvtxNameClCommandQueueW_impl_fnptr;
|
||||
nvtxNameClMemObjectA_fakeimpl_fntype nvtxNameClMemObjectA_impl_fnptr;
|
||||
nvtxNameClMemObjectW_fakeimpl_fntype nvtxNameClMemObjectW_impl_fnptr;
|
||||
nvtxNameClSamplerA_fakeimpl_fntype nvtxNameClSamplerA_impl_fnptr;
|
||||
nvtxNameClSamplerW_fakeimpl_fntype nvtxNameClSamplerW_impl_fnptr;
|
||||
nvtxNameClProgramA_fakeimpl_fntype nvtxNameClProgramA_impl_fnptr;
|
||||
nvtxNameClProgramW_fakeimpl_fntype nvtxNameClProgramW_impl_fnptr;
|
||||
nvtxNameClEventA_fakeimpl_fntype nvtxNameClEventA_impl_fnptr;
|
||||
nvtxNameClEventW_fakeimpl_fntype nvtxNameClEventW_impl_fnptr;
|
||||
|
||||
nvtxNameCudaDeviceA_impl_fntype nvtxNameCudaDeviceA_impl_fnptr;
|
||||
nvtxNameCudaDeviceW_impl_fntype nvtxNameCudaDeviceW_impl_fnptr;
|
||||
nvtxNameCudaStreamA_fakeimpl_fntype nvtxNameCudaStreamA_impl_fnptr;
|
||||
nvtxNameCudaStreamW_fakeimpl_fntype nvtxNameCudaStreamW_impl_fnptr;
|
||||
nvtxNameCudaEventA_fakeimpl_fntype nvtxNameCudaEventA_impl_fnptr;
|
||||
nvtxNameCudaEventW_fakeimpl_fntype nvtxNameCudaEventW_impl_fnptr;
|
||||
|
||||
nvtxDomainMarkEx_impl_fntype nvtxDomainMarkEx_impl_fnptr;
|
||||
nvtxDomainRangeStartEx_impl_fntype nvtxDomainRangeStartEx_impl_fnptr;
|
||||
nvtxDomainRangeEnd_impl_fntype nvtxDomainRangeEnd_impl_fnptr;
|
||||
nvtxDomainRangePushEx_impl_fntype nvtxDomainRangePushEx_impl_fnptr;
|
||||
nvtxDomainRangePop_impl_fntype nvtxDomainRangePop_impl_fnptr;
|
||||
nvtxDomainResourceCreate_impl_fntype nvtxDomainResourceCreate_impl_fnptr;
|
||||
nvtxDomainResourceDestroy_impl_fntype nvtxDomainResourceDestroy_impl_fnptr;
|
||||
nvtxDomainNameCategoryA_impl_fntype nvtxDomainNameCategoryA_impl_fnptr;
|
||||
nvtxDomainNameCategoryW_impl_fntype nvtxDomainNameCategoryW_impl_fnptr;
|
||||
nvtxDomainRegisterStringA_impl_fntype nvtxDomainRegisterStringA_impl_fnptr;
|
||||
nvtxDomainRegisterStringW_impl_fntype nvtxDomainRegisterStringW_impl_fnptr;
|
||||
nvtxDomainCreateA_impl_fntype nvtxDomainCreateA_impl_fnptr;
|
||||
nvtxDomainCreateW_impl_fntype nvtxDomainCreateW_impl_fnptr;
|
||||
nvtxDomainDestroy_impl_fntype nvtxDomainDestroy_impl_fnptr;
|
||||
nvtxInitialize_impl_fntype nvtxInitialize_impl_fnptr;
|
||||
|
||||
nvtxDomainSyncUserCreate_impl_fntype nvtxDomainSyncUserCreate_impl_fnptr;
|
||||
nvtxDomainSyncUserDestroy_impl_fntype nvtxDomainSyncUserDestroy_impl_fnptr;
|
||||
nvtxDomainSyncUserAcquireStart_impl_fntype nvtxDomainSyncUserAcquireStart_impl_fnptr;
|
||||
nvtxDomainSyncUserAcquireFailed_impl_fntype nvtxDomainSyncUserAcquireFailed_impl_fnptr;
|
||||
nvtxDomainSyncUserAcquireSuccess_impl_fntype nvtxDomainSyncUserAcquireSuccess_impl_fnptr;
|
||||
nvtxDomainSyncUserReleasing_impl_fntype nvtxDomainSyncUserReleasing_impl_fnptr;
|
||||
|
||||
/* Tables of function pointers -- Extra null added to the end to ensure
|
||||
* a crash instead of silent corruption if a tool reads off the end. */
|
||||
NvtxFunctionPointer* functionTable_CORE [NVTX_CBID_CORE_SIZE + 1];
|
||||
NvtxFunctionPointer* functionTable_CUDA [NVTX_CBID_CUDA_SIZE + 1];
|
||||
NvtxFunctionPointer* functionTable_OPENCL[NVTX_CBID_OPENCL_SIZE + 1];
|
||||
NvtxFunctionPointer* functionTable_CUDART[NVTX_CBID_CUDART_SIZE + 1];
|
||||
NvtxFunctionPointer* functionTable_CORE2 [NVTX_CBID_CORE2_SIZE + 1];
|
||||
NvtxFunctionPointer* functionTable_SYNC [NVTX_CBID_SYNC_SIZE + 1];
|
||||
} nvtxGlobals_t;
|
||||
|
||||
NVTX_LINKONCE_DEFINE_GLOBAL nvtxGlobals_t NVTX_VERSIONED_IDENTIFIER(nvtxGlobals) =
|
||||
{
|
||||
NVTX_INIT_STATE_FRESH,
|
||||
|
||||
{
|
||||
sizeof(NvtxExportTableCallbacks),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxEtiGetModuleFunctionTable)
|
||||
},
|
||||
{
|
||||
sizeof(NvtxExportTableVersionInfo),
|
||||
NVTX_VERSION,
|
||||
0,
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxEtiSetInjectionNvtxVersion)
|
||||
},
|
||||
|
||||
/* Implementation function pointers */
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init),
|
||||
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init),
|
||||
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init),
|
||||
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init),
|
||||
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init),
|
||||
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init),
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init),
|
||||
|
||||
/* Tables of function pointers */
|
||||
{
|
||||
0,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr,
|
||||
0
|
||||
},
|
||||
{
|
||||
0,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr,
|
||||
0
|
||||
},
|
||||
{
|
||||
0,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr,
|
||||
0
|
||||
},
|
||||
{
|
||||
0,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr,
|
||||
0
|
||||
},
|
||||
{
|
||||
0,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr,
|
||||
0
|
||||
},
|
||||
{
|
||||
0,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr,
|
||||
(NvtxFunctionPointer*)&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr,
|
||||
0
|
||||
}
|
||||
};
|
||||
|
||||
/* ---- Define static inline implementations of core API functions ---- */
|
||||
|
||||
#include "nvtxImplCore.h"
|
||||
|
||||
/* ---- Define implementations of export table functions ---- */
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiGetModuleFunctionTable)(
|
||||
NvtxCallbackModule module,
|
||||
NvtxFunctionTable* out_table,
|
||||
unsigned int* out_size)
|
||||
{
|
||||
unsigned int bytes = 0;
|
||||
NvtxFunctionTable table = (NvtxFunctionTable)0;
|
||||
|
||||
switch (module)
|
||||
{
|
||||
case NVTX_CB_MODULE_CORE:
|
||||
table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE;
|
||||
bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE);
|
||||
break;
|
||||
case NVTX_CB_MODULE_CUDA:
|
||||
table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDA;
|
||||
bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDA);
|
||||
break;
|
||||
case NVTX_CB_MODULE_OPENCL:
|
||||
table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_OPENCL;
|
||||
bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_OPENCL);
|
||||
break;
|
||||
case NVTX_CB_MODULE_CUDART:
|
||||
table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDART;
|
||||
bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CUDART);
|
||||
break;
|
||||
case NVTX_CB_MODULE_CORE2:
|
||||
table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE2;
|
||||
bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_CORE2);
|
||||
break;
|
||||
case NVTX_CB_MODULE_SYNC:
|
||||
table = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_SYNC;
|
||||
bytes = (unsigned int)sizeof(NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).functionTable_SYNC);
|
||||
break;
|
||||
default: return 0;
|
||||
}
|
||||
|
||||
if (out_size)
|
||||
*out_size = (bytes / (unsigned int)sizeof(NvtxFunctionPointer*)) - 1;
|
||||
|
||||
if (out_table)
|
||||
*out_table = table;
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION const void* NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxGetExportTable)(uint32_t exportTableId)
|
||||
{
|
||||
switch (exportTableId)
|
||||
{
|
||||
case NVTX_ETID_CALLBACKS: return &NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).etblCallbacks;
|
||||
case NVTX_ETID_VERSIONINFO: return &NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).etblVersionInfo;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxEtiSetInjectionNvtxVersion)(uint32_t version)
|
||||
{
|
||||
/* Reserved for custom implementations to resolve problems with tools */
|
||||
(void)version;
|
||||
}
|
||||
|
||||
/* ---- Define implementations of init versions of all API functions ---- */
|
||||
|
||||
#include "nvtxInitDefs.h"
|
||||
|
||||
/* ---- Define implementations of initialization functions ---- */
|
||||
|
||||
#include "nvtxInit.h"
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC visibility pop
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif /* __cplusplus */
|
299
independent_test_harness/nvtxDetail/nvtxImplCore.h
Executable file
299
independent_test_harness/nvtxDetail/nvtxImplCore.h
Executable file
@ -0,0 +1,299 @@
|
||||
NVTX_DECLSPEC void NVTX_API nvtxMarkEx(const nvtxEventAttributes_t* eventAttrib)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxMarkEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(eventAttrib);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxMarkA(const char* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxMarkA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(message);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxMarkW(const wchar_t* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxMarkW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(message);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartEx(const nvtxEventAttributes_t* eventAttrib)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangeStartEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(eventAttrib);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxRangeId_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartA(const char* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangeStartA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(message);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxRangeId_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxRangeStartW(const wchar_t* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangeStartW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(message);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxRangeId_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxRangeEnd(nvtxRangeId_t id)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangeEnd_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(id);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC int NVTX_API nvtxRangePushEx(const nvtxEventAttributes_t* eventAttrib)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangePushEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(eventAttrib);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (int)NVTX_NO_PUSH_POP_TRACKING;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC int NVTX_API nvtxRangePushA(const char* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangePushA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(message);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (int)NVTX_NO_PUSH_POP_TRACKING;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC int NVTX_API nvtxRangePushW(const wchar_t* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangePushW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(message);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (int)NVTX_NO_PUSH_POP_TRACKING;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC int NVTX_API nvtxRangePop(void)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxRangePop_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)();
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (int)NVTX_NO_PUSH_POP_TRACKING;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCategoryA(uint32_t category, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCategoryA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(category, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCategoryW(uint32_t category, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCategoryW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(category, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameOsThreadA(uint32_t threadId, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameOsThreadA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(threadId, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameOsThreadW(uint32_t threadId, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameOsThreadW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(threadId, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainMarkEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainMarkEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(domain, eventAttrib);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxRangeId_t NVTX_API nvtxDomainRangeStartEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainRangeStartEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain, eventAttrib);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxRangeId_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainRangeEnd(nvtxDomainHandle_t domain, nvtxRangeId_t id)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainRangeEnd_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(domain, id);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC int NVTX_API nvtxDomainRangePushEx(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainRangePushEx_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain, eventAttrib);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (int)NVTX_NO_PUSH_POP_TRACKING;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC int NVTX_API nvtxDomainRangePop(nvtxDomainHandle_t domain)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainRangePop_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (int)NVTX_NO_PUSH_POP_TRACKING;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxResourceHandle_t NVTX_API nvtxDomainResourceCreate(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainResourceCreate_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain, attribs);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxResourceHandle_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainResourceDestroy(nvtxResourceHandle_t resource)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainResourceDestroy_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(resource);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainNameCategoryA(nvtxDomainHandle_t domain, uint32_t category, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainNameCategoryA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(domain, category, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainNameCategoryW(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainNameCategoryW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(domain, category, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxStringHandle_t NVTX_API nvtxDomainRegisterStringA(nvtxDomainHandle_t domain, const char* string)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainRegisterStringA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain, string);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxStringHandle_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxStringHandle_t NVTX_API nvtxDomainRegisterStringW(nvtxDomainHandle_t domain, const wchar_t* string)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainRegisterStringW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain, string);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxStringHandle_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxDomainHandle_t NVTX_API nvtxDomainCreateA(const char* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainCreateA_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(message);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxDomainHandle_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC nvtxDomainHandle_t NVTX_API nvtxDomainCreateW(const wchar_t* message)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainCreateW_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(message);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxDomainHandle_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainDestroy(nvtxDomainHandle_t domain)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainDestroy_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(domain);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxInitialize(const void* reserved)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxInitialize_impl_fntype local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(reserved);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
112
independent_test_harness/nvtxDetail/nvtxImplCudaRt_v3.h
Executable file
112
independent_test_harness/nvtxDetail/nvtxImplCudaRt_v3.h
Executable file
@ -0,0 +1,112 @@
|
||||
/* This file was procedurally generated! Do not modify this file by hand. */
|
||||
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD_CUDART
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExtCudaRt.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
//typedef void (NVTX_API * nvtxNameCudaDeviceA_impl_fntype)(int device, const char* name);
|
||||
//typedef void (NVTX_API * nvtxNameCudaDeviceW_impl_fntype)(int device, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaStreamA_impl_fntype)(cudaStream_t stream, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaStreamW_impl_fntype)(cudaStream_t stream, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaEventA_impl_fntype)(cudaEvent_t event, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaEventW_impl_fntype)(cudaEvent_t event, const wchar_t* name);
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCudaDeviceA(int device, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCudaDeviceA_impl_fntype local = (nvtxNameCudaDeviceA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(device, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCudaDeviceW(int device, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCudaDeviceW_impl_fntype local = (nvtxNameCudaDeviceW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(device, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCudaStreamA(cudaStream_t stream, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCudaStreamA_impl_fntype local = (nvtxNameCudaStreamA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(stream, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCudaStreamW(cudaStream_t stream, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCudaStreamW_impl_fntype local = (nvtxNameCudaStreamW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(stream, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCudaEventA(cudaEvent_t event, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCudaEventA_impl_fntype local = (nvtxNameCudaEventA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(event, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCudaEventW(cudaEvent_t event, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCudaEventW_impl_fntype local = (nvtxNameCudaEventW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(event, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif /* __cplusplus */
|
||||
|
133
independent_test_harness/nvtxDetail/nvtxImplCuda_v3.h
Executable file
133
independent_test_harness/nvtxDetail/nvtxImplCuda_v3.h
Executable file
@ -0,0 +1,133 @@
|
||||
/* This file was procedurally generated! Do not modify this file by hand. */
|
||||
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD_CUDA
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExtCuda.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
typedef void (NVTX_API * nvtxNameCuDeviceA_impl_fntype)(CUdevice device, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuDeviceW_impl_fntype)(CUdevice device, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCuContextA_impl_fntype)(CUcontext context, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuContextW_impl_fntype)(CUcontext context, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCuStreamA_impl_fntype)(CUstream stream, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuStreamW_impl_fntype)(CUstream stream, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCuEventA_impl_fntype)(CUevent event, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuEventW_impl_fntype)(CUevent event, const wchar_t* name);
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuDeviceA(CUdevice device, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuDeviceA_impl_fntype local = (nvtxNameCuDeviceA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(device, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuDeviceW(CUdevice device, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuDeviceW_impl_fntype local = (nvtxNameCuDeviceW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(device, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuContextA(CUcontext context, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuContextA_impl_fntype local = (nvtxNameCuContextA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(context, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuContextW(CUcontext context, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuContextW_impl_fntype local = (nvtxNameCuContextW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(context, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuStreamA(CUstream stream, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuStreamA_impl_fntype local = (nvtxNameCuStreamA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(stream, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuStreamW(CUstream stream, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuStreamW_impl_fntype local = (nvtxNameCuStreamW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(stream, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuEventA(CUevent event, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuEventA_impl_fntype local = (nvtxNameCuEventA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(event, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameCuEventW(CUevent event, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameCuEventW_impl_fntype local = (nvtxNameCuEventW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(event, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif /* __cplusplus */
|
||||
|
192
independent_test_harness/nvtxDetail/nvtxImplOpenCL_v3.h
Executable file
192
independent_test_harness/nvtxDetail/nvtxImplOpenCL_v3.h
Executable file
@ -0,0 +1,192 @@
|
||||
/* This file was procedurally generated! Do not modify this file by hand. */
|
||||
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD_OPENCL
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExtCuda.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
typedef void (NVTX_API * nvtxNameClDeviceA_impl_fntype)(cl_device_id device, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClDeviceW_impl_fntype)(cl_device_id device, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClContextA_impl_fntype)(cl_context context, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClContextW_impl_fntype)(cl_context context, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClCommandQueueA_impl_fntype)(cl_command_queue command_queue, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClCommandQueueW_impl_fntype)(cl_command_queue command_queue, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClMemObjectA_impl_fntype)(cl_mem memobj, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClMemObjectW_impl_fntype)(cl_mem memobj, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClSamplerA_impl_fntype)(cl_sampler sampler, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClSamplerW_impl_fntype)(cl_sampler sampler, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClProgramA_impl_fntype)(cl_program program, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClProgramW_impl_fntype)(cl_program program, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClEventA_impl_fntype)(cl_event evnt, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClEventW_impl_fntype)(cl_event evnt, const wchar_t* name);
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClDeviceA(cl_device_id device, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClDeviceA_impl_fntype local = (nvtxNameClDeviceA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(device, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClDeviceW(cl_device_id device, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClDeviceW_impl_fntype local = (nvtxNameClDeviceW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(device, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClContextA(cl_context context, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClContextA_impl_fntype local = (nvtxNameClContextA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(context, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClContextW(cl_context context, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClContextW_impl_fntype local = (nvtxNameClContextW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(context, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClCommandQueueA(cl_command_queue command_queue, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClCommandQueueA_impl_fntype local = (nvtxNameClCommandQueueA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(command_queue, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClCommandQueueW(cl_command_queue command_queue, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClCommandQueueW_impl_fntype local = (nvtxNameClCommandQueueW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(command_queue, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClMemObjectA(cl_mem memobj, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClMemObjectA_impl_fntype local = (nvtxNameClMemObjectA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(memobj, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClMemObjectW(cl_mem memobj, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClMemObjectW_impl_fntype local = (nvtxNameClMemObjectW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(memobj, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClSamplerA(cl_sampler sampler, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClSamplerA_impl_fntype local = (nvtxNameClSamplerA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(sampler, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClSamplerW(cl_sampler sampler, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClSamplerW_impl_fntype local = (nvtxNameClSamplerW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(sampler, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClProgramA(cl_program program, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClProgramA_impl_fntype local = (nvtxNameClProgramA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(program, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClProgramW(cl_program program, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClProgramW_impl_fntype local = (nvtxNameClProgramW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(program, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClEventA(cl_event evnt, const char* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClEventA_impl_fntype local = (nvtxNameClEventA_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(evnt, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxNameClEventW(cl_event evnt, const wchar_t* name)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxNameClEventW_impl_fntype local = (nvtxNameClEventW_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(evnt, name);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif /* __cplusplus */
|
114
independent_test_harness/nvtxDetail/nvtxImplSync_v3.h
Executable file
114
independent_test_harness/nvtxDetail/nvtxImplSync_v3.h
Executable file
@ -0,0 +1,114 @@
|
||||
/* This file was procedurally generated! Do not modify this file by hand. */
|
||||
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD_SYNC
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExtCuda.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
typedef nvtxSyncUser_t (NVTX_API * nvtxDomainSyncUserCreate_impl_fntype)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserDestroy_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserAcquireStart_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserAcquireFailed_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserAcquireSuccess_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserReleasing_impl_fntype)(nvtxSyncUser_t handle);
|
||||
|
||||
NVTX_DECLSPEC nvtxSyncUser_t NVTX_API nvtxDomainSyncUserCreate(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainSyncUserCreate_impl_fntype local = (nvtxDomainSyncUserCreate_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr;
|
||||
if(local!=0)
|
||||
return (*local)(domain, attribs);
|
||||
else
|
||||
#endif /*NVTX_DISABLE*/
|
||||
return (nvtxSyncUser_t)0;
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserDestroy(nvtxSyncUser_t handle)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainSyncUserDestroy_impl_fntype local = (nvtxDomainSyncUserDestroy_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(handle);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireStart(nvtxSyncUser_t handle)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainSyncUserAcquireStart_impl_fntype local = (nvtxDomainSyncUserAcquireStart_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(handle);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireFailed(nvtxSyncUser_t handle)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainSyncUserAcquireFailed_impl_fntype local = (nvtxDomainSyncUserAcquireFailed_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(handle);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserAcquireSuccess(nvtxSyncUser_t handle)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainSyncUserAcquireSuccess_impl_fntype local = (nvtxDomainSyncUserAcquireSuccess_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(handle);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
NVTX_DECLSPEC void NVTX_API nvtxDomainSyncUserReleasing(nvtxSyncUser_t handle)
|
||||
{
|
||||
#ifndef NVTX_DISABLE
|
||||
nvtxDomainSyncUserReleasing_impl_fntype local = (nvtxDomainSyncUserReleasing_impl_fntype)NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr;
|
||||
if(local!=0)
|
||||
(*local)(handle);
|
||||
#endif /*NVTX_DISABLE*/
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif /* __cplusplus */
|
343
independent_test_harness/nvtxDetail/nvtxInit.h
Executable file
343
independent_test_harness/nvtxDetail/nvtxInit.h
Executable file
@ -0,0 +1,343 @@
|
||||
/* This file was procedurally generated! Do not modify this file by hand. */
|
||||
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
/* ---- Platform-independent helper definitions and functions ---- */
|
||||
|
||||
/* Prefer macros over inline functions to reduce symbol resolution at link time */
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define NVTX_PATHCHAR wchar_t
|
||||
#define NVTX_STR(x) L##x
|
||||
#define NVTX_GETENV _wgetenv
|
||||
#define NVTX_BUFSIZE MAX_PATH
|
||||
#define NVTX_DLLHANDLE HMODULE
|
||||
#define NVTX_DLLOPEN(x) LoadLibraryW(x)
|
||||
#define NVTX_DLLFUNC GetProcAddress
|
||||
#define NVTX_DLLCLOSE FreeLibrary
|
||||
#define NVTX_YIELD() SwitchToThread()
|
||||
#define NVTX_MEMBAR() MemoryBarrier()
|
||||
#define NVTX_ATOMIC_WRITE_32(address, value) InterlockedExchange((volatile LONG*)address, value)
|
||||
#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) old = InterlockedCompareExchange((volatile LONG*)address, exchange, comparand)
|
||||
#elif defined(__GNUC__)
|
||||
#define NVTX_PATHCHAR char
|
||||
#define NVTX_STR(x) x
|
||||
#define NVTX_GETENV getenv
|
||||
#define NVTX_BUFSIZE PATH_MAX
|
||||
#define NVTX_DLLHANDLE void*
|
||||
#define NVTX_DLLOPEN(x) dlopen(x, RTLD_LAZY)
|
||||
#define NVTX_DLLFUNC dlsym
|
||||
#define NVTX_DLLCLOSE dlclose
|
||||
#define NVTX_YIELD() sched_yield()
|
||||
#define NVTX_MEMBAR() __sync_synchronize()
|
||||
/* Ensure full memory barrier for atomics, to match Windows functions */
|
||||
#define NVTX_ATOMIC_WRITE_32(address, value) __sync_synchronize(); __sync_lock_test_and_set(address, value)
|
||||
#define NVTX_ATOMIC_CAS_32(old, address, exchange, comparand) __sync_synchronize(); old = __sync_val_compare_and_swap(address, exchange, comparand)
|
||||
#else
|
||||
#error The library does not support your configuration!
|
||||
#endif
|
||||
|
||||
/* Define this to 1 for platforms that where pre-injected libraries can be discovered. */
|
||||
#if defined(_WIN32)
|
||||
/* TODO */
|
||||
#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0
|
||||
#else
|
||||
#define NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY 0
|
||||
#endif
|
||||
|
||||
/* Define this to 1 for platforms that support environment variables */
|
||||
/* TODO: Detect UWP, a.k.a. Windows Store app, and set this to 0. */
|
||||
/* Try: #if defined(WINAPI_FAMILY_PARTITION) && WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) */
|
||||
#define NVTX_SUPPORT_ENV_VARS 1
|
||||
|
||||
/* Define this to 1 for platforms that support dynamic/shared libraries */
|
||||
#define NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY 1
|
||||
|
||||
/* Injection libraries implementing InitializeInjectionNvtx2 may be statically linked,
|
||||
* and this will override any dynamic injection. Useful for platforms where dynamic
|
||||
* injection is not available. Since weak symbols not explicitly marked extern are
|
||||
* guaranteed to be initialized to zero if no definitions are found by the linker, the
|
||||
* dynamic injection process proceeds normally if pfnInitializeInjectionNvtx2 is 0. */
|
||||
#if defined(__GNUC__) && !defined(_WIN32) && !defined(__CYGWIN__)
|
||||
#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 1
|
||||
/* To statically inject an NVTX library, define InitializeInjectionNvtx2_fnptr as a normal
|
||||
* symbol (not weak) pointing to the implementation of InitializeInjectionNvtx2 (which
|
||||
* does not need to be named "InitializeInjectionNvtx2" as is necessary in a dynamic
|
||||
* injection library. */
|
||||
__attribute__((weak)) NvtxInitializeInjectionNvtxFunc_t InitializeInjectionNvtx2_fnptr;
|
||||
#else
|
||||
#define NVTX_SUPPORT_STATIC_INJECTION_LIBRARY 0
|
||||
#endif
|
||||
|
||||
/* This function tries to find or load an NVTX injection library and get the
|
||||
* address of its InitializeInjection2 function. If such a function pointer
|
||||
* is found, it is called, and passed the address of this NVTX instance's
|
||||
* nvtxGetExportTable function, so the injection can attach to this instance.
|
||||
* If the initialization fails for any reason, any dynamic library loaded will
|
||||
* be freed, and all NVTX implementation functions will be set to no-ops. If
|
||||
* initialization succeeds, NVTX functions not attached to the tool will be set
|
||||
* to no-ops. This is implemented as one function instead of several small
|
||||
* functions to minimize the number of weak symbols the linker must resolve.
|
||||
* Order of search is:
|
||||
* - Pre-injected library exporting InitializeInjectionNvtx2
|
||||
* - Loadable library exporting InitializeInjectionNvtx2
|
||||
* - Path specified by env var NVTX_INJECTION??_PATH (?? is 32 or 64)
|
||||
* - On Android, libNvtxInjection??.so within the package (?? is 32 or 64)
|
||||
* - Statically-linked injection library defining InitializeInjectionNvtx2_fnptr
|
||||
*/
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxInitializeInjectionLibrary)(void);
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_VERSIONED_IDENTIFIER(nvtxInitializeInjectionLibrary)(void)
|
||||
{
|
||||
const char* const initFuncName = "InitializeInjectionNvtx2";
|
||||
NvtxInitializeInjectionNvtxFunc_t init_fnptr = (NvtxInitializeInjectionNvtxFunc_t)0;
|
||||
NVTX_DLLHANDLE injectionLibraryHandle = (NVTX_DLLHANDLE)0;
|
||||
int entryPointStatus = 0;
|
||||
|
||||
#if NVTX_SUPPORT_ALREADY_INJECTED_LIBRARY
|
||||
/* Use POSIX global symbol chain to query for init function from any module */
|
||||
init_fnptr = (NvtxInitializeInjectionNvtxFunc_t)NVTX_DLLFUNC(0, initFuncName);
|
||||
#endif
|
||||
|
||||
#if NVTX_SUPPORT_DYNAMIC_INJECTION_LIBRARY
|
||||
/* Try discovering dynamic injection library to load */
|
||||
if (!init_fnptr)
|
||||
{
|
||||
#if NVTX_SUPPORT_ENV_VARS
|
||||
/* If env var NVTX_INJECTION64_PATH is set, it should contain the path
|
||||
* to a 64-bit dynamic NVTX injection library (and similar for 32-bit). */
|
||||
const NVTX_PATHCHAR* const nvtxEnvVarName = (sizeof(void*) == 4)
|
||||
? NVTX_STR("NVTX_INJECTION32_PATH")
|
||||
: NVTX_STR("NVTX_INJECTION64_PATH");
|
||||
#endif /* NVTX_SUPPORT_ENV_VARS */
|
||||
NVTX_PATHCHAR injectionLibraryPathBuf[NVTX_BUFSIZE];
|
||||
const NVTX_PATHCHAR* injectionLibraryPath = (const NVTX_PATHCHAR*)0;
|
||||
|
||||
/* Refer to this variable explicitly in case all references to it are #if'ed out */
|
||||
(void)injectionLibraryPathBuf;
|
||||
|
||||
#if NVTX_SUPPORT_ENV_VARS
|
||||
/* Disable the warning for getenv & _wgetenv -- this usage is safe because
|
||||
* these functions are not called again before using the returned value. */
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning( push )
|
||||
#pragma warning( disable : 4996 )
|
||||
#endif
|
||||
injectionLibraryPath = NVTX_GETENV(nvtxEnvVarName);
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning( pop )
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__ANDROID__)
|
||||
if (!injectionLibraryPath)
|
||||
{
|
||||
const char *bits = (sizeof(void*) == 4) ? "32" : "64";
|
||||
char cmdlineBuf[32];
|
||||
char pkgName[PATH_MAX];
|
||||
int count;
|
||||
int pid;
|
||||
FILE *fp;
|
||||
size_t bytesRead;
|
||||
size_t pos;
|
||||
|
||||
pid = (int)getpid();
|
||||
count = snprintf(cmdlineBuf, sizeof(cmdlineBuf), "/proc/%d/cmdline", pid);
|
||||
if (count <= 0 || count >= (int)sizeof(cmdlineBuf))
|
||||
{
|
||||
NVTX_ERR("Path buffer too small for: /proc/%d/cmdline\n", pid);
|
||||
return NVTX_ERR_INIT_ACCESS_LIBRARY;
|
||||
}
|
||||
|
||||
fp = fopen(cmdlineBuf, "r");
|
||||
if (!fp)
|
||||
{
|
||||
NVTX_ERR("File couldn't be opened: %s\n", cmdlineBuf);
|
||||
return NVTX_ERR_INIT_ACCESS_LIBRARY;
|
||||
}
|
||||
|
||||
bytesRead = fread(pkgName, 1, sizeof(pkgName) - 1, fp);
|
||||
fclose(fp);
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
NVTX_ERR("Package name couldn't be read from file: %s\n", cmdlineBuf);
|
||||
return NVTX_ERR_INIT_ACCESS_LIBRARY;
|
||||
}
|
||||
|
||||
pkgName[bytesRead] = 0;
|
||||
|
||||
/* String can contain colon as a process separator. In this case the package name is before the colon. */
|
||||
pos = 0;
|
||||
while (pos < bytesRead && pkgName[pos] != ':' && pkgName[pos] != '\0')
|
||||
{
|
||||
++pos;
|
||||
}
|
||||
pkgName[pos] = 0;
|
||||
|
||||
count = snprintf(injectionLibraryPathBuf, NVTX_BUFSIZE, "/data/data/%s/files/libNvtxInjection%s.so", pkgName, bits);
|
||||
if (count <= 0 || count >= NVTX_BUFSIZE)
|
||||
{
|
||||
NVTX_ERR("Path buffer too small for: /data/data/%s/files/libNvtxInjection%s.so\n", pkgName, bits);
|
||||
return NVTX_ERR_INIT_ACCESS_LIBRARY;
|
||||
}
|
||||
|
||||
/* On Android, verify path is accessible due to aggressive file access restrictions. */
|
||||
/* For dlopen, if the filename contains a leading slash, then it is interpreted as a */
|
||||
/* relative or absolute pathname; otherwise it will follow the rules in ld.so. */
|
||||
if (injectionLibraryPathBuf[0] == '/')
|
||||
{
|
||||
#if (__ANDROID_API__ < 21)
|
||||
int access_err = access(injectionLibraryPathBuf, F_OK | R_OK);
|
||||
#else
|
||||
int access_err = faccessat(AT_FDCWD, injectionLibraryPathBuf, F_OK | R_OK, 0);
|
||||
#endif
|
||||
if (access_err != 0)
|
||||
{
|
||||
NVTX_ERR("Injection library path wasn't accessible [code=%s] [path=%s]\n", strerror(errno), injectionLibraryPathBuf);
|
||||
return NVTX_ERR_INIT_ACCESS_LIBRARY;
|
||||
}
|
||||
}
|
||||
injectionLibraryPath = injectionLibraryPathBuf;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* At this point, injectionLibraryPath is specified if a dynamic
|
||||
* injection library was specified by a tool. */
|
||||
if (injectionLibraryPath)
|
||||
{
|
||||
/* Load the injection library */
|
||||
injectionLibraryHandle = NVTX_DLLOPEN(injectionLibraryPath);
|
||||
if (!injectionLibraryHandle)
|
||||
{
|
||||
NVTX_ERR("Failed to load injection library\n");
|
||||
return NVTX_ERR_INIT_LOAD_LIBRARY;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* Attempt to get the injection library's entry-point */
|
||||
init_fnptr = (NvtxInitializeInjectionNvtxFunc_t)NVTX_DLLFUNC(injectionLibraryHandle, initFuncName);
|
||||
if (!init_fnptr)
|
||||
{
|
||||
NVTX_DLLCLOSE(injectionLibraryHandle);
|
||||
NVTX_ERR("Failed to get address of function InitializeInjectionNvtx2 from injection library\n");
|
||||
return NVTX_ERR_INIT_MISSING_LIBRARY_ENTRY_POINT;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if NVTX_SUPPORT_STATIC_INJECTION_LIBRARY
|
||||
if (!init_fnptr)
|
||||
{
|
||||
/* Check weakly-defined function pointer. A statically-linked injection can define this as
|
||||
* a normal symbol and it will take precedence over a dynamic injection. */
|
||||
if (InitializeInjectionNvtx2_fnptr)
|
||||
{
|
||||
init_fnptr = InitializeInjectionNvtx2_fnptr;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/* At this point, if init_fnptr is not set, then no tool has specified
|
||||
* an NVTX injection library -- return non-success result so all NVTX
|
||||
* API functions will be set to no-ops. */
|
||||
if (!init_fnptr)
|
||||
{
|
||||
return NVTX_ERR_NO_INJECTION_LIBRARY_AVAILABLE;
|
||||
}
|
||||
|
||||
/* Invoke injection library's initialization function. If it returns
|
||||
* 0 (failure) and a dynamic injection was loaded, unload it. */
|
||||
entryPointStatus = init_fnptr(NVTX_VERSIONED_IDENTIFIER(nvtxGetExportTable));
|
||||
if (entryPointStatus == 0)
|
||||
{
|
||||
NVTX_ERR("Failed to initialize injection library -- initialization function returned 0\n");
|
||||
if (injectionLibraryHandle)
|
||||
{
|
||||
NVTX_DLLCLOSE(injectionLibraryHandle);
|
||||
}
|
||||
return NVTX_ERR_INIT_FAILED_LIBRARY_ENTRY_POINT;
|
||||
}
|
||||
|
||||
return NVTX_SUCCESS;
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)(void)
|
||||
{
|
||||
unsigned int old;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState == NVTX_INIT_STATE_COMPLETE)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
NVTX_ATOMIC_CAS_32(
|
||||
old,
|
||||
&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState,
|
||||
NVTX_INIT_STATE_STARTED,
|
||||
NVTX_INIT_STATE_FRESH);
|
||||
if (old == NVTX_INIT_STATE_FRESH)
|
||||
{
|
||||
int result;
|
||||
int forceAllToNoops;
|
||||
|
||||
/* Load & initialize injection library -- it will assign the function pointers */
|
||||
result = NVTX_VERSIONED_IDENTIFIER(nvtxInitializeInjectionLibrary)();
|
||||
|
||||
/* Set all pointers not assigned by the injection to null */
|
||||
forceAllToNoops = result != NVTX_SUCCESS; /* Set all to null if injection init failed */
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxSetInitFunctionsToNoops)(forceAllToNoops);
|
||||
|
||||
/* Signal that initialization has finished, so now the assigned function pointers will be used */
|
||||
NVTX_ATOMIC_WRITE_32(
|
||||
&NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState,
|
||||
NVTX_INIT_STATE_COMPLETE);
|
||||
}
|
||||
else /* Spin-wait until initialization has finished */
|
||||
{
|
||||
NVTX_MEMBAR();
|
||||
while (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).initState != NVTX_INIT_STATE_COMPLETE)
|
||||
{
|
||||
NVTX_YIELD();
|
||||
NVTX_MEMBAR();
|
||||
}
|
||||
}
|
||||
}
|
73
independent_test_harness/nvtxDetail/nvtxInitDecls.h
Executable file
73
independent_test_harness/nvtxDetail/nvtxInitDecls.h
Executable file
@ -0,0 +1,73 @@
|
||||
#ifndef NVTX_IMPL_GUARD
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init)(const nvtxEventAttributes_t* eventAttrib);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init)(const char* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init)(const wchar_t* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init)(const nvtxEventAttributes_t* eventAttrib);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init)(const char* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init)(const wchar_t* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init)(nvtxRangeId_t id);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init)(const nvtxEventAttributes_t* eventAttrib);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init)(const char* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init)(const wchar_t* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init)(void);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init)(uint32_t category, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init)(uint32_t category, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init)(uint32_t threadId, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init)(uint32_t threadId, const wchar_t* name);
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init)(nvtx_CUdevice device, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init)(nvtx_CUdevice device, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init)(nvtx_CUcontext context, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init)(nvtx_CUcontext context, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init)(nvtx_CUstream stream, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init)(nvtx_CUstream stream, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init)(nvtx_CUevent event, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init)(nvtx_CUevent event, const wchar_t* name);
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init)(nvtx_cl_device_id device, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init)(nvtx_cl_device_id device, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init)(nvtx_cl_context context, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init)(nvtx_cl_context context, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init)(nvtx_cl_command_queue command_queue, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init)(nvtx_cl_command_queue command_queue, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init)(nvtx_cl_mem memobj, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init)(nvtx_cl_mem memobj, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init)(nvtx_cl_sampler sampler, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init)(nvtx_cl_sampler sampler, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init)(nvtx_cl_program program, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init)(nvtx_cl_program program, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init)(nvtx_cl_event evnt, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init)(nvtx_cl_event evnt, const wchar_t* name);
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init)(int device, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init)(int device, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init)(nvtx_cudaStream_t stream, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init)(nvtx_cudaStream_t stream, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init)(nvtx_cudaEvent_t event, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init)(nvtx_cudaEvent_t event, const wchar_t* name);
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init)(nvtxDomainHandle_t domain, nvtxRangeId_t id);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init)(nvtxDomainHandle_t domain);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxResourceHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init)(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init)(nvtxResourceHandle_t resource);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const char* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init)(nvtxDomainHandle_t domain, const char* string);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init)(nvtxDomainHandle_t domain, const wchar_t* string);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init)(const char* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init)(const wchar_t* message);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init)(nvtxDomainHandle_t domain);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init)(const void* reserved);
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION nvtxSyncUser_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init)(nvtxSyncUser_t handle);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init)(nvtxSyncUser_t handle);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init)(nvtxSyncUser_t handle);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init)(nvtxSyncUser_t handle);
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init)(nvtxSyncUser_t handle);
|
565
independent_test_harness/nvtxDetail/nvtxInitDefs.h
Executable file
565
independent_test_harness/nvtxDetail/nvtxInitDefs.h
Executable file
@ -0,0 +1,565 @@
|
||||
#ifndef NVTX_IMPL_GUARD
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExt.h (except when NVTX_NO_IMPL is defined).
|
||||
#endif
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init)(const nvtxEventAttributes_t* eventAttrib){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxMarkEx(eventAttrib);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init)(const char* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxMarkA(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init)(const wchar_t* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxMarkW(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init)(const nvtxEventAttributes_t* eventAttrib){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangeStartEx(eventAttrib);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init)(const char* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangeStartA(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init)(const wchar_t* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangeStartW(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init)(nvtxRangeId_t id){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxRangeEnd(id);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init)(const nvtxEventAttributes_t* eventAttrib){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangePushEx(eventAttrib);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init)(const char* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangePushA(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init)(const wchar_t* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangePushW(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init)(void){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxRangePop();
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init)(uint32_t category, const char* name){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxNameCategoryA(category, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init)(uint32_t category, const wchar_t* name){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxNameCategoryW(category, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init)(uint32_t threadId, const char* name){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxNameOsThreadA(threadId, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init)(uint32_t threadId, const wchar_t* name){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxNameOsThreadW(threadId, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxDomainMarkEx(domain, eventAttrib);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxRangeId_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainRangeStartEx(domain, eventAttrib);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init)(nvtxDomainHandle_t domain, nvtxRangeId_t id){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxDomainRangeEnd(domain, id);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainRangePushEx(domain, eventAttrib);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION int NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init)(nvtxDomainHandle_t domain){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainRangePop(domain);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxResourceHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init)(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainResourceCreate(domain, attribs);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init)(nvtxResourceHandle_t resource){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxDomainResourceDestroy(resource);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const char* name){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxDomainNameCategoryA(domain, category, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init)(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxDomainNameCategoryW(domain, category, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init)(nvtxDomainHandle_t domain, const char* string){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainRegisterStringA(domain, string);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxStringHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init)(nvtxDomainHandle_t domain, const wchar_t* string){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainRegisterStringW(domain, string);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init)(const char* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainCreateA(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxDomainHandle_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init)(const wchar_t* message){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
return nvtxDomainCreateW(message);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init)(nvtxDomainHandle_t domain){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxDomainDestroy(domain);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init)(const void* reserved){
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
nvtxInitialize(reserved);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init)(nvtx_CUdevice device, const char* name){
|
||||
nvtxNameCuDeviceA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr;
|
||||
if (local)
|
||||
local(device, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init)(nvtx_CUdevice device, const wchar_t* name){
|
||||
nvtxNameCuDeviceW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr;
|
||||
if (local)
|
||||
local(device, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init)(nvtx_CUcontext context, const char* name){
|
||||
nvtxNameCuContextA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr;
|
||||
if (local)
|
||||
local(context, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init)(nvtx_CUcontext context, const wchar_t* name){
|
||||
nvtxNameCuContextW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr;
|
||||
if (local)
|
||||
local(context, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init)(nvtx_CUstream stream, const char* name){
|
||||
nvtxNameCuStreamA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr;
|
||||
if (local)
|
||||
local(stream, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init)(nvtx_CUstream stream, const wchar_t* name){
|
||||
nvtxNameCuStreamW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr;
|
||||
if (local)
|
||||
local(stream, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init)(nvtx_CUevent event, const char* name){
|
||||
nvtxNameCuEventA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr;
|
||||
if (local)
|
||||
local(event, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init)(nvtx_CUevent event, const wchar_t* name){
|
||||
nvtxNameCuEventW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr;
|
||||
if (local)
|
||||
local(event, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init)(int device, const char* name){
|
||||
nvtxNameCudaDeviceA_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr;
|
||||
if (local)
|
||||
local(device, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init)(int device, const wchar_t* name){
|
||||
nvtxNameCudaDeviceW_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr;
|
||||
if (local)
|
||||
local(device, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init)(nvtx_cudaStream_t stream, const char* name){
|
||||
nvtxNameCudaStreamA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr;
|
||||
if (local)
|
||||
local(stream, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init)(nvtx_cudaStream_t stream, const wchar_t* name){
|
||||
nvtxNameCudaStreamW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr;
|
||||
if (local)
|
||||
local(stream, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init)(nvtx_cudaEvent_t event, const char* name){
|
||||
nvtxNameCudaEventA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr;
|
||||
if (local)
|
||||
local(event, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init)(nvtx_cudaEvent_t event, const wchar_t* name){
|
||||
nvtxNameCudaEventW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr;
|
||||
if (local)
|
||||
local(event, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init)(nvtx_cl_device_id device, const char* name){
|
||||
nvtxNameClDeviceA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr;
|
||||
if (local)
|
||||
local(device, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init)(nvtx_cl_device_id device, const wchar_t* name){
|
||||
nvtxNameClDeviceW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr;
|
||||
if (local)
|
||||
local(device, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init)(nvtx_cl_context context, const char* name){
|
||||
nvtxNameClContextA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr;
|
||||
if (local)
|
||||
local(context, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init)(nvtx_cl_context context, const wchar_t* name){
|
||||
nvtxNameClContextW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr;
|
||||
if (local)
|
||||
local(context, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init)(nvtx_cl_command_queue command_queue, const char* name){
|
||||
nvtxNameClCommandQueueA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr;
|
||||
if (local)
|
||||
local(command_queue, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init)(nvtx_cl_command_queue command_queue, const wchar_t* name){
|
||||
nvtxNameClCommandQueueW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr;
|
||||
if (local)
|
||||
local(command_queue, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init)(nvtx_cl_mem memobj, const char* name){
|
||||
nvtxNameClMemObjectA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr;
|
||||
if (local)
|
||||
local(memobj, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init)(nvtx_cl_mem memobj, const wchar_t* name){
|
||||
nvtxNameClMemObjectW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr;
|
||||
if (local)
|
||||
local(memobj, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init)(nvtx_cl_sampler sampler, const char* name){
|
||||
nvtxNameClSamplerA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr;
|
||||
if (local)
|
||||
local(sampler, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init)(nvtx_cl_sampler sampler, const wchar_t* name){
|
||||
nvtxNameClSamplerW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr;
|
||||
if (local)
|
||||
local(sampler, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init)(nvtx_cl_program program, const char* name){
|
||||
nvtxNameClProgramA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr;
|
||||
if (local)
|
||||
local(program, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init)(nvtx_cl_program program, const wchar_t* name){
|
||||
nvtxNameClProgramW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr;
|
||||
if (local)
|
||||
local(program, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init)(nvtx_cl_event evnt, const char* name){
|
||||
nvtxNameClEventA_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr;
|
||||
if (local)
|
||||
local(evnt, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init)(nvtx_cl_event evnt, const wchar_t* name){
|
||||
nvtxNameClEventW_fakeimpl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr;
|
||||
if (local)
|
||||
local(evnt, name);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION nvtxSyncUser_t NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs){
|
||||
nvtxDomainSyncUserCreate_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr;
|
||||
if (local) {
|
||||
return local(domain, attribs);
|
||||
}
|
||||
return (nvtxSyncUser_t)0;
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init)(nvtxSyncUser_t handle){
|
||||
nvtxDomainSyncUserDestroy_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr;
|
||||
if (local)
|
||||
local(handle);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init)(nvtxSyncUser_t handle){
|
||||
nvtxDomainSyncUserAcquireStart_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr;
|
||||
if (local)
|
||||
local(handle);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init)(nvtxSyncUser_t handle){
|
||||
nvtxDomainSyncUserAcquireFailed_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr;
|
||||
if (local)
|
||||
local(handle);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init)(nvtxSyncUser_t handle){
|
||||
nvtxDomainSyncUserAcquireSuccess_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr;
|
||||
if (local)
|
||||
local(handle);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_API NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init)(nvtxSyncUser_t handle){
|
||||
nvtxDomainSyncUserReleasing_impl_fntype local;
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxInitOnce)();
|
||||
local = NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr;
|
||||
if (local)
|
||||
local(handle);
|
||||
}
|
||||
|
||||
NVTX_LINKONCE_FWDDECL_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxSetInitFunctionsToNoops)(int forceAllToNoops);
|
||||
NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_VERSIONED_IDENTIFIER(nvtxSetInitFunctionsToNoops)(int forceAllToNoops)
|
||||
{
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxMarkEx_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkEx_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxMarkA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxMarkW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxMarkW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartEx_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartEx_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeStartW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeStartW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangeEnd_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangeEnd_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePushEx_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushEx_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePushA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePushW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePushW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxRangePop_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxRangePop_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCategoryW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCategoryW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameOsThreadW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameOsThreadW_impl_fnptr = NULL;
|
||||
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuDeviceW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuDeviceW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuContextW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuContextW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuStreamW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuStreamW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCuEventW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCuEventW_impl_fnptr = NULL;
|
||||
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClDeviceW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClDeviceW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClContextW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClContextW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClCommandQueueW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClCommandQueueW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClMemObjectW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClMemObjectW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClSamplerW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClSamplerW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClProgramW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClProgramW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameClEventW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameClEventW_impl_fnptr = NULL;
|
||||
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaDeviceW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaDeviceW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaStreamW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaStreamW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxNameCudaEventW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxNameCudaEventW_impl_fnptr = NULL;
|
||||
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainMarkEx_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainMarkEx_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeStartEx_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeStartEx_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangeEnd_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangeEnd_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePushEx_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePushEx_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRangePop_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRangePop_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceCreate_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceCreate_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainResourceDestroy_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainResourceDestroy_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainNameCategoryW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainNameCategoryW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainRegisterStringW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainRegisterStringW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateA_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateA_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainCreateW_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainCreateW_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainDestroy_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainDestroy_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxInitialize_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxInitialize_impl_fnptr = NULL;
|
||||
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserCreate_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserCreate_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserDestroy_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserDestroy_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireStart_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireStart_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireFailed_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireFailed_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserAcquireSuccess_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserAcquireSuccess_impl_fnptr = NULL;
|
||||
if (NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr == NVTX_VERSIONED_IDENTIFIER(nvtxDomainSyncUserReleasing_impl_init) || forceAllToNoops)
|
||||
NVTX_VERSIONED_IDENTIFIER(nvtxGlobals).nvtxDomainSyncUserReleasing_impl_fnptr = NULL;
|
||||
}
|
75
independent_test_harness/nvtxDetail/nvtxLinkOnce.h
Executable file
75
independent_test_harness/nvtxDetail/nvtxLinkOnce.h
Executable file
@ -0,0 +1,75 @@
|
||||
#ifndef __NVTX_LINKONCE_H__
|
||||
#define __NVTX_LINKONCE_H__
|
||||
|
||||
/* This header defines macros to permit making definitions of global variables
|
||||
* and functions in C/C++ header files which may be included multiple times in
|
||||
* a translation unit or linkage unit. It allows authoring header-only libraries
|
||||
* which can be used by multiple other header-only libraries (either as the same
|
||||
* copy or multiple copies), and does not require any build changes, such as
|
||||
* adding another .c file, linking a static library, or deploying a dynamic
|
||||
* library. Globals defined with these macros have the property that they have
|
||||
* the same address, pointing to a single instance, for the entire linkage unit.
|
||||
* It is expected but not guaranteed that each linkage unit will have a separate
|
||||
* instance.
|
||||
*
|
||||
* In some situations it is desirable to declare a variable without initializing
|
||||
* it, refer to it in code or other variables' initializers, and then initialize
|
||||
* it later. Similarly, functions can be prototyped, have their address taken,
|
||||
* and then have their body defined later. In such cases, use the FWDDECL macros
|
||||
* when forward-declaring LINKONCE global variables without initializers and
|
||||
* function prototypes, and then use the DEFINE macros when later defining them.
|
||||
* Although in many cases the FWDDECL macro is equivalent to the DEFINE macro,
|
||||
* following this pattern makes code maximally portable.
|
||||
*/
|
||||
|
||||
#if defined(__MINGW32__) /* MinGW */
|
||||
#define NVTX_LINKONCE_WEAK __attribute__((section(".gnu.linkonce.0.")))
|
||||
#if defined(__cplusplus)
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" inline NVTX_LINKONCE_WEAK
|
||||
#else
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_WEAK
|
||||
#endif
|
||||
#elif defined(_MSC_VER) /* MSVC */
|
||||
#if defined(__cplusplus)
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL extern "C" __declspec(selectany)
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" inline
|
||||
#else
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION __inline
|
||||
#endif
|
||||
#elif defined(__CYGWIN__) && defined(__clang__) /* Clang on Cygwin */
|
||||
#define NVTX_LINKONCE_WEAK __attribute__((section(".gnu.linkonce.0.")))
|
||||
#if defined(__cplusplus)
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_WEAK
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" NVTX_LINKONCE_WEAK
|
||||
#else
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_WEAK
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_WEAK
|
||||
#endif
|
||||
#elif defined(__CYGWIN__) /* Assume GCC or compatible */
|
||||
#define NVTX_LINKONCE_WEAK __attribute__((weak))
|
||||
#if defined(__cplusplus)
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL __declspec(selectany)
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" inline
|
||||
#else
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_WEAK
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_WEAK
|
||||
#endif
|
||||
#else /* All others: Assume GCC, clang, or compatible */
|
||||
#define NVTX_LINKONCE_WEAK __attribute__((weak))
|
||||
#define NVTX_LINKONCE_HIDDEN __attribute__((visibility("hidden")))
|
||||
#if defined(__cplusplus)
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_HIDDEN NVTX_LINKONCE_WEAK
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION extern "C" NVTX_LINKONCE_HIDDEN inline
|
||||
#else
|
||||
#define NVTX_LINKONCE_DEFINE_GLOBAL NVTX_LINKONCE_HIDDEN NVTX_LINKONCE_WEAK
|
||||
#define NVTX_LINKONCE_DEFINE_FUNCTION NVTX_LINKONCE_HIDDEN NVTX_LINKONCE_WEAK
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define NVTX_LINKONCE_FWDDECL_GLOBAL NVTX_LINKONCE_DEFINE_GLOBAL extern
|
||||
#define NVTX_LINKONCE_FWDDECL_FUNCTION NVTX_LINKONCE_DEFINE_FUNCTION
|
||||
|
||||
#endif /* __NVTX_LINKONCE_H__ */
|
333
independent_test_harness/nvtxDetail/nvtxTypes.h
Executable file
333
independent_test_harness/nvtxDetail/nvtxTypes.h
Executable file
@ -0,0 +1,333 @@
|
||||
/*
|
||||
* Copyright 2009-2016 NVIDIA Corporation. All rights reserved.
|
||||
*
|
||||
* NOTICE TO USER:
|
||||
*
|
||||
* This source code is subject to NVIDIA ownership rights under U.S. and
|
||||
* international Copyright laws.
|
||||
*
|
||||
* This software and the information contained herein is PROPRIETARY and
|
||||
* CONFIDENTIAL to NVIDIA and is being provided under the terms and conditions
|
||||
* of a form of NVIDIA software license agreement.
|
||||
*
|
||||
* NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
|
||||
* CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
|
||||
* IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
|
||||
* REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
|
||||
* OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
|
||||
* OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||
* OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
|
||||
* OR PERFORMANCE OF THIS SOURCE CODE.
|
||||
*
|
||||
* U.S. Government End Users. This source code is a "commercial item" as
|
||||
* that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
|
||||
* "commercial computer software" and "commercial computer software
|
||||
* documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
|
||||
* and is provided to the U.S. Government only as a commercial end item.
|
||||
* Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
|
||||
* 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
|
||||
* source code with only those rights set forth herein.
|
||||
*
|
||||
* Any use of this source code in individual and commercial software must
|
||||
* include, in the user documentation and internal comments to the code,
|
||||
* the above Disclaimer and U.S. Government End Users Notice.
|
||||
*/
|
||||
|
||||
/* This header defines types which are used by the internal implementation
|
||||
* of NVTX and callback subscribers. API clients do not use these types,
|
||||
* so they are defined here instead of in nvToolsExt.h to clarify they are
|
||||
* not part of the NVTX client API. */
|
||||
|
||||
#ifndef NVTX_IMPL_GUARD
|
||||
#error Never include this file directly -- it is automatically included by nvToolsExt.h.
|
||||
#endif
|
||||
|
||||
/* ------ Dependency-free types binary-compatible with real types ------- */
|
||||
|
||||
/* In order to avoid having the NVTX core API headers depend on non-NVTX
|
||||
* headers like cuda.h, NVTX defines binary-compatible types to use for
|
||||
* safely making the initialization versions of all NVTX functions without
|
||||
* needing to have definitions for the real types. */
|
||||
|
||||
typedef int nvtx_CUdevice;
|
||||
typedef void* nvtx_CUcontext;
|
||||
typedef void* nvtx_CUstream;
|
||||
typedef void* nvtx_CUevent;
|
||||
|
||||
typedef void* nvtx_cudaStream_t;
|
||||
typedef void* nvtx_cudaEvent_t;
|
||||
|
||||
typedef void* nvtx_cl_platform_id;
|
||||
typedef void* nvtx_cl_device_id;
|
||||
typedef void* nvtx_cl_context;
|
||||
typedef void* nvtx_cl_command_queue;
|
||||
typedef void* nvtx_cl_mem;
|
||||
typedef void* nvtx_cl_program;
|
||||
typedef void* nvtx_cl_kernel;
|
||||
typedef void* nvtx_cl_event;
|
||||
typedef void* nvtx_cl_sampler;
|
||||
|
||||
typedef struct nvtxSyncUser* nvtxSyncUser_t;
|
||||
struct nvtxSyncUserAttributes_v0;
|
||||
typedef struct nvtxSyncUserAttributes_v0 nvtxSyncUserAttributes_t;
|
||||
|
||||
/* --------- Types for function pointers (with fake API types) ---------- */
|
||||
|
||||
typedef void (NVTX_API * nvtxMarkEx_impl_fntype)(const nvtxEventAttributes_t* eventAttrib);
|
||||
typedef void (NVTX_API * nvtxMarkA_impl_fntype)(const char* message);
|
||||
typedef void (NVTX_API * nvtxMarkW_impl_fntype)(const wchar_t* message);
|
||||
typedef nvtxRangeId_t (NVTX_API * nvtxRangeStartEx_impl_fntype)(const nvtxEventAttributes_t* eventAttrib);
|
||||
typedef nvtxRangeId_t (NVTX_API * nvtxRangeStartA_impl_fntype)(const char* message);
|
||||
typedef nvtxRangeId_t (NVTX_API * nvtxRangeStartW_impl_fntype)(const wchar_t* message);
|
||||
typedef void (NVTX_API * nvtxRangeEnd_impl_fntype)(nvtxRangeId_t id);
|
||||
typedef int (NVTX_API * nvtxRangePushEx_impl_fntype)(const nvtxEventAttributes_t* eventAttrib);
|
||||
typedef int (NVTX_API * nvtxRangePushA_impl_fntype)(const char* message);
|
||||
typedef int (NVTX_API * nvtxRangePushW_impl_fntype)(const wchar_t* message);
|
||||
typedef int (NVTX_API * nvtxRangePop_impl_fntype)(void);
|
||||
typedef void (NVTX_API * nvtxNameCategoryA_impl_fntype)(uint32_t category, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCategoryW_impl_fntype)(uint32_t category, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameOsThreadA_impl_fntype)(uint32_t threadId, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameOsThreadW_impl_fntype)(uint32_t threadId, const wchar_t* name);
|
||||
|
||||
/* Real impl types are defined in nvtxImplCuda_v3.h, where CUDA headers are included */
|
||||
typedef void (NVTX_API * nvtxNameCuDeviceA_fakeimpl_fntype)(nvtx_CUdevice device, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuDeviceW_fakeimpl_fntype)(nvtx_CUdevice device, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCuContextA_fakeimpl_fntype)(nvtx_CUcontext context, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuContextW_fakeimpl_fntype)(nvtx_CUcontext context, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCuStreamA_fakeimpl_fntype)(nvtx_CUstream stream, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuStreamW_fakeimpl_fntype)(nvtx_CUstream stream, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCuEventA_fakeimpl_fntype)(nvtx_CUevent event, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCuEventW_fakeimpl_fntype)(nvtx_CUevent event, const wchar_t* name);
|
||||
|
||||
/* Real impl types are defined in nvtxImplOpenCL_v3.h, where OPENCL headers are included */
|
||||
typedef void (NVTX_API * nvtxNameClDeviceA_fakeimpl_fntype)(nvtx_cl_device_id device, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClDeviceW_fakeimpl_fntype)(nvtx_cl_device_id device, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClContextA_fakeimpl_fntype)(nvtx_cl_context context, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClContextW_fakeimpl_fntype)(nvtx_cl_context context, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClCommandQueueA_fakeimpl_fntype)(nvtx_cl_command_queue command_queue, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClCommandQueueW_fakeimpl_fntype)(nvtx_cl_command_queue command_queue, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClMemObjectA_fakeimpl_fntype)(nvtx_cl_mem memobj, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClMemObjectW_fakeimpl_fntype)(nvtx_cl_mem memobj, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClSamplerA_fakeimpl_fntype)(nvtx_cl_sampler sampler, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClSamplerW_fakeimpl_fntype)(nvtx_cl_sampler sampler, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClProgramA_fakeimpl_fntype)(nvtx_cl_program program, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClProgramW_fakeimpl_fntype)(nvtx_cl_program program, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameClEventA_fakeimpl_fntype)(nvtx_cl_event evnt, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameClEventW_fakeimpl_fntype)(nvtx_cl_event evnt, const wchar_t* name);
|
||||
|
||||
/* Real impl types are defined in nvtxImplCudaRt_v3.h, where CUDART headers are included */
|
||||
typedef void (NVTX_API * nvtxNameCudaDeviceA_impl_fntype)(int device, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaDeviceW_impl_fntype)(int device, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaStreamA_fakeimpl_fntype)(nvtx_cudaStream_t stream, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaStreamW_fakeimpl_fntype)(nvtx_cudaStream_t stream, const wchar_t* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaEventA_fakeimpl_fntype)(nvtx_cudaEvent_t event, const char* name);
|
||||
typedef void (NVTX_API * nvtxNameCudaEventW_fakeimpl_fntype)(nvtx_cudaEvent_t event, const wchar_t* name);
|
||||
|
||||
typedef void (NVTX_API * nvtxDomainMarkEx_impl_fntype)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
|
||||
typedef nvtxRangeId_t (NVTX_API * nvtxDomainRangeStartEx_impl_fntype)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
|
||||
typedef void (NVTX_API * nvtxDomainRangeEnd_impl_fntype)(nvtxDomainHandle_t domain, nvtxRangeId_t id);
|
||||
typedef int (NVTX_API * nvtxDomainRangePushEx_impl_fntype)(nvtxDomainHandle_t domain, const nvtxEventAttributes_t* eventAttrib);
|
||||
typedef int (NVTX_API * nvtxDomainRangePop_impl_fntype)(nvtxDomainHandle_t domain);
|
||||
typedef nvtxResourceHandle_t (NVTX_API * nvtxDomainResourceCreate_impl_fntype)(nvtxDomainHandle_t domain, nvtxResourceAttributes_t* attribs);
|
||||
typedef void (NVTX_API * nvtxDomainResourceDestroy_impl_fntype)(nvtxResourceHandle_t resource);
|
||||
typedef void (NVTX_API * nvtxDomainNameCategoryA_impl_fntype)(nvtxDomainHandle_t domain, uint32_t category, const char* name);
|
||||
typedef void (NVTX_API * nvtxDomainNameCategoryW_impl_fntype)(nvtxDomainHandle_t domain, uint32_t category, const wchar_t* name);
|
||||
typedef nvtxStringHandle_t (NVTX_API * nvtxDomainRegisterStringA_impl_fntype)(nvtxDomainHandle_t domain, const char* string);
|
||||
typedef nvtxStringHandle_t (NVTX_API * nvtxDomainRegisterStringW_impl_fntype)(nvtxDomainHandle_t domain, const wchar_t* string);
|
||||
typedef nvtxDomainHandle_t (NVTX_API * nvtxDomainCreateA_impl_fntype)(const char* message);
|
||||
typedef nvtxDomainHandle_t (NVTX_API * nvtxDomainCreateW_impl_fntype)(const wchar_t* message);
|
||||
typedef void (NVTX_API * nvtxDomainDestroy_impl_fntype)(nvtxDomainHandle_t domain);
|
||||
typedef void (NVTX_API * nvtxInitialize_impl_fntype)(const void* reserved);
|
||||
|
||||
typedef nvtxSyncUser_t (NVTX_API * nvtxDomainSyncUserCreate_impl_fntype)(nvtxDomainHandle_t domain, const nvtxSyncUserAttributes_t* attribs);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserDestroy_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserAcquireStart_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserAcquireFailed_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserAcquireSuccess_impl_fntype)(nvtxSyncUser_t handle);
|
||||
typedef void (NVTX_API * nvtxDomainSyncUserReleasing_impl_fntype)(nvtxSyncUser_t handle);
|
||||
|
||||
/* ---------------- Types for callback subscription --------------------- */
|
||||
|
||||
typedef const void *(NVTX_API * NvtxGetExportTableFunc_t)(uint32_t exportTableId);
|
||||
typedef int (NVTX_API * NvtxInitializeInjectionNvtxFunc_t)(NvtxGetExportTableFunc_t exportTable);
|
||||
|
||||
typedef enum NvtxCallbackModule
|
||||
{
|
||||
NVTX_CB_MODULE_INVALID = 0,
|
||||
NVTX_CB_MODULE_CORE = 1,
|
||||
NVTX_CB_MODULE_CUDA = 2,
|
||||
NVTX_CB_MODULE_OPENCL = 3,
|
||||
NVTX_CB_MODULE_CUDART = 4,
|
||||
NVTX_CB_MODULE_CORE2 = 5,
|
||||
NVTX_CB_MODULE_SYNC = 6,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CB_MODULE_SIZE,
|
||||
NVTX_CB_MODULE_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackModule;
|
||||
|
||||
typedef enum NvtxCallbackIdCore
|
||||
{
|
||||
NVTX_CBID_CORE_INVALID = 0,
|
||||
NVTX_CBID_CORE_MarkEx = 1,
|
||||
NVTX_CBID_CORE_MarkA = 2,
|
||||
NVTX_CBID_CORE_MarkW = 3,
|
||||
NVTX_CBID_CORE_RangeStartEx = 4,
|
||||
NVTX_CBID_CORE_RangeStartA = 5,
|
||||
NVTX_CBID_CORE_RangeStartW = 6,
|
||||
NVTX_CBID_CORE_RangeEnd = 7,
|
||||
NVTX_CBID_CORE_RangePushEx = 8,
|
||||
NVTX_CBID_CORE_RangePushA = 9,
|
||||
NVTX_CBID_CORE_RangePushW = 10,
|
||||
NVTX_CBID_CORE_RangePop = 11,
|
||||
NVTX_CBID_CORE_NameCategoryA = 12,
|
||||
NVTX_CBID_CORE_NameCategoryW = 13,
|
||||
NVTX_CBID_CORE_NameOsThreadA = 14,
|
||||
NVTX_CBID_CORE_NameOsThreadW = 15,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CBID_CORE_SIZE,
|
||||
NVTX_CBID_CORE_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackIdCore;
|
||||
|
||||
typedef enum NvtxCallbackIdCore2
|
||||
{
|
||||
NVTX_CBID_CORE2_INVALID = 0,
|
||||
NVTX_CBID_CORE2_DomainMarkEx = 1,
|
||||
NVTX_CBID_CORE2_DomainRangeStartEx = 2,
|
||||
NVTX_CBID_CORE2_DomainRangeEnd = 3,
|
||||
NVTX_CBID_CORE2_DomainRangePushEx = 4,
|
||||
NVTX_CBID_CORE2_DomainRangePop = 5,
|
||||
NVTX_CBID_CORE2_DomainResourceCreate = 6,
|
||||
NVTX_CBID_CORE2_DomainResourceDestroy = 7,
|
||||
NVTX_CBID_CORE2_DomainNameCategoryA = 8,
|
||||
NVTX_CBID_CORE2_DomainNameCategoryW = 9,
|
||||
NVTX_CBID_CORE2_DomainRegisterStringA = 10,
|
||||
NVTX_CBID_CORE2_DomainRegisterStringW = 11,
|
||||
NVTX_CBID_CORE2_DomainCreateA = 12,
|
||||
NVTX_CBID_CORE2_DomainCreateW = 13,
|
||||
NVTX_CBID_CORE2_DomainDestroy = 14,
|
||||
NVTX_CBID_CORE2_Initialize = 15,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CBID_CORE2_SIZE,
|
||||
NVTX_CBID_CORE2_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackIdCore2;
|
||||
|
||||
typedef enum NvtxCallbackIdCuda
|
||||
{
|
||||
NVTX_CBID_CUDA_INVALID = 0,
|
||||
NVTX_CBID_CUDA_NameCuDeviceA = 1,
|
||||
NVTX_CBID_CUDA_NameCuDeviceW = 2,
|
||||
NVTX_CBID_CUDA_NameCuContextA = 3,
|
||||
NVTX_CBID_CUDA_NameCuContextW = 4,
|
||||
NVTX_CBID_CUDA_NameCuStreamA = 5,
|
||||
NVTX_CBID_CUDA_NameCuStreamW = 6,
|
||||
NVTX_CBID_CUDA_NameCuEventA = 7,
|
||||
NVTX_CBID_CUDA_NameCuEventW = 8,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CBID_CUDA_SIZE,
|
||||
NVTX_CBID_CUDA_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackIdCuda;
|
||||
|
||||
typedef enum NvtxCallbackIdCudaRt
|
||||
{
|
||||
NVTX_CBID_CUDART_INVALID = 0,
|
||||
NVTX_CBID_CUDART_NameCudaDeviceA = 1,
|
||||
NVTX_CBID_CUDART_NameCudaDeviceW = 2,
|
||||
NVTX_CBID_CUDART_NameCudaStreamA = 3,
|
||||
NVTX_CBID_CUDART_NameCudaStreamW = 4,
|
||||
NVTX_CBID_CUDART_NameCudaEventA = 5,
|
||||
NVTX_CBID_CUDART_NameCudaEventW = 6,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CBID_CUDART_SIZE,
|
||||
NVTX_CBID_CUDART_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackIdCudaRt;
|
||||
|
||||
typedef enum NvtxCallbackIdOpenCL
|
||||
{
|
||||
NVTX_CBID_OPENCL_INVALID = 0,
|
||||
NVTX_CBID_OPENCL_NameClDeviceA = 1,
|
||||
NVTX_CBID_OPENCL_NameClDeviceW = 2,
|
||||
NVTX_CBID_OPENCL_NameClContextA = 3,
|
||||
NVTX_CBID_OPENCL_NameClContextW = 4,
|
||||
NVTX_CBID_OPENCL_NameClCommandQueueA = 5,
|
||||
NVTX_CBID_OPENCL_NameClCommandQueueW = 6,
|
||||
NVTX_CBID_OPENCL_NameClMemObjectA = 7,
|
||||
NVTX_CBID_OPENCL_NameClMemObjectW = 8,
|
||||
NVTX_CBID_OPENCL_NameClSamplerA = 9,
|
||||
NVTX_CBID_OPENCL_NameClSamplerW = 10,
|
||||
NVTX_CBID_OPENCL_NameClProgramA = 11,
|
||||
NVTX_CBID_OPENCL_NameClProgramW = 12,
|
||||
NVTX_CBID_OPENCL_NameClEventA = 13,
|
||||
NVTX_CBID_OPENCL_NameClEventW = 14,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CBID_OPENCL_SIZE,
|
||||
NVTX_CBID_OPENCL_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackIdOpenCL;
|
||||
|
||||
typedef enum NvtxCallbackIdSync
|
||||
{
|
||||
NVTX_CBID_SYNC_INVALID = 0,
|
||||
NVTX_CBID_SYNC_DomainSyncUserCreate = 1,
|
||||
NVTX_CBID_SYNC_DomainSyncUserDestroy = 2,
|
||||
NVTX_CBID_SYNC_DomainSyncUserAcquireStart = 3,
|
||||
NVTX_CBID_SYNC_DomainSyncUserAcquireFailed = 4,
|
||||
NVTX_CBID_SYNC_DomainSyncUserAcquireSuccess = 5,
|
||||
NVTX_CBID_SYNC_DomainSyncUserReleasing = 6,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_CBID_SYNC_SIZE,
|
||||
NVTX_CBID_SYNC_FORCE_INT = 0x7fffffff
|
||||
} NvtxCallbackIdSync;
|
||||
|
||||
/* IDs for NVTX Export Tables */
|
||||
typedef enum NvtxExportTableID
|
||||
{
|
||||
NVTX_ETID_INVALID = 0,
|
||||
NVTX_ETID_CALLBACKS = 1,
|
||||
NVTX_ETID_RESERVED0 = 2,
|
||||
NVTX_ETID_VERSIONINFO = 3,
|
||||
/* --- New constants must only be added directly above this line --- */
|
||||
NVTX_ETID_SIZE,
|
||||
NVTX_ETID_FORCE_INT = 0x7fffffff
|
||||
} NvtxExportTableID;
|
||||
|
||||
typedef void (* NvtxFunctionPointer)(void); /* generic uncallable function pointer, must be casted to appropriate function type */
|
||||
typedef NvtxFunctionPointer** NvtxFunctionTable; /* double pointer because array(1) of pointers(2) to function pointers */
|
||||
|
||||
typedef struct NvtxExportTableCallbacks
|
||||
{
|
||||
size_t struct_size;
|
||||
|
||||
/* returns an array of pointer to function pointers*/
|
||||
int (NVTX_API *GetModuleFunctionTable)(
|
||||
NvtxCallbackModule module,
|
||||
NvtxFunctionTable* out_table,
|
||||
unsigned int* out_size);
|
||||
} NvtxExportTableCallbacks;
|
||||
|
||||
typedef struct NvtxExportTableVersionInfo
|
||||
{
|
||||
/* sizeof(NvtxExportTableVersionInfo) */
|
||||
size_t struct_size;
|
||||
|
||||
/* The API version comes from the NVTX library linked to the app. The
|
||||
* injection library is can use this info to make some assumptions */
|
||||
uint32_t version;
|
||||
|
||||
/* Reserved for alignment, do not use */
|
||||
uint32_t reserved0;
|
||||
|
||||
/* This must be set by tools when attaching to provide applications
|
||||
* the ability to, in emergency situations, detect problematic tools
|
||||
* versions and modify the NVTX source to prevent attaching anything
|
||||
* that causes trouble in the app. Currently, this value is ignored. */
|
||||
void (NVTX_API *SetInjectionNvtxVersion)(
|
||||
uint32_t version);
|
||||
} NvtxExportTableVersionInfo;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ int main(int argc, char **argv) {
|
||||
assert(argc == 2);
|
||||
char *version = argv[1];
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
cublasHandle_t handle = init_cublas();
|
||||
cusolverDnHandle_t s_handle = init_cusolver();
|
||||
#endif
|
||||
@ -21,16 +21,17 @@ int main(int argc, char **argv) {
|
||||
// SETUP DATA ACCESS
|
||||
hid_t file_id = H5Fopen(DATASET, H5F_ACC_RDONLY, H5P_DEFAULT);
|
||||
|
||||
printf("\n# %d REPETITIONS\n", REPETITIONS);
|
||||
printf("#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n");
|
||||
printf("#1\t2\t3\t4\t\t5\t6\t\t7\t\t8\t\t9\t\t10\t\t11\t\t12\t\t13\t\t14\n");
|
||||
printf("#CYCLE\tUPDS\tERR_IN\tERR_BREAK\tERR_OUT\tSPLITS\t\tBLK_FAILS\tMAX\t\tFROB\t\tCOND\t\tCPU_CYC\t\tCPU_CYC/UPD\tCUMUL\t\tREC\n");
|
||||
printf("#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n");
|
||||
|
||||
// FOR EACH UPDATE CYCLE DO:
|
||||
for (uint32_t cycles_index = 0; cycles_index < 1; cycles_index++) {
|
||||
for (uint32_t cycles_index = 0; cycles_index < n_cycles; cycles_index++) {
|
||||
|
||||
// SETUP TEST PARAMETERS
|
||||
const uint32_t GHz = 2800000000;
|
||||
// const uint32_t GHz = 2800000000; // 2.8 giga-cycles per second
|
||||
const double breakdown = 0.001; // default = 0.001. 1e-9 might be too small
|
||||
const double tolerance = 0.001; // default = 0.001
|
||||
double cumulative = 0;
|
||||
@ -154,14 +155,31 @@ int main(int argc, char **argv) {
|
||||
// 4. ADD TIME DIFFERENCE TO TIME CUMMULATOR
|
||||
accumulator += (double) (after - before);
|
||||
}
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
#ifdef USE_OMP
|
||||
else if (version[0] == 'o') { // Woodbury K OMP
|
||||
|
||||
// 1. FETCH START TIME
|
||||
uint64_t before = rdtsc();
|
||||
|
||||
// 2. EXECUTE KERNEL AND REMEMBER EXIT STATUS
|
||||
err_break = qmckl_woodbury_k_omp(Lds, Dim, N_updates, Updates,
|
||||
Updates_index, breakdown, Slater_invT_copy, &determinant);
|
||||
|
||||
// 3. FETCH FINISH TIME
|
||||
uint64_t after = rdtsc();
|
||||
|
||||
// 4. ADD TIME DIFFERENCE TO TIME CUMMULATOR
|
||||
accumulator += (double) (after - before);
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
else if (version[0] == 'c') { // Woodbury K cuBLAS
|
||||
|
||||
// 1. FETCH START TIME
|
||||
uint64_t before = rdtsc();
|
||||
|
||||
// 2. EXECUTE KERNEL AND REMEMBER EXIT STATUS
|
||||
err_break = qmckl_woodbury_k_cublas_offload(handle, s_handle, Lds, Dim, N_updates, Updates,
|
||||
err_break = qmckl_woodbury_k_ompol_cuda_sync(handle, s_handle, Lds, Dim, N_updates, Updates,
|
||||
Updates_index, breakdown, Slater_invT_copy, &determinant);
|
||||
|
||||
// 3. FETCH FINISH TIME
|
||||
@ -233,7 +251,6 @@ int main(int argc, char **argv) {
|
||||
// 4. COPY RESULT BACK TO ORIGINAL
|
||||
memcpy(Slater_invT, Slater_invT_copy, Lds * Dim * sizeof(double));
|
||||
determinant = determinant_copy;
|
||||
// At this point Slater_invT contains the correct inverse matrix
|
||||
|
||||
// 5. DIVIDE CYCLE- AND SPLIT-ACCUMULATOR BY NUMBER OF REPETITIONS AND RECORD
|
||||
// DIVIDE CYCLE-ACCUMULATOR BY NUMBER OF UPDATES AND RECORD
|
||||
@ -272,7 +289,8 @@ int main(int argc, char **argv) {
|
||||
free(Res);
|
||||
|
||||
// 10. WRITE RESULTS TO FILE: CYCLE#, #UPDS, ERR_INP, ERR_BREAK, #SPLITS, ERR_OUT, COND, #CLCK_TCKS
|
||||
printf("%u\t%lu\t%u\t%u\t\t%u\t%lu\t\t%lu\t\t%e\t%e\t%e\t%9.6f\t%9.6f\t%9.6f\t%lu\n", cycle, N_updates, err_inp, err_break, err_out, n_splits, block_fail, max, frob, condnr, (double)accumulator/GHz, (double)cycles_per_update/GHz, (double)cumulative/GHz, recursive_calls);
|
||||
printf("%u\t%lu\t%u\t%u\t\t%u\t%lu\t\t%lu\t\t%e\t%e\t%e\t%e\t%e\t%e\t%lu\n", cycle, N_updates, err_inp, err_break, err_out, n_splits, block_fail, max, frob, condnr, accumulator, cycles_per_update, cumulative, recursive_calls);
|
||||
// printf("%u\t%lu\t%u\t%u\t\t%u\t%lu\t\t%lu\t\t%e\t%e\t%e\t%9.6f\t%9.6f\t%9.6f\t%lu\n", cycle, N_updates, err_inp, err_break, err_out, n_splits, block_fail, max, frob, condnr, (double)accumulator/GHz, (double)cycles_per_update/GHz, (double)cumulative/GHz, recursive_calls);
|
||||
|
||||
free(Updates_index);
|
||||
free(Updates);
|
||||
@ -282,14 +300,14 @@ int main(int argc, char **argv) {
|
||||
|
||||
} // END OF CYCLE LOOP
|
||||
|
||||
printf("#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n");
|
||||
printf("#1\t2\t3\t4\t\t5\t6\t\t7\t\t8\t\t9\t\t10\t\t11\t\t12\t\t13\t\t14\n");
|
||||
printf("#CYCLE\tUPDS\tERR_IN\tERR_BREAK\tERR_OUT\tSPLITS\t\tBLK_FAILS\tMAX\t\tFROB\t\tCOND\t\tCPU_CYC\t\tCPU_CYC/UPD\tCUMUL\t\tREC\n");
|
||||
printf("#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n");
|
||||
// printf("#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n");
|
||||
// printf("#1\t2\t3\t4\t\t5\t6\t\t7\t\t8\t\t9\t\t10\t\t11\t\t12\t\t13\t\t14\n");
|
||||
// printf("#CYCLE\tUPDS\tERR_IN\tERR_BREAK\tERR_OUT\tSPLITS\t\tBLK_FAILS\tMAX\t\tFROB\t\tCOND\t\tCPU_CYC\t\tCPU_CYC/UPD\tCUMUL\t\tREC\n");
|
||||
// printf("#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n");
|
||||
|
||||
(void) H5Fclose(file_id);
|
||||
|
||||
#ifdef HAVE_CUBLAS_OFFLOAD
|
||||
#ifdef USE_OMP_OFFLOAD_CUDA
|
||||
cublasDestroy_v2(handle);
|
||||
cusolverDnDestroy(s_handle);
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user