From 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Sun, 11 Sep 2022 11:55:09 -0700
Subject: Fixed 2^31 max size issue for cpu blockwise quant.

---
 bitsandbytes/functional.py | 90 ++++++++--------------------------------------
 1 file changed, 14 insertions(+), 76 deletions(-)

(limited to 'bitsandbytes/functional.py')

diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 22200f2..c104ebd 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -369,13 +369,7 @@ def estimate_quantiles(
     return out
 
 
-def quantize_blockwise(
-    A: Tensor,
-    code: Tensor = None,
-    absmax: Tensor = None,
-    rand=None,
-    out: Tensor = None,
-) -> Tensor:
+def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
     """
     Quantize tensor A in blocks of size 4096 values.
 
@@ -412,9 +406,9 @@ def quantize_blockwise(
 
     if absmax is None:
         n = A.numel()
-        num_blocks = 4096
-        blocks = n // num_blocks
-        blocks += 1 if n % num_blocks > 0 else 0
+        blocksize = (blocksize if A.device.type == 'cpu' else 4096)
+        blocks = n // blocksize
+        blocks += 1 if n % blocksize > 0 else 0
         absmax = torch.zeros((blocks,), device=A.device)
 
     if out is None:
@@ -426,46 +420,18 @@ def quantize_blockwise(
             assert rand.numel() >= 1024
             rand_offset = random.randint(0, 1023)
             if A.dtype == torch.float32:
-                lib.cquantize_blockwise_stochastic_fp32(
-                    get_ptr(code),
-                    get_ptr(A),
-                    get_ptr(absmax),
-                    get_ptr(out),
-                    get_ptr(rand),
-                    ct.c_int32(rand_offset),
-                    ct.c_int(A.numel()),
-                )
+                lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
             elif A.dtype == torch.float16:
-                lib.cquantize_blockwise_stochastic_fp16(
-                    get_ptr(code),
-                    get_ptr(A),
-                    get_ptr(absmax),
-                    get_ptr(out),
-                    get_ptr(rand),
-                    ct.c_int32(rand_offset),
-                    ct.c_int(A.numel()),
-                )
+                lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
             else:
                 raise ValueError(
                     f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
                 )
         else:
             if A.dtype == torch.float32:
-                lib.cquantize_blockwise_fp32(
-                    get_ptr(code),
-                    get_ptr(A),
-                    get_ptr(absmax),
-                    get_ptr(out),
-                    ct.c_int(A.numel()),
-                )
+                lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
             elif A.dtype == torch.float16:
-                lib.cquantize_blockwise_fp16(
-                    get_ptr(code),
-                    get_ptr(A),
-                    get_ptr(absmax),
-                    get_ptr(out),
-                    ct.c_int(A.numel()),
-                )
+                lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
             else:
                 raise ValueError(
                     f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
@@ -473,13 +439,7 @@ def quantize_blockwise(
     else:
         # cpu
         assert rand is None
-        lib.cquantize_blockwise_cpu_fp32(
-            get_ptr(code),
-            get_ptr(A),
-            get_ptr(absmax),
-            get_ptr(out),
-            ct.c_int(A.numel()),
-        )
+        lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
 
     return out, (absmax, code)
 
@@ -529,43 +489,21 @@ def dequantize_blockwise(
     if quant_state is None:
         quant_state = (absmax, code)
 
-    if blocksize not in [2048, 4096]:
-        raise ValueError(
-            f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
-        )
 
     if A.device.type != 'cpu':
+        if blocksize not in [2048, 4096]:
+            raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
         is_on_gpu([A, out])
         if out.dtype == torch.float32:
-            lib.cdequantize_blockwise_fp32(
-                get_ptr(quant_state[1]),
-                get_ptr(A),
-                get_ptr(quant_state[0]),
-                get_ptr(out),
-                ct.c_int(blocksize),
-                ct.c_int(A.numel()),
-            )
+            lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
         elif out.dtype == torch.float16:
-            lib.cdequantize_blockwise_fp16(
-                get_ptr(quant_state[1]),
-                get_ptr(A),
-                get_ptr(quant_state[0]),
-                get_ptr(out),
-                ct.c_int(blocksize),
-                ct.c_int(A.numel()),
-            )
+            lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
         else:
             raise ValueError(
                 f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
             )
     else:
-        lib.cdequantize_blockwise_cpu_fp32(
-            get_ptr(quant_state[1]),
-            get_ptr(A),
-            get_ptr(quant_state[0]),
-            get_ptr(out),
-            ct.c_int(A.numel()),
-        )
+        lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
 
     return out
 
-- 
cgit v1.2.3