{{py:

implementation_specific_values = [
    # Values are the following ones:
    #
    #       name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE
    #
    # We also use the float64 dtype and C-type names as defined in
    # `sklearn.utils._typedefs` to maintain consistency.
    #
    ('64', False, 'float64_t', 'np.float64'),
    ('32', True, 'float32_t', 'np.float32')
]

}}
from libcpp.vector cimport vector

from ...utils._typedefs cimport float64_t, float32_t, int32_t, intp_t


cdef void _middle_term_sparse_sparse_64(
    const float64_t[:] X_data,
    const int32_t[:] X_indices,
    const int32_t[:] X_indptr,
    intp_t X_start,
    intp_t X_end,
    const float64_t[:] Y_data,
    const int32_t[:] Y_indices,
    const int32_t[:] Y_indptr,
    intp_t Y_start,
    intp_t Y_end,
    float64_t * D,
) noexcept nogil


{{for name_suffix, upcast_to_float64, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}


cdef class MiddleTermComputer{{name_suffix}}:
    cdef:
        intp_t effective_n_threads
        intp_t chunks_n_threads
        intp_t dist_middle_terms_chunks_size
        intp_t n_features
        intp_t chunk_size

        # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM
        vector[vector[float64_t]] dist_middle_terms_chunks

    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil

    cdef void _parallel_on_X_parallel_init(self, intp_t thread_num) noexcept nogil

    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_Y_init(self) noexcept nogil

    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num
    ) noexcept nogil

    cdef float64_t * _compute_dist_middle_terms(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil


cdef class DenseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}):
    cdef:
        const {{INPUT_DTYPE_t}}[:, ::1] X
        const {{INPUT_DTYPE_t}}[:, ::1] Y

    {{if upcast_to_float64}}
        # Buffers for upcasting chunks of X and Y from 32bit to 64bit
        vector[vector[float64_t]] X_c_upcast
        vector[vector[float64_t]] Y_c_upcast
    {{endif}}

    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil

    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil

    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num
    ) noexcept nogil

    cdef float64_t * _compute_dist_middle_terms(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil


cdef class SparseSparseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}):
    cdef:
        const float64_t[:] X_data
        const int32_t[:] X_indices
        const int32_t[:] X_indptr

        const float64_t[:] Y_data
        const int32_t[:] Y_indices
        const int32_t[:] Y_indptr

    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num
    ) noexcept nogil

    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num
    ) noexcept nogil

    cdef float64_t * _compute_dist_middle_terms(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil


cdef class SparseDenseMiddleTermComputer{{name_suffix}}(MiddleTermComputer{{name_suffix}}):
    cdef:
        const float64_t[:] X_data
        const int32_t[:] X_indices
        const int32_t[:] X_indptr

        const {{INPUT_DTYPE_t}}[:, ::1] Y

        # We treat the dense-sparse case with the sparse-dense case by simply
        # treating the dist_middle_terms as F-ordered and by swapping arguments.
        # This attribute is meant to encode the case and adapt the logic
        # accordingly.
        bint c_ordered_middle_term

    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num
    ) noexcept nogil

    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num
    ) noexcept nogil

    cdef float64_t * _compute_dist_middle_terms(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil

{{endfor}}
