From 4d6174bc6336fb6fba712f1d2c903de1de677747 Mon Sep 17 00:00:00 2001
From: dbaranchuk <dmitrybaranchuk@gmail.com>
Date: Thu, 25 Aug 2022 19:09:23 +0300
Subject: memory efficient fp16 backward

---
 bitsandbytes/nn/modules.py | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

(limited to 'bitsandbytes/nn/modules.py')

diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 03ffd3b..3e32c8e 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter):
         has_fp16_weights=False,
         CB=None,
         SCB=None,
-        SCBt=None,
     ):
         cls.has_fp16_weights = has_fp16_weights
         cls.CB = None
         cls.SCB = None
-        cls.SCBt = None
         if data is None:
             data = torch.empty(0)
         return torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter):
             B = self.data.contiguous().half().cuda(device)
             CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
             del CBt
+            del SCBt
             self.data = CB
             setattr(self, "CB", CB)
             setattr(self, "SCB", SCB)
-            setattr(self, "SCBt", SCBt)
 
         return self
 
@@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter):
             )
             new_param.CB = self.CB
             new_param.SCB = self.SCB
-            new_param.SCBt = self.SCBt
 
             return new_param
 
@@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear):
     def init_8bit_state(self):
         self.state.CB = self.weight.CB
         self.state.SCB = self.weight.SCB
-        self.state.SCBt = self.weight.SCBt
         self.weight.CB = None
         self.weight.SCB = None
-        self.weight.SCBt = None
 
     def forward(self, x):
         self.state.is_training = self.training
-- 
cgit v1.2.3