From 98cbc4bc4f15f5c094cd8575ddb0380a19516099 Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Sun, 6 Nov 2022 11:59:37 -0800
Subject: Added k-bit fp8 map.

---
 bitsandbytes/functional.py | 16 ++++++---
 tests/test_functional.py   | 88 +++++++++++++++++++++-------------------------
 2 files changed, 52 insertions(+), 52 deletions(-)

diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 65eccf2..ff48b7f 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
         return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
 
 
-def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
+def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
     e = exponent_bits
     p = precision_bits
-    assert e+p == 7
+    has_sign = 1 if signed else 0
+    assert e+p == total_bits-has_sign
     # the exponent is biased to 2^(e-1) -1 == 0
     evalues = []
     pvalues = []
-    for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
+    for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
         evalues.append(2**val)
 
 
@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
             value += pval*(2**-(i+1))
         pvalues.append(value)
 
-    assert len(evalues)*len(pvalues) == 128
+    assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
     values = []
     for ev in evalues:
         for pv in pvalues:
-            values.append(-ev*pv)
+            if signed:
+                values.append(-ev*pv)
             values.append(ev*pv)
+    if total_bits < 8:
+        gap = 256 - len(values)
+        for i in range(gap):
+            values.append(0)
     values.sort()
     code = torch.Tensor(values)
     code /= code.max()
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 494bf51..bd4dafe 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -11,7 +11,7 @@ import bitsandbytes as bnb
 from bitsandbytes import functional as F
 
 torch.set_printoptions(
-    precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
+    precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
 )
 k = 20
 
@@ -2095,49 +2095,43 @@ def test_fp8_quant():
 def test_few_bit_quant():
 
     for bits in range(2, 9):
-        code = F.create_linear_map(True, bits=bits).cuda()
-        assert code.numel() == 256
-        print(bits)
-        for i in range(100):
-
-            values = torch.randn(1, 24, device='cuda')
-            values /= values.abs().max()
-            #values[values.abs() < 1e-6] += 1e-5
-
-            q1 = []
-            v1 = []
-            for v in values[0]:
-                idx = torch.abs(v-code).argmin()
-                q1.append(idx.item())
-                v1.append(code[idx].item())
-
-            q1 = torch.Tensor(q1).cuda()
-            v1 = torch.Tensor(v1).cuda()
-
-            q2, S2 = F.quantize(values, code=code)
-            v2 = F.dequantize(q2, S2)
-
-            idx = torch.isclose(q1.int(), q2.int())
-            if idx.sum():
-                # some weird cases
-                err1 = torch.abs(v1-values).mean()
-                err2 = torch.abs(v2-values).mean()
-                assert err2 <= err1
-
-            else:
-                torch.testing.assert_allclose(q1, q2)
-
-    #print(e_bits, p_bits)
-    #abserr = []
-    #relerr = []
-    #for i in range(100):
-    #    A1 = torch.randn(1024, 1024, device="cuda")
-    #    C, SC = F.quantize_blockwise(A1, code=code)
-    #    A2 = F.dequantize_blockwise(C, SC)
-    #    diff = torch.abs(A1 - A2)
-    #    reldiff = diff/torch.abs(A1+1e-8)
-    #    abserr.append(diff.mean().item())
-    #    relerr.append(reldiff.mean().item())
-    #    #assert diff < 0.0075
-    #print(sum(abserr)/len(abserr))
-    #print(sum(relerr)/len(relerr))
+        for method in ['linear', 'fp8']:
+            code = None
+            if method == 'linear':
+                code = F.create_linear_map(True, bits=bits).cuda()
+            elif method == 'fp8':
+                ebits = math.ceil(bits/2)
+                pbits = bits-ebits-1
+                code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
+                print(ebits, pbits, bits)
+                print(code)
+            assert code.numel() == 256
+            print(bits)
+            for i in range(10):
+
+                values = torch.randn(1, 32, device='cuda')
+                values /= values.abs().max()
+                #values[values.abs() < 1e-6] += 1e-5
+
+                q1 = []
+                v1 = []
+                for v in values[0]:
+                    idx = torch.abs(v-code).argmin()
+                    q1.append(idx.item())
+                    v1.append(code[idx].item())
+
+                q1 = torch.Tensor(q1).cuda()
+                v1 = torch.Tensor(v1).cuda()
+
+                q2, S2 = F.quantize(values, code=code)
+                v2 = F.dequantize(q2, S2)
+
+                idx = torch.isclose(q1.int(), q2.int())
+                if idx.sum():
+                    # some weird cases
+                    err1 = torch.abs(v1-values).mean()
+                    err2 = torch.abs(v2-values).mean()
+                    assert err2 <= err1
+
+                else:
+                    torch.testing.assert_allclose(q1, q2)
-- 
cgit v1.2.3