/*******************************************************************************
* Copyright (C) 2025 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       DPCPP USM API oneapi::mkl::sparse::gemv to perform general
*       sparse matrix-vector  multiplication on a SYCL device (CPU, GPU). This
*       example uses a sparse matrix in CSC format.
*
*       y = alpha * op(A) * x + beta * y
*
*       where op() is defined by one of
*           oneapi::mkl::transpose::{nontrans,trans,conjtrans}
*
*       The supported floating point data types for gemv are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*       The supported matrix formats for gemv are:
*           CSR
*           CSC
*           COO
*           BSR
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

//
// Main example for Sparse Matrix-Vector Multiply consisting of
// initialization of A matrix, x and y vectors as well as
// scalars alpha and beta.  Then the product
//
// y = alpha * op(A) * x + beta * y
//
// is performed and finally the results are post processed.
//
template <typename dataType, typename intType>
int run_sparse_blas_example(sycl::queue &q)
{
    bool good = true;

    //
    // handle for sparse matrix
    //
    oneapi::mkl::sparse::matrix_handle_t cscA = nullptr;

    // create arrays to handle deallocation
    std::vector<intType *> int_ptr_vec;
    std::vector<dataType *> data_ptr_vec;

    try {

        // Initialize data for Sparse Matrix-Vector Multiply
        oneapi::mkl::transpose transpose_val = oneapi::mkl::transpose::trans;
        oneapi::mkl::index_base index_base_val = oneapi::mkl::index_base::zero;
        intType int_index = (index_base_val == oneapi::mkl::index_base::zero ? 0 : 1);

        // Matrix data size
        intType size  = 4;
        const std::int64_t nrows = size * size * size;
        const std::int64_t ncols = nrows;

        // Set up data for the sparse matrix in CSC format on the host
        intType *ia_host = sycl::malloc_host<intType>(ncols + 1, q);
        intType *ja_host = sycl::malloc_host<intType>(27 * nrows, q);
        dataType *a_host = sycl::malloc_host<dataType>(27 * nrows, q);
        if (!ia_host || !ja_host || !a_host) {
            std::string errorMessage =
                "Failed to allocate USM host memory arrays \n"
                " for CSC A matrix: ia(" + std::to_string((ncols+1)*sizeof(intType)) + " bytes)\n"
                "                   ja(" + std::to_string((nrows * 27)*sizeof(intType)) + " bytes)\n"
                "                   a(" + std::to_string((nrows * 27)*sizeof(dataType)) + " bytes)";

            throw std::runtime_error(errorMessage);
        }
        int_ptr_vec.push_back(ia_host);
        int_ptr_vec.push_back(ja_host);
        data_ptr_vec.push_back(a_host);

        // Generate a sparse matrix in CSC format with int_index indexing
        // Note: the below function technically generates a CSR matrix, but
        // because nrows == ncols == size^3, the generated data can be
        // interpreted as CSC for the purpose of this example.
        generate_sparse_matrix<dataType, intType>(size, ia_host, ja_host, a_host, int_index);
        const std::int64_t nnz = ia_host[ncols] - int_index;

        // Set up data for the sparse matrix in CSC format on the device
        intType *ia = sycl::malloc_device<intType>(ncols + 1, q);
        intType *ja = sycl::malloc_device<intType>(nnz, q);
        dataType *a = sycl::malloc_device<dataType>(nnz, q);

        if (!ia || !ja || !a) {
           std::string errorMessage =
               "Failed to allocate USM device memory arrays \n"
               " for CSC A matrix: ia(" + std::to_string((ncols+1)*sizeof(intType)) + " bytes)\n"
               "                   ja(" + std::to_string((nnz)*sizeof(intType)) + " bytes)\n"
               "                   a(" + std::to_string((nnz)*sizeof(dataType)) + " bytes)";

            throw std::runtime_error(errorMessage);
        }
        int_ptr_vec.push_back(ia);
        int_ptr_vec.push_back(ja);
        data_ptr_vec.push_back(a);

        // copy A matrix USM data from host to device
        auto ev_cpy_ia = q.copy<intType>(ia_host, ia, ncols + 1);
        auto ev_cpy_ja = q.copy<intType>(ja_host, ja, nnz);
        auto ev_cpy_a  = q.copy<dataType>(a_host, a, nnz);

        // Init vectors x and y on the host
        const std::int64_t x_len = (transpose_val == oneapi::mkl::transpose::nontrans) ? ncols : nrows;
        const std::int64_t y_len = (transpose_val == oneapi::mkl::transpose::nontrans) ? nrows : ncols;
        dataType *x_host = sycl::malloc_host<dataType>(x_len, q);
        dataType *y_host = sycl::malloc_host<dataType>(y_len, q);
        dataType *y_ref_host = sycl::malloc_host<dataType>(y_len, q);
        if (!x_host || !y_host || !y_ref_host) {
            std::string errorMessage =
                "Failed to allocate USM host memory arrays \n"
                " for x vector(" + std::to_string((x_len)*sizeof(dataType)) + " bytes)\n"
                " for y vector(" + std::to_string((y_len)*sizeof(dataType)) + " bytes)\n"
                " for y_ref vector(" + std::to_string((y_len)*sizeof(dataType)) + " bytes)";
            throw std::runtime_error(errorMessage);
        }
        data_ptr_vec.push_back(x_host);
        data_ptr_vec.push_back(y_host);
        data_ptr_vec.push_back(y_ref_host);

        for (intType i = 0; i < x_len; i++) {
            x_host[i]     = set_fp_value(dataType(1.0), dataType( 0.0));
        }
        for (intType i = 0; i < y_len; i++) {
            y_host[i]     = set_fp_value(dataType(1.0), dataType(-1.0));
            y_ref_host[i] = set_fp_value(dataType(1.0), dataType(-1.0));
        }

        // Copy vectors x and y from host to device
        dataType *x = sycl::malloc_device<dataType>(x_len, q);
        dataType *y = sycl::malloc_device<dataType>(y_len, q);
        if (!x || !y) {
            std::string errorMessage =
                "Failed to allocate USM device memory arrays \n"
                " for x vector(" + std::to_string((x_len)*sizeof(dataType)) + " bytes)\n"
                " for y vector(" + std::to_string((y_len)*sizeof(dataType)) + " bytes)";
            throw std::runtime_error(errorMessage);
        }
        data_ptr_vec.push_back(x);
        data_ptr_vec.push_back(y);

        // Copy x and y vectors from host to device
        auto ev_cpy_x = q.copy<dataType>(x_host, x, x_len);
        auto ev_cpy_y = q.copy<dataType>(y_host, y, y_len);

        // Set scalar dataType values
        dataType alpha, beta;
        if constexpr (is_complex<dataType>()) {
            alpha = set_fp_value(dataType(1.0), dataType(-1.0));
            beta  = set_fp_value(dataType(2.0), dataType(1.0));
        }
        else {
            alpha = set_fp_value(dataType(1.0), dataType(0.0));
            beta  = set_fp_value(dataType(2.0), dataType(0.0));
        }

        //
        // Execute Sparse Matrix - Dense Vector Multiply
        //

        std::cout << "\n\t\tsparse::gemv parameters:\n";
        std::cout << "\t\t\ttranspose_val  = " << transpose_val << std::endl;
        std::cout << "\t\t\tindex_base_val = " << index_base_val << std::endl;
        std::cout << "\t\t\tnrows          = " << nrows << std::endl;
        std::cout << "\t\t\tncols          = " << ncols << std::endl;
        std::cout << "\t\t\tnnz            = " << nnz << std::endl;
        std::cout << "\t\t\talpha          = " << alpha << std::endl;
        std::cout << "\t\t\tbeta           = " << beta << std::endl;

        oneapi::mkl::sparse::init_matrix_handle(&cscA);

        sycl::event ev_set_csc = oneapi::mkl::sparse::set_csc_data(q, cscA, nrows, ncols, nnz, index_base_val,
                                          ia, ja, a, {ev_cpy_ia, ev_cpy_ja, ev_cpy_a});

        sycl::event ev_gemv = oneapi::mkl::sparse::gemv(q, transpose_val, alpha, cscA,
                                  x, beta, y, {ev_set_csc, ev_cpy_x, ev_cpy_y});

        oneapi::mkl::sparse::release_matrix_handle(q, &cscA, {ev_gemv}).wait();

        //
        // Post Processing
        //

        // Copy result vector y from device to host
        ev_cpy_y = q.copy<dataType>(y, y_host, y_len);
        ev_cpy_y.wait();

        // We validate against CSC GEMV reference solution here
        if (transpose_val == oneapi::mkl::transpose::nontrans) {
            for (intType row = 0; row < nrows; row++) {
                y_ref_host[row] *= beta;
            }
            for (intType col = 0; col < ncols; col++) {
                dataType tmp = alpha * x_host[col];
                for (intType i = ia_host[col]-int_index; i < ia_host[col+1]-int_index; i++) {
                    intType row = ja_host[i]-int_index;
                    dataType val = a_host[i];
                    y_ref_host[row] += tmp * val;
                }
            }
        }
        else { // transpose_val is trans or conjtrans
            const bool isConj = (transpose_val == oneapi::mkl::transpose::conjtrans);
            for (intType col = 0; col < ncols; col++) {
                dataType tmp = dataType(0.0);
                for (intType i = ia_host[col]-int_index; i < ia_host[col+1]-int_index; i++) {
                    intType row = ja_host[i]-int_index;
                    dataType val = a_host[i];
                    if constexpr (is_complex<dataType>()) {
                        tmp += (isConj ? std::conj(val) : val) * x_host[row];
                    }
                    else {
                        tmp += val * x_host[row];
                    }
                }
                y_ref_host[col] = alpha * tmp + beta * y_ref_host[col];
            }
        }

        for (intType i = 0; i < y_len; i++) {
            good &= check_result(y_host[i], y_ref_host[i], y_len, i);
        }

        std::cout << "\n\t\t sparse::gemv example " << (good ? "passed" : "failed")
                  << "\n\tFinished" << std::endl;

        q.wait_and_throw();

    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;
        good = false;

    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        good = false;
    }

    q.wait();

    // backup cleaning of release_handle if exceptions happened
    oneapi::mkl::sparse::release_matrix_handle(q, &cscA).wait();

    // cleanup allocations
    cleanup_arrays<dataType, intType>(data_ptr_vec, int_ptr_vec, q);

    q.wait();

    return good ? 0 : 1;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Matrix-Vector Multiply Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# y = alpha * op(A) * x + beta * y" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A is a sparse matrix in CSC format, x and y are "
                 "dense vectors"
              << std::endl;
    std::cout << "# and alpha, beta are floating point type precision scalars." << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::gemv" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type,
//  run_sparse_blas_example is run with all supported data types,
//  if any fail, we move on to the next device.
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        try {
            sycl::device my_dev;
            bool my_dev_is_found = false;
            get_sycl_device(my_dev, my_dev_is_found, *it);

            if (my_dev_is_found) {
                std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

                // Catch asynchronous exceptions
                auto exception_handler = [](sycl::exception_list exceptions) {
                    for (std::exception_ptr const &e : exceptions) {
                        try {
                            std::rethrow_exception(e);
                        }
                        catch (sycl::exception const &e) {
                            std::cout << "Caught asynchronous SYCL exception: \n"
                                << e.what() << std::endl;
                        }
                    }
                };

                sycl::queue q(my_dev, exception_handler);

                std::cout << "\tRunning with single precision real data type:" << std::endl;
                status |= run_sparse_blas_example<float, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision real data type:" << std::endl;
                    status |= run_sparse_blas_example<double, std::int32_t>(q);
                }

                std::cout << "\tRunning with single precision complex data type:" << std::endl;
                status |= run_sparse_blas_example<std::complex<float>, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision complex data type:" << std::endl;
                    status |= run_sparse_blas_example<std::complex<double>, std::int32_t>(q);
                }

            }
            else {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it]
                    << " devices found; Fail on missing devices "
                    "is enabled.\n";
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                    << sycl_device_names[*it] << " tests.\n";
#endif
            }
        }
        catch (sycl::exception const &e) {
            std::cout << "\t\tCaught SYCL exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }
        catch (std::exception const &e) {
            std::cout << "\t\tCaught std exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }


    } // for loop over devices

    mkl_free_buffers();
    return status;
}
