diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5660fbe92..a4a2d2840 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -38,6 +38,8 @@ jobs: arch: aarch64 - os: windows-2025 arch: x86_64 + - os: windows-11-arm + arch: arm64 - os: macos-15 arch: arm64 runs-on: ${{ matrix.os }} @@ -46,6 +48,8 @@ jobs: - name: Setup MSVC if: startsWith(matrix.os, 'windows') uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + with: + arch: ${{ matrix.arch == 'arm64' && 'arm64' || 'x64' }} - name: Build C++ run: bash .github/scripts/build-cpu.sh env: @@ -188,20 +192,27 @@ jobs: - build-xpu strategy: matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15] + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, windows-11-arm, macos-15] include: - os: ubuntu-22.04 arch: x86_64 + python-version: "3.10" - os: ubuntu-22.04-arm arch: aarch64 + python-version: "3.10" - os: windows-2025 arch: x86_64 + python-version: "3.10" + - os: windows-11-arm + arch: arm64 + # Python for Windows ARM64 is only available from 3.12+ + python-version: "3.12" - os: macos-15 arch: arm64 + python-version: "3.10" # The specific Python version is irrelevant in this context as we are only packaging non-C extension # code. This ensures compatibility across Python versions, as compatibility is # dictated by the packaged code itself, not the Python version used for packaging. - python-version: ["3.10"] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -327,6 +338,8 @@ jobs: echo "### Linux (aarch64)" >> body.md elif [[ "$fname" == *"win_amd64"* ]]; then echo "### Windows (x86_64)" >> body.md + elif [[ "$fname" == *"win_arm64"* ]]; then + echo "### Windows (arm64)" >> body.md elif [[ "$fname" == *"macosx"* ]]; then echo "### macOS 14+ (arm64)" >> body.md else diff --git a/CMakeLists.txt b/CMakeLists.txt index d9c691e79..a787866f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -337,7 +337,13 @@ if(WIN32) endif() if(MSVC) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") + # /arch:AVX2 is only valid for x86/x64 targets, not ARM64 + string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" _msvc_arch) + if(_msvc_arch MATCHES "x86|x64|amd64") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /fp:fast") + endif() endif() set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 9e20153b5..2a8912674 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -37,6 +37,181 @@ inline unsigned char lookup_code_index(const float* codebook, float value) { } // namespace +#if defined(_M_ARM64) || defined(__aarch64__) +#include + +// ARM NEON NF4 lookup table (16 float values indexed by 4-bit code) +static inline void neon_nf4_lut(float32x4_t lut[4]) { + // Indices 0-15 map to NF4 dequantized values + // Order: index 0 = -1.0, index 1 = -0.6962, ..., index 7 = 0.0, ..., index 15 = 1.0 + static const float nf4_values[16] = { + -1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f + }; + lut[0] = vld1q_f32(nf4_values); + lut[1] = vld1q_f32(nf4_values + 4); + lut[2] = vld1q_f32(nf4_values + 8); + lut[3] = vld1q_f32(nf4_values + 12); +} + +// ARM NEON FP4 lookup table +static inline void neon_fp4_lut(float32x4_t lut[4]) { + static const float fp4_values[16] = { + 0.0f, 5.208333333e-03f, 0.66666667f, 1.0f, 0.33333333f, 0.5f, 0.16666667f, 0.25f, + 0.0f, -5.208333333e-03f, -0.66666667f, -1.0f, -0.33333333f, -0.5f, -0.16666667f, -0.25f + }; + lut[0] = vld1q_f32(fp4_values); + lut[1] = vld1q_f32(fp4_values + 4); + lut[2] = vld1q_f32(fp4_values + 8); + lut[3] = vld1q_f32(fp4_values + 12); +} + +// Efficient NEON float LUT lookup using indexed lane extraction +// The LUT has 16 float entries stored as a flat array for direct indexing +static inline float neon_lut_lookup_flat(const float* flat_lut, uint8_t idx) { return flat_lut[idx]; } + +// Vectorized 4-bit dequantization: process 8 packed bytes = 16 output values +// Each byte contains two 4-bit values: high nibble first, low nibble second +static inline void + neon_dequant_4bit_16values(const uint8_t* packed, float scale, const float32x4_t lut[4], float* out) { + // Load 8 bytes = 16 x 4-bit values + uint8x8_t raw = vld1_u8(packed); + + // Extract high and low nibbles + uint8x8_t mask4 = vdup_n_u8(0x0F); + uint8x8_t lo_nibbles = vand_u8(raw, mask4); // low nibble (second value) + uint8x8_t hi_nibbles = vshr_n_u8(raw, 4); // high nibble (first value) + + // Interleave hi/lo into 16-element index array for output ordering + // output[2*i] = hi_nibble[i], output[2*i+1] = lo_nibble[i] + uint8x8x2_t interleaved = vzip_u8(hi_nibbles, lo_nibbles); + // interleaved.val[0] has elements 0-7, interleaved.val[1] has elements 8-15 + uint8x16_t indices = vcombine_u8(interleaved.val[0], interleaved.val[1]); + + // Use flat LUT for fast indexed access + // Store LUT as flat float array on stack (likely in L1 cache) + float flat_lut[16]; + vst1q_f32(flat_lut, lut[0]); + vst1q_f32(flat_lut + 4, lut[1]); + vst1q_f32(flat_lut + 8, lut[2]); + vst1q_f32(flat_lut + 12, lut[3]); + + // Extract indices and do lookups in groups of 4 for NEON multiply + uint8_t idx_arr[16]; + vst1q_u8(idx_arr, indices); + + float32x4_t vscale = vdupq_n_f32(scale); + + // Process 4 values at a time with NEON - load from temp buffer + float tmp_vals[16]; + tmp_vals[0] = flat_lut[idx_arr[0]]; + tmp_vals[1] = flat_lut[idx_arr[1]]; + tmp_vals[2] = flat_lut[idx_arr[2]]; + tmp_vals[3] = flat_lut[idx_arr[3]]; + tmp_vals[4] = flat_lut[idx_arr[4]]; + tmp_vals[5] = flat_lut[idx_arr[5]]; + tmp_vals[6] = flat_lut[idx_arr[6]]; + tmp_vals[7] = flat_lut[idx_arr[7]]; + tmp_vals[8] = flat_lut[idx_arr[8]]; + tmp_vals[9] = flat_lut[idx_arr[9]]; + tmp_vals[10] = flat_lut[idx_arr[10]]; + tmp_vals[11] = flat_lut[idx_arr[11]]; + tmp_vals[12] = flat_lut[idx_arr[12]]; + tmp_vals[13] = flat_lut[idx_arr[13]]; + tmp_vals[14] = flat_lut[idx_arr[14]]; + tmp_vals[15] = flat_lut[idx_arr[15]]; + float32x4_t v0 = vld1q_f32(tmp_vals); + float32x4_t v1 = vld1q_f32(tmp_vals + 4); + float32x4_t v2 = vld1q_f32(tmp_vals + 8); + float32x4_t v3 = vld1q_f32(tmp_vals + 12); + + vst1q_f32(out, vmulq_f32(v0, vscale)); + vst1q_f32(out + 4, vmulq_f32(v1, vscale)); + vst1q_f32(out + 8, vmulq_f32(v2, vscale)); + vst1q_f32(out + 12, vmulq_f32(v3, vscale)); +} + +// NEON-optimized BF16 to float conversion (4 values at a time) +static inline float32x4_t neon_bf16x4_to_f32(const bf16_t* src) { + // BF16 is upper 16 bits of float32, so shift left by 16 + uint16x4_t raw = vld1_u16(reinterpret_cast(src)); + uint32x4_t wide = vshll_n_u16(raw, 16); + return vreinterpretq_f32_u32(wide); +} + +// NEON-optimized float to BF16 conversion (4 values at a time, with rounding) +static inline void neon_f32_to_bf16x4(const float32x4_t src, bf16_t* dst) { + uint32x4_t bits = vreinterpretq_u32_f32(src); + // Round to nearest even: add 0x7FFF + ((bits >> 16) & 1) + uint32x4_t lsb = vshrq_n_u32(bits, 16); + lsb = vandq_u32(lsb, vdupq_n_u32(1)); + uint32x4_t rounding = vaddq_u32(vdupq_n_u32(0x7FFF), lsb); + bits = vaddq_u32(bits, rounding); + // Extract upper 16 bits + uint16x4_t result = vshrn_n_u32(bits, 16); + vst1_u16(reinterpret_cast(dst), result); +} + +// NEON-optimized float to FP16 conversion (4 values at a time) +static inline void neon_f32_to_fp16x4(const float32x4_t src, fp16_t* dst) { + // ARM64 has native FP16 conversion + float16x4_t half = vcvt_f16_f32(src); + vst1_u16(reinterpret_cast(dst), vreinterpret_u16_f16(half)); +} + +// NEON-optimized absmax computation for a block of float32 +static inline float neon_absmax_f32(const float* data, long long n) { + float32x4_t vmax = vdupq_n_f32(0.0f); + long long i = 0; + // Process 16 elements per iteration for better throughput + for (; i + 16 <= n; i += 16) { + float32x4_t v0 = vabsq_f32(vld1q_f32(data + i)); + float32x4_t v1 = vabsq_f32(vld1q_f32(data + i + 4)); + float32x4_t v2 = vabsq_f32(vld1q_f32(data + i + 8)); + float32x4_t v3 = vabsq_f32(vld1q_f32(data + i + 12)); + vmax = vmaxq_f32(vmax, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3))); + } + for (; i + 4 <= n; i += 4) { + float32x4_t v = vld1q_f32(data + i); + vmax = vmaxq_f32(vmax, vabsq_f32(v)); + } + // Horizontal max + float result = vmaxvq_f32(vmax); + // Handle remainder + for (; i < n; ++i) { + result = std::max(result, std::fabs(data[i])); + } + return result; +} + +// NEON-optimized norm_to_lut_index for 4 float values at a time +// Maps [-1, 1] → [0, 65535] +static inline uint16x4_t neon_norm_to_lut_index_x4(float32x4_t vals) { + // clamp to [-1, 1] + vals = vmaxq_f32(vals, vdupq_n_f32(-1.0f)); + vals = vminq_f32(vals, vdupq_n_f32(1.0f)); + // (val + 1.0) * 0.5 * 65535 + 0.5 + float32x4_t result = vmlaq_f32(vdupq_n_f32(0.5f), vaddq_f32(vals, vdupq_n_f32(1.0f)), vdupq_n_f32(0.5f * 65535.0f)); + uint32x4_t u32 = vcvtq_u32_f32(result); + return vmovn_u32(u32); +} + +#endif // _M_ARM64 || __aarch64__ + #if defined(__AVX512F__) #include @@ -94,6 +269,57 @@ void dequantizeBlockwise4bitCpu( if (blocksize <= 0 || m < 0 || n <= 0) return; +#if defined(_M_ARM64) || defined(__aarch64__) + { + long long dim_0 = m; + long long dim_1 = n; + long long input_dim_1 = dim_1 >> 1; + long long absmax_dim_1 = dim_1 / blocksize; + // NEON path: process 16 output values at a time (8 packed bytes) + // Only use when blocksize evenly divides dim_1 to ensure correct scale indexing + constexpr long long VEC_LEN = 16; + if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN && (dim_1 % blocksize == 0)) { + float32x4_t lut[4]; + if constexpr (DATA_TYPE == 1) { + neon_fp4_lut(lut); + } else { + neon_nf4_lut(lut); + } + constexpr long long k_step = VEC_LEN / 2; // 8 bytes per iteration + BNB_OMP_PARALLEL_FOR + for (long long block_idx = 0; block_idx < dim_0; ++block_idx) { + for (long long k = 0; k < input_dim_1; k += k_step) { + long long scale_idx = k * 2 / blocksize; + float scale = absmax[block_idx * absmax_dim_1 + scale_idx]; + const uint8_t* p = &A[block_idx * input_dim_1 + k]; + + // Dequantize 16 values into a temp float buffer + float tmp_f32[16]; + neon_dequant_4bit_16values(p, scale, lut, tmp_f32); + + // Store results (convert to output type using NEON) + T* pout = &out[block_idx * dim_1 + k * 2]; + if constexpr (std::is_same()) { + // Direct copy - already float + std::memcpy(pout, tmp_f32, 16 * sizeof(float)); + } else if constexpr (std::is_same()) { + neon_f32_to_bf16x4(vld1q_f32(tmp_f32), pout); + neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 4), pout + 4); + neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 8), pout + 8); + neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 12), pout + 12); + } else if constexpr (std::is_same()) { + neon_f32_to_fp16x4(vld1q_f32(tmp_f32), pout); + neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 4), pout + 4); + neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 8), pout + 8); + neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 12), pout + 12); + } + } + } + return; + } + } +#endif // _M_ARM64 || __aarch64__ + #if defined(__AVX512F__) if (has_avx512f()) { long long dim_0 = m; @@ -305,19 +531,29 @@ void quantize_cpu_impl(float* code, const T* A, float* absmax, unsigned char* ou for (long long b = 0; b < num_blocks; ++b) { const long long block_start = b * blocksize; const long long block_end = std::min(block_start + blocksize, n); + const long long block_len = block_end - block_start; // Compute absmax for this block float absmax_block = 0.0f; - for (long long i = block_start; i < block_end; ++i) { - float val; - if constexpr (std::is_same::value) { - val = A[i]; - } else if constexpr (std::is_same::value) { - val = bf16_to_float(A[i].v); - } else if constexpr (std::is_same::value) { - val = fp16_to_float(A[i].v); + +#if defined(_M_ARM64) || defined(__aarch64__) + if constexpr (std::is_same::value) { + // Use NEON-optimized absmax for float32 + absmax_block = neon_absmax_f32(reinterpret_cast(A + block_start), block_len); + } else +#endif + { + for (long long i = block_start; i < block_end; ++i) { + float val; + if constexpr (std::is_same::value) { + val = A[i]; + } else if constexpr (std::is_same::value) { + val = bf16_to_float(A[i].v); + } else if constexpr (std::is_same::value) { + val = fp16_to_float(A[i].v); + } + absmax_block = std::max(absmax_block, std::fabs(val)); } - absmax_block = std::max(absmax_block, std::fabs(val)); } absmax[b] = absmax_block; @@ -330,17 +566,42 @@ void quantize_cpu_impl(float* code, const T* A, float* absmax, unsigned char* ou } const float inv_absmax = 1.0f / absmax_block; - for (long long i = block_start; i < block_end; ++i) { - float val; - if constexpr (std::is_same::value) { - val = A[i]; - } else if constexpr (std::is_same::value) { - val = bf16_to_float(A[i].v); - } else if constexpr (std::is_same::value) { - val = fp16_to_float(A[i].v); + +#if defined(_M_ARM64) || defined(__aarch64__) + if constexpr (std::is_same::value) { + // NEON-optimized normalize + LUT index for float32 + const float* src = A + block_start; + long long i = 0; + float32x4_t vinv = vdupq_n_f32(inv_absmax); + for (; i + 4 <= block_len; i += 4) { + float32x4_t v = vmulq_f32(vld1q_f32(src + i), vinv); + uint16x4_t indices = neon_norm_to_lut_index_x4(v); + uint16_t idx_arr[4]; + vst1_u16(idx_arr, indices); + out[block_start + i] = lut[idx_arr[0]]; + out[block_start + i + 1] = lut[idx_arr[1]]; + out[block_start + i + 2] = lut[idx_arr[2]]; + out[block_start + i + 3] = lut[idx_arr[3]]; + } + for (; i < block_len; ++i) { + float normed_value = src[i] * inv_absmax; + out[block_start + i] = lut[norm_to_lut_index(normed_value)]; + } + } else +#endif + { + for (long long i = block_start; i < block_end; ++i) { + float val; + if constexpr (std::is_same::value) { + val = A[i]; + } else if constexpr (std::is_same::value) { + val = bf16_to_float(A[i].v); + } else if constexpr (std::is_same::value) { + val = fp16_to_float(A[i].v); + } + float normed_value = val * inv_absmax; + out[i] = lut[norm_to_lut_index(normed_value)]; } - float normed_value = val * inv_absmax; - out[i] = lut[norm_to_lut_index(normed_value)]; } } } diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 211e9b2f8..2f650f3b2 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -29,6 +29,9 @@ These are the minimum requirements for `bitsandbytes` across all platforms. Plea * Python >= 3.10 * PyTorch >= 2.4 +> [!NOTE] +> Windows ARM64 requires Python >= 3.12, as earlier Python versions do not provide official ARM64 Windows builds. + ## NVIDIA CUDA[[cuda]] `bitsandbytes` is currently supported on NVIDIA GPUs with [Compute Capability](https://developer.nvidia.com/cuda-gpus) 6.0+. @@ -172,6 +175,7 @@ The currently distributed `bitsandbytes` packages are built with the following c | **Linux x86-64** | GCC 11.4 | AVX2 | | **Linux aarch64** | GCC 11.4 | | | **Windows x86-64** | MSVC 19.43+ (VS2022) | AVX2 | +| **Windows arm64** | MSVC 19.43+ (VS2022) | ARM NEON | | **macOS arm64** | Apple Clang 17 | | The Linux build has a minimum glibc version of 2.24. @@ -186,11 +190,38 @@ pip install bitsandbytes To compile from source, simply install the package from source using `pip`. The package will be built for CPU only at this time. + + + +```bash +git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ +pip install -e . +``` + + + + ```bash git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ pip install -e . ``` + + + +Requires Visual Studio 2022 with the **ARM64 C++ build tools** component and Python >= **3.12**. + +```bash +git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ +pip install -e . +``` + +> [!NOTE] +> The build system will detect the ARM64 architecture automatically via CMake. Only the CPU backend is supported on Windows ARM64 at this time (no CUDA). + + + + ## AMD ROCm (Preview)[[rocm]] * Support for AMD GPUs is currently in a preview state. @@ -294,13 +325,21 @@ pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsand ``` - + ```bash # Note: if you don't want to reinstall our dependencies, append the `--no-deps` flag! pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ``` + + +```bash +# Note: if you don't want to reinstall our dependencies, append the `--no-deps` flag! +# Requires Python >= 3.12 +pip install --force-reinstall https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_arm64.whl +``` + ```bash