From ab9dee062d791ef343ff5f9e8c2c85dc094219ed Mon Sep 17 00:00:00 2001
From: justheuristic <justheuristic@gmail.com>
Date: Sun, 18 Sep 2022 00:36:46 +0300
Subject: cast edge case

---
 bitsandbytes/autograd/_functions.py | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

(limited to 'bitsandbytes/autograd')

diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index d0e48b7..1d0002c 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -221,9 +221,6 @@ class MatMul8bitLt(torch.autograd.Function):
         # 3. Matmul
         # 4. Mixed-precision decomposition matmul
         # 5. Save state
-        requires_gradA = A.requires_grad
-        requires_gradB = B.requires_grad
-        requires_gradBias = bias is not None and bias.requires_grad
         formatB = state.formatB
         input_shape = A.shape
         if state.outlier_pool is None:
@@ -330,7 +327,7 @@ class MatMul8bitLt(torch.autograd.Function):
         ctx.grad_shape = input_shape
         ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
 
-        if requires_gradA or requires_gradB:
+        if any(ctx.needs_input_grad[:2]):
             ctx.tensors = (CAt, subA)
             ctx.tensor_states = (SCAt, state.idx)
         else:
-- 
cgit v1.2.3