From 98cbc4bc4f15f5c094cd8575ddb0380a19516099 Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Sun, 6 Nov 2022 11:59:37 -0800
Subject: Added k-bit fp8 map.

---
 bitsandbytes/functional.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

(limited to 'bitsandbytes')

diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 65eccf2..ff48b7f 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
         return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
 
 
-def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
+def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
     e = exponent_bits
     p = precision_bits
-    assert e+p == 7
+    has_sign = 1 if signed else 0
+    assert e+p == total_bits-has_sign
     # the exponent is biased to 2^(e-1) -1 == 0
     evalues = []
     pvalues = []
-    for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
+    for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
         evalues.append(2**val)
 
 
@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
             value += pval*(2**-(i+1))
         pvalues.append(value)
 
-    assert len(evalues)*len(pvalues) == 128
+    assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
     values = []
     for ev in evalues:
         for pv in pvalues:
-            values.append(-ev*pv)
+            if signed:
+                values.append(-ev*pv)
             values.append(ev*pv)
+    if total_bits < 8:
+        gap = 256 - len(values)
+        for i in range(gap):
+            values.append(0)
     values.sort()
     code = torch.Tensor(values)
     code /= code.max()
-- 
cgit v1.2.3