From de354f7ded52bfa857089769225cdf1ee694bfd6 Mon Sep 17 00:00:00 2001
From: Tim Dettmers <tim.dettmers@gmail.com>
Date: Tue, 16 Aug 2022 12:00:54 -0700
Subject: Added fused bias to matmullt.

---
 tests/test_autograd.py | 49 ++++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 38 insertions(+), 11 deletions(-)

(limited to 'tests')

diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index f1a15f5..0cd17c9 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -1,4 +1,4 @@
-from itertools import product
+from itertools import product, permutations
 
 import pytest
 import torch
@@ -241,11 +241,20 @@ decomp = [0.0, 6.0]
 funcs = [(torch.matmul, bnb.matmul)]
 str_funcs = ["matmul"]
 req_grad = [(False, False), (True, False), (True, True), (False, True)]
-req_grad_str = ["FF", "TF", "TT", "FT"]
+req_grad = list(product([True, False], repeat=3))
+req_grad_str = []
+for c in req_grad:
+    strval = ''
+    for v in c:
+        if v == True: strval += 'T'
+        else: strval += 'F'
+    req_grad_str.append(strval)
+
 transpose = [(False, True), (False, False)]
 str_transpose = ["NT", "NN"]
 dtype = [torch.float16]
 has_fp16_weights = [True, False]
+has_bias = [True, False]
 values = list(
     product(
         dim1,
@@ -258,6 +267,7 @@ values = list(
         transpose,
         decomp,
         has_fp16_weights,
+        has_bias
     )
 )
 str_values = list(
@@ -272,18 +282,14 @@ str_values = list(
         str_transpose,
         decomp,
         has_fp16_weights,
+        has_bias
     )
 )
-names = [
-    "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
-        *vals
-    )
-    for vals in str_values
-]
+names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
 
 
 @pytest.mark.parametrize(
-    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
+    "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
     values,
     ids=names,
 )
@@ -298,10 +304,14 @@ def test_matmullt(
     transpose,
     decomp,
     has_fp16_weights,
+    has_bias
 ):
     dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
     dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
     outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
+    if has_bias == False:
+        req_grad = list(req_grad)
+        req_grad[2] = False
 
     for i in range(k):
 
@@ -322,6 +332,11 @@ def test_matmullt(
                 requires_grad=req_grad[1],
                 dtype=dtype,
             )
+            bias = None
+            bias2 = None
+            if has_bias: 
+                bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
+                bias2 = bias.clone()
             torch.nn.init.xavier_uniform_(B)
             B2 = B.clone()
 
@@ -342,10 +357,13 @@ def test_matmullt(
 
             if not transpose[0] and transpose[1]:
                 out_torch = funcs[0](A, B.t())
-                out_bnb = funcs[1](A, B2, state=state)
+                out_bnb = funcs[1](A, B2, state=state, bias=bias2)
             elif not transpose[0] and not transpose[1]:
                 out_torch = funcs[0](A, B)
-                out_bnb = funcs[1](A, B2.t(), state=state)
+                out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2)
+
+            if has_bias:
+                out_torch += bias
 
             n = out_bnb.numel()
             err = torch.abs(out_bnb - out_torch).mean().item()
@@ -367,6 +385,9 @@ def test_matmullt(
                     gradB1 = B.grad
                     A.grad = None
                     B.grad = None
+                    if has_bias:
+                        gradBias1 = bias.grad
+                        bias.grad = None
 
                     loss_torch = torch.nn.functional.mse_loss(
                         out_torch, target
@@ -376,6 +397,9 @@ def test_matmullt(
                     gradB2 = B.grad
                     A.grad = None
                     B.grad = None
+                    if has_bias:
+                        gradBias2 = bias.grad
+                        bias.grad = None
 
                 if req_grad[0]:
                     torch.testing.assert_allclose(
@@ -397,3 +421,6 @@ def test_matmullt(
                     torch.testing.assert_allclose(
                         gradB1, gradB2, atol=0.18, rtol=0.3
                     )
+
+                if req_grad[2]:
+                    torch.testing.assert_allclose(gradBias1, gradBias2)
-- 
cgit v1.2.3