From 95dafc6475bc36490e213269d1028adfd4f75363 Mon Sep 17 00:00:00 2001
From: justheuristic <justheuristic@gmail.com>
Date: Sun, 18 Sep 2022 01:22:31 +0300
Subject: cast before allclose

---
 tests/test_modules.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

(limited to 'tests/test_modules.py')

diff --git a/tests/test_modules.py b/tests/test_modules.py
index 8108b35..dbadea9 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -541,8 +541,8 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
     mlp = MLP8bit(
             32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
         )
-    w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone()
-    mlp = mlp.cuda().half()
+    w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone()  # note: we grad original weights before quantization,
+    mlp = mlp.cuda().half()  # and this line triggers quantization
 
     for i in range(100):
         b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -567,8 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
 
         mlp.zero_grad()
         (o1 * grad_proj).sum().backward()
-        assert False, (w1, w2)
-        grad_ref = grad_proj.flatten(2) @ w2 @ w1
+        grad_ref = grad_proj.flatten(2) @ w2.to(grad_proj.device) @ w1.to(grad_proj.device)
         assert torch.allclose(b1.grad, grad_ref)
 
 
-- 
cgit v1.2.3