From 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b Mon Sep 17 00:00:00 2001
From: dbaranchuk <dmitrybaranchuk@gmail.com>
Date: Tue, 23 Aug 2022 23:39:54 +0300
Subject: add memory efficient backward

---
 bitsandbytes/autograd/_functions.py | 39 ++++++++++++++++++-------------------
 1 file changed, 19 insertions(+), 20 deletions(-)

(limited to 'bitsandbytes/autograd')

diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 4dbf129..63e8ad5 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function):
                 subA = A[:, idx]
                 state.subB = B[:, idx].t().contiguous()
                 state.idx = idx
-            else:
-                if state.CxB is None:
-                    # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
-                    # we also need to convert it to the turing/ampere format
-                    state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+            elif state.CxB is None:
+                # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+                # we also need to convert it to the turing/ampere format
+                state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
         else:
             if not state.has_fp16_weights and state.CxB is None:
                 state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
 
             outlier_idx = torch.unique(coo_tensorA.colidx)
             state.idx = outlier_idx
-            # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
-            # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
-            #    # do not use pool for 2nd FFN layer
-            #    state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
-            # else:
-            #    state.idx = outlier_idx
             outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
             state.subB = (
                 (outliers * state.SCB.view(-1, 1) / 127.0)
@@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
         SCAt, idx = ctx.tensor_states
         formatB = ctx.formatB
         state = ctx.state
-        assert (
-            state.has_fp16_weights
-        ), "Backprop only supported for fp16 weights."
 
         if len(grad_output.shape) == 3:
-            grad_output = grad_output.view(
+            grad_output = grad_output.reshape(
                 -1, grad_output.shape[-1]
             ).contiguous()
 
@@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
 
         if req_gradA:
             C32grad, Sgrad = F.transform(Cgrad, "col32")
-            if state.CxBt is None:
-                state.CxBt, state.SBt = F.transform(
-                    state.CBt, to_order=formatB, transpose=True
-                )
-            gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+            if state.CxBt is None and state.has_fp16_weights:
+                CBt = state.CBt
+            elif state.CxBt is None:
+                assert state.CBt is None
+                CB = state.CB.half()
+                SCB = state.SCB.unsquezee(1).half()
+                SCBt = state.SCBt.unsquezee(1).half()
+                Bt = (CB * SCB).t().contiguous()
+                CBt = (Bt / SCBt).t().to(torch.int8)
+
+            CxBt, SBt = F.transform(
+                CBt, to_order=formatB, transpose=True
+            )
+            gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
             grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
 
         if req_gradBias:
-- 
cgit v1.2.3