From 9d60b3c5279641ba936facd710c722ebe52fcf40 Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Wed, 17 Aug 2022 03:45:57 -0700
Subject: Fixed bug in Linear8bitLt, when the bias is None.

---
 tests/test_modules.py | 23 +++++++++++++++++++++++
 1 file changed, 23 insertions(+)

(limited to 'tests')

diff --git a/tests/test_modules.py b/tests/test_modules.py
index 7faadb8..c0b3311 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold):
     assert mlp.fc2.weight.dtype == torch.int8
     assert mlp.fc1.weight.device.type == "cuda"
     assert mlp.fc2.weight.device.type == "cuda"
+
+
+def test_linear8bitlt_fp32_bias():
+    # casts model to fp16 -> int8 automatically
+    l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
+    assert l1.weight.dtype == torch.int8
+    assert l1.bias.dtype == torch.float32
+
+    for i in range(100):
+        b1 = torch.randn(16, 8, 32, device="cuda").half()
+        # casts bias to fp32
+        o1 = l1(b1)
+        assert l1.bias.dtype == torch.float16
+
+    # casts model to fp16 -> int8 automatically
+    l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
+    assert l1.weight.dtype == torch.int8
+    assert l1.bias is None
+
+    for i in range(100):
+        b1 = torch.randn(16, 8, 32, device="cuda").half()
+        o1 = l1(b1)
+        assert l1.bias is None
-- 
cgit v1.2.3