From 6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Sun, 6 Nov 2022 16:27:48 -0800
Subject: Added blocksizes 2048, 1024, and 512 to blockwise quant.

---
 bitsandbytes/cextension.py | 11 +++++++++--
 bitsandbytes/functional.py | 20 ++++++++++----------
 2 files changed, 19 insertions(+), 12 deletions(-)

(limited to 'bitsandbytes')

diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 8125202..ead8502 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -52,8 +52,13 @@ class CUDASetup(object):
         self.add_log_entry('python setup.py install')
 
     def initialize(self):
-        self.cuda_setup_log = []
+        self.has_printed = False
         self.lib = None
+        self.run_cuda_setup()
+
+    def run_cuda_setup(self):
+        self.initialized = True
+        self.cuda_setup_log = []
 
         from .cuda_setup.main import evaluate_cuda_setup
         binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
@@ -89,7 +94,9 @@ class CUDASetup(object):
             else:
                 self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
                 self.lib = ct.cdll.LoadLibrary(binary_path)
-        except:
+                print(self.lib)
+        except Exception as ex:
+            self.add_log_entry(str(ex))
             self.print_log_stack()
 
     def add_log_entry(self, msg, is_warning=False):
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 076414d..49d4db1 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -130,10 +130,10 @@ class Cusparse_Context(object):
         return cls._instance
 
 
-def create_linear_map(signed=True, bits=8):
+def create_linear_map(signed=True, total_bits=8):
     sign = (-1.0 if signed else 0.0)
 
-    values = torch.linspace(sign, 1.0, 2**bits)
+    values = torch.linspace(sign, 1.0, 2**total_bits)
     gap = 256 - values.numel()
     if gap == 0:
         return values
@@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
         The quantization state to undo the quantization.
     """
 
+
     if code is None:
         if "dynamic" not in name2qmap:
             name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
         out = torch.zeros_like(A, dtype=torch.uint8)
 
     if A.device.type != 'cpu':
+        assert blocksize in [4096, 2048, 1024, 512]
         is_on_gpu([code, A, absmax, out, rand])
+        cblocksize = ct.c_int32(blocksize)
         if rand is not None:
+            assert blocksize==4096
             assert rand.numel() >= 1024
             rand_offset = random.randint(0, 1023)
             if A.dtype == torch.float32:
@@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
             elif A.dtype == torch.float16:
                 lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
             else:
-                raise ValueError(
-                    f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
-                )
+                raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
         else:
             if A.dtype == torch.float32:
-                lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
+                lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
             elif A.dtype == torch.float16:
-                lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
+                lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
             else:
-                raise ValueError(
-                    f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
-                )
+                raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
     else:
         # cpu
         assert rand is None
-- 
cgit v1.2.3