// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#include <arm_neon.h>
#include <assert.h>
#include <atomic>
#include <cfloat>
#include <cmath>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
#include <vector>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
#elif defined(__APPLE__)
#include <string_view>
#include <sys/sysctl.h>
#include <sys/types.h>
#elif defined(_WIN32)
#include <windows.h>
#include <excpt.h>
#endif

#include "kleidiai.h"

#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-threading.h"
#include "traits.h"

#include "kernels.h"

#include "kai_common.h"

#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"

struct ggml_kleidiai_context {
    cpu_feature features;
    ggml_kleidiai_kernels * kernels_q4;
    ggml_kleidiai_kernels * kernels_q8;
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };

static const char* cpu_feature_to_string(cpu_feature f) {
    if (f == CPU_FEATURE_NONE) {
        return "NONE";
    } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
        return "SME";
    } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
        return "SVE";
    }
    else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
        return "I8MM";
    } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
        return "DOTPROD";
    }
    else {
        return "UNKNOWN";
    }
}

static void init_kleidiai_context(void) {

    ggml_critical_section_start();
    static bool initialized = false;

    if (!initialized) {
        initialized = true;
        const char *env_var = getenv("GGML_KLEIDIAI_SME");
        int sme_enabled = 0;

        ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
                        (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
                        ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);

        if (env_var) {
            sme_enabled = atoi(env_var);
        }

        if (sme_enabled != 0) {
            ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
        }
        ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
        ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
#ifndef NDEBUG
        if (ctx.kernels_q4) {
            GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
        }
        if (ctx.kernels_q8) {
            GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
        }
#endif
    }
    ggml_critical_section_end();
}

static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
    return tensor->ne[dim];
}

namespace ggml::cpu::kleidiai {

static size_t round_down(size_t x, size_t y) {
    return y == 0 ? x : x - (x % y);
}

static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
    size_t src_stride = rhs_stride / sizeof(uint16_t);
    size_t dst_stride = n;

    for (size_t k_idx = 0; k_idx < k; ++k_idx) {
        for (size_t n_idx = 0; n_idx < n; ++n_idx) {
            uint16_t v = *(src + k_idx + n_idx * src_stride);
            *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
        }
    }
}

class tensor_traits : public ggml::cpu::tensor_traits {
    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
        if (op->op != GGML_OP_MUL_MAT) {
            return false;
        }
        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
        if (!kernels) {
            return false;
        }
        bool is_gemv = op->src[1]->ne[1] == 1;
        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;

        size_t k = op->src[0]->ne[0];
        size_t n = op->src[0]->ne[1];
        size_t m = op->src[1]->ne[1];

        size_t mr = kernel->get_mr();
        size_t kr = kernel->get_kr();
        size_t sr = kernel->get_sr();

        if (kernels->rhs_type == GGML_TYPE_Q4_0) {
            if (!lhs_info->packed_size_ex) return false;
            size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
        } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
            if (!lhs_info->packed_size_ex) return false;
            size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
        } else if (kernels->rhs_type == GGML_TYPE_F16) {
            if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
            const int64_t lhs_batch_size0 = op->src[1]->ne[2];
            const int64_t rhs_batch_size0 = op->src[0]->ne[2];
            const int64_t r = lhs_batch_size0 / rhs_batch_size0;
            size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
                   kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
                   k * n * sizeof(float) + n * sizeof(float);
        } else {
            return false;
        }

        return true;
    }

    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
        if (dst->op == GGML_OP_MUL_MAT) {
            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
                return compute_forward_q4_0(params, dst);
            } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
                return compute_forward_q8_0(params, dst);
            } else if (dst->src[0]->type == GGML_TYPE_F16) {
                return compute_forward_fp16(params, dst);
            }
        } else if (dst->op == GGML_OP_GET_ROWS) {
            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
                return compute_forward_get_rows(params, dst);
            }
        }
        return false;
    }

    bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
        const ggml_tensor * src0 = dst->src[0];
        const ggml_tensor * src1 = dst->src[1];

        GGML_TENSOR_BINARY_OP_LOCALS

        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
        if (!kernels) {
            return false;
        }

        const bool is_gemv = src1->ne[1] == 1;
        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
        GGML_ASSERT(kernel);
        if (!kernels->rhs_info.pack_func_ex ||
            !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
            return false;
        }

        const int nth = params->nth;
        const int ith = params->ith;

        const int64_t lhs_batch_size0 = ne12;
        const int64_t rhs_batch_size0 = ne02;
        const int64_t batch_size      = lhs_batch_size0;

        GGML_ASSERT(rhs_batch_size0 > 0);
        GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
        const int64_t r = lhs_batch_size0 / rhs_batch_size0;

        const int64_t m_group = ne11;
        const int64_t m       = m_group;
        const int64_t n       = ne01;
        const int64_t k       = ne00;

        const size_t lhs_stride = src1->nb[1];
        const size_t rhs_stride = src0->nb[1];
        const size_t dst_stride = dst->nb[1];

        const int64_t mr = (int64_t) kernel->get_mr();
        const int64_t nr = (int64_t) kernel->get_nr();
        const int64_t kr = (int64_t) kernel->get_kr();
        const int64_t sr = (int64_t) kernel->get_sr();

        const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
        const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
        const size_t kxn_size        = k * n * sizeof(float);
        const size_t bias_size       = n * sizeof(float);

        const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
        GGML_ASSERT(wsize_required <= params->wsize);

        uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
        uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
        uint8_t * rhs_kxn    = rhs_packed + rhs_packed_size;
        uint8_t * bias       = rhs_kxn + kxn_size;

        for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
            const int64_t rhs_batch_idx = batch_idx / r;
            const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
            uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];

            // LHS packing (threaded over m, honoring mr alignment and KV groups)
            {
                const int64_t m_roundup_mr = kai_roundup(m, mr);
                const int64_t num_threads  = KAI_MIN(m_roundup_mr / mr, nth);

                if (ith < num_threads) {
                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
                    const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;

                    const int64_t m_start = ith * num_m_per_thread0;
                    const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;

                    // Base packed offset (aligned) and per-row stride in bytes
                    const size_t base_packed_off  = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
                    const size_t next_block_off   = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
                    const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;

                    int64_t remaining = m_count;
                    int64_t cur       = m_start;

                    while (remaining > 0) {
                        const int64_t row_in_group = cur;
                        const int64_t avail        = m_group - row_in_group;
                        const int64_t take         = std::min(avail, remaining);

                        const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
                        const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
                        void * dst_ptr       = lhs_packed + dst_off;

                        lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);

                        cur       += take;
                        remaining -= take;
                    }
                }
            }

            // RHS packing (single thread), then synchronize
            if (ith == 0) {
                memset(bias, 0, (size_t)n * sizeof(float));
                transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
                                        reinterpret_cast<float *>(rhs_kxn),
                                        reinterpret_cast<const uint16_t *>(rhs_batch_base),
                                        rhs_stride);

                kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
                             rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
            }

            ggml_barrier(params->threadpool);

            // Matmul (threaded over n)
            {
                const int64_t n_step  = (int64_t) kernel->get_n_step();
                int64_t num_threads_n = KAI_MIN(n / n_step, nth);
                if (num_threads_n <= 0) {
                    num_threads_n = 1;
                }

                if (ith < num_threads_n) {
                    const int64_t num_n_per_thread0   = round_down((size_t)(n / num_threads_n), (size_t)n_step);
                    const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;

                    const int64_t n_start      = ith * num_n_per_thread0;
                    const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;

                    // LHS packed base at row 0 (consistent with packing above)
                    const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
                    const size_t rhs_packed_offset  = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
                    const size_t dst_offset         = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);

                    const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
                    const void * rhs_ptr = rhs_packed + rhs_packed_offset;
                    float * dst_ptr      = reinterpret_cast<float *>(dst_batch_base + dst_offset);

                    kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
                }
            }

            if (batch_idx != batch_size - 1) {
                ggml_barrier(params->threadpool);
            }
        }

        return true;
    }

    bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);

        const ggml_tensor * src0 = dst->src[0];
        const ggml_tensor * src1 = dst->src[1];

        GGML_TENSOR_BINARY_OP_LOCALS

        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
        if (!kernels) {
            return false;
        }

        bool is_gemv = src1->ne[1] == 1;
        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;

        GGML_ASSERT(kernel);
        if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
            return false;
        }

        const int ith = params->ith;
        const int nth_raw = params->nth;
        const int nth = nth_raw > 0 ? nth_raw : 1;

        const size_t k = ne00;
        const size_t m = ne11;
        const size_t n = ne01;

        size_t mr = kernel->get_mr();
        size_t kr = kernel->get_kr();
        size_t sr = kernel->get_sr();

        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
        uint8_t * lhs_packed       = (uint8_t*)params->wdata;
        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);

        const size_t n_step = kernel->get_n_step();
        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
        const size_t n_start = ith * num_n_per_thread;

        size_t n_to_process = 0;
        if (n_start < n) {
            n_to_process = num_n_per_thread;
            if ((n_start + n_to_process) > n) {
                n_to_process = n - n_start;
            }
        }

        // Calculate number of columns to be processed per thread
        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
        const size_t m_start = ith * num_m_per_thread;
        size_t m_to_process = num_m_per_thread;
        if ((m_start + m_to_process) > m) {
            m_to_process = m - m_start;
        }

        if (m_start < m) {
            // Transform LHS
            const size_t src_stride        = src1->nb[1];
            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);

            // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
            lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
        }

        ggml_barrier(params->threadpool);

        // Perform the operation
        const size_t dst_stride        = dst->nb[1];
        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
        const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
        float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);

        if (n_to_process > 0) {
            kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
                               sizeof(float), -FLT_MAX, FLT_MAX);
        }

        return true;
    }

    bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);

        const ggml_tensor * src0 = dst->src[0];
        const ggml_tensor * src1 = dst->src[1];

        GGML_TENSOR_BINARY_OP_LOCALS

        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
        if (!kernels) {
            return false;
        }

        bool is_gemv = src1->ne[1] == 1;
        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;

        if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
            return false;
        }

        const int ith = params->ith;
        const int nth_raw = params->nth;
        const int nth = nth_raw > 0 ? nth_raw : 1;

        const size_t k = ne00;
        const size_t m = ne11;
        const size_t n = ne01;

        size_t mr = kernel->get_mr();
        size_t kr = kernel->get_kr();
        size_t sr = kernel->get_sr();

        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
        uint8_t * lhs_packed       = static_cast<uint8_t *>(params->wdata);
        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);

        const size_t n_step = kernel->get_n_step();
        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
        const size_t n_start = ith * num_n_per_thread;

        size_t n_to_process = 0;
        if (n_start < n) {
            n_to_process = num_n_per_thread;
            if ((n_start + n_to_process) > n) {
                n_to_process = n - n_start;
            }
        }

        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
        const size_t m_start = ith * num_m_per_thread;
        size_t m_to_process = num_m_per_thread;
        if ((m_start + m_to_process) > m) {
            m_to_process = m - m_start;
        }

        if (m_start < m) {
            const size_t src_stride        = src1->nb[1];
            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);

            lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
        }

        ggml_barrier(params->threadpool);

        const size_t dst_stride        = dst->nb[1];
        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
        const void * lhs_ptr           = static_cast<const void *>(lhs_packed + lhs_packed_offset);
        float * dst_ptr                = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);

        if (n_to_process > 0) {
            kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
                                  sizeof(float), -FLT_MAX, FLT_MAX);
        }

        return true;
    }

    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
        const ggml_tensor * src0 = dst->src[0];
        const ggml_tensor * src1 = dst->src[1];

        GGML_TENSOR_BINARY_OP_LOCALS

        ggml_kleidiai_kernels * kernels = nullptr;
        size_t block_len = 0;
        size_t num_bytes_multiplier = 0;

        if (dst->src[0]->type == GGML_TYPE_Q4_0) {
            if (!ctx.kernels_q4) {
                return false;
            }
            kernels = ctx.kernels_q4;
            block_len = QK4_0;
            num_bytes_multiplier = sizeof(uint16_t);
        } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
            if (!ctx.kernels_q8) {
                return false;
            }
            kernels = ctx.kernels_q8;
            block_len = QK8_0;
            num_bytes_multiplier = sizeof(float);
        } else {
            return false;
        }

        rhs_packing_info * rhs_info = &kernels->rhs_info;
        kernel_info * kernel        = &kernels->gemm;
        if (!rhs_info->to_float || !kernel->get_nr) {
            return false;
        }

        const int64_t nc     = ne00;
        const int64_t nr     = ggml_nelements(src1);

        const size_t block_rows = kernel->get_nr();
        const size_t kr         = kernel->get_kr();

        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);

        const int ith = params->ith;
        const int nth = params->nth;

        const int dr = (nr + nth - 1) / nth;
        const int ir0 = dr * ith;
        const int ir1 = MIN(ir0 + dr, nr);

        for (int64_t i = ir0; i < ir1; ++i) {
            GGML_ASSERT(src1->type == GGML_TYPE_I32);
            int64_t row_idx = ((const int32_t *)src1->data)[i];
            GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);

            float *out = (float *)((char *)dst->data + i * nb1);
            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
        }

        return true;
    }

public:
    int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
        const size_t n = tensor->ne[1];
        const size_t k = tensor->ne[0];

        if (tensor->type == GGML_TYPE_Q4_0) {
            if (!ctx.kernels_q4) {
                return -1;
            }
            size_t nr = ctx.kernels_q4->gemm.get_nr();
            size_t kr = ctx.kernels_q4->gemm.get_kr();
            size_t sr = ctx.kernels_q4->gemm.get_sr();

            struct kai_rhs_pack_qs4cxs1s0_param params;
            params.lhs_zero_point = 1;
            params.rhs_zero_point = 8;
            ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
                                                  static_cast<const uint8_t *>(data),
                                                  nullptr, nullptr, tensor->data, 0, &params);
            GGML_UNUSED(data_size);
            return 0;
        } else if (tensor->type == GGML_TYPE_Q8_0) {
            if (!ctx.kernels_q8) {
                return -1;
            }

            const size_t row_stride = tensor->nb[1];
            const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;

            std::vector<int8_t> qdata(n * k, 0);
            std::vector<float> scales(n, 0.0f);

            for (size_t row = 0; row < n; ++row) {
                const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
                    static_cast<const uint8_t *>(data) + row * row_stride);

                float max_abs = 0.0f;
                for (size_t block = 0; block < k_blocks; ++block) {
                    const block_q8_0 & blk = row_blocks[block];
                    const float d = GGML_FP16_TO_FP32(blk.d);
                    for (size_t l = 0; l < QK8_0; ++l) {
                        const size_t linear_idx = block * QK8_0 + l;
                        if (linear_idx >= k) {
                            break;
                        }
                        const float value = d * blk.qs[l];
                        max_abs = std::max(max_abs, std::fabs(value));
                    }
                }

                float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
                scales[row] = scale;
                const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;

                for (size_t block = 0; block < k_blocks; ++block) {
                    const block_q8_0 & blk = row_blocks[block];
                    const float d = GGML_FP16_TO_FP32(blk.d);
                    for (size_t l = 0; l < QK8_0; ++l) {
                        const size_t linear_idx = block * QK8_0 + l;
                        if (linear_idx >= k) {
                            break;
                        }
                        const float value = d * blk.qs[l];
                        int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
                        q = std::clamp(q, -127, 127);
                        qdata[row * k + linear_idx] = static_cast<int8_t>(q);
                    }
                }
            }

            size_t nr = ctx.kernels_q8->gemm.get_nr();
            size_t kr = ctx.kernels_q8->gemm.get_kr();
            size_t sr = ctx.kernels_q8->gemm.get_sr();

            struct kai_rhs_pack_qsi8cx_params params;
            params.lhs_zero_point = 1;
            params.scale_multiplier = 1.0f;

            ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
                                                  qdata.data(), nullptr, scales.data(),
                                                  tensor->data, 0, &params);
            GGML_UNUSED(data_size);
            return 0;
        }

        GGML_UNUSED(data_size);
        return -1;
    }
};

static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
    static tensor_traits traits;
    return &traits;
}
}  // namespace ggml::cpu::kleidiai

static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
    tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);

    return GGML_STATUS_SUCCESS;
    GGML_UNUSED(buffer);
}

static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
                                                       const void * data, size_t offset, size_t size) {
    GGML_ASSERT(offset == 0);
    GGML_ASSERT(size == ggml_nbytes(tensor));

    auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
    auto OK            = tensor_traits->repack(tensor, data, size);

    GGML_ASSERT(OK == 0);
    GGML_UNUSED(buffer);
}

static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
    return "CPU_KLEIDIAI";

    GGML_UNUSED(buft);
}

static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);

    if (buffer == nullptr) {
        return nullptr;
    }

    buffer->buft              = buft;
    buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
    buffer->iface.set_tensor  = ggml_backend_cpu_kleidiai_buffer_set_tensor;
    buffer->iface.get_tensor  = nullptr;
    buffer->iface.cpy_tensor  = nullptr;
    return buffer;
}

static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
    return TENSOR_ALIGNMENT;

    GGML_UNUSED(buft);
}

static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
    GGML_UNUSED(buft);

    const size_t n = tensor->ne[1];
    const size_t k = tensor->ne[0];

    ggml_kleidiai_kernels * kernels = nullptr;
    size_t block_len = 0;

    if (tensor->type == GGML_TYPE_Q4_0) {
        GGML_ASSERT(ctx.kernels_q4);
        kernels = ctx.kernels_q4;
        block_len = QK4_0;
    } else if (tensor->type == GGML_TYPE_Q8_0) {
        GGML_ASSERT(ctx.kernels_q8);
        kernels = ctx.kernels_q8;
        block_len = QK8_0;
    } else {
        return 0;
    }

    const size_t nr = kernels->gemm.get_nr();
    const size_t kr = kernels->gemm.get_kr();
    const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
    const size_t raw     = ggml_nbytes(tensor);

    return packed > raw ? packed : raw;
}

namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
        if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
            (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
            op->src[0]->buffer &&
            (ggml_n_dims(op->src[0]) == 2) &&
            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
            if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
                return false;
            }
            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
                return false;
            }
            if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
                ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
                return true;
            }
        }
        return false;
    }

    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                return (ggml::cpu::tensor_traits *) op->src[0]->extra;
            }
            else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
                if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
                    (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
                    return nullptr;
                }

                return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
            }
        }
        return nullptr;
    }
};
}  // namespace ggml::cpu::kleidiai

ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
    static ggml::cpu::kleidiai::extra_buffer_type ctx;
    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
        /* .iface    = */ {
                           /* .get_name         = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
                           /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
                           /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
                           /* .get_alloc_size   = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
                           /* .is_host          = */ nullptr,
                           },
        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
        /* .context = */ &ctx,
    };

    init_kleidiai_context();

    return &ggml_backend_cpu_buffer_type_kleidiai;
}
