From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001
From: Titus von Koeller <titus@vonkoeller.com>
Date: Mon, 1 Aug 2022 09:32:47 -0700
Subject: reran black with linelength 80 for greater readability

---
 bitsandbytes/nn/modules.py | 34 ++++++++++++++++++++++++++++------
 1 file changed, 28 insertions(+), 6 deletions(-)

(limited to 'bitsandbytes/nn/modules.py')

diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 9ce3ac8..454dba5 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -2,8 +2,19 @@
 #
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
-from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
-                    Tuple, TypeVar, Union, overload)
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterator,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+    overload,
+)
 
 import torch
 import torch.nn.functional as F
@@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
 
 class Int8Params(torch.nn.Parameter):
     def __new__(
-        cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
+        cls,
+        data=None,
+        requires_grad=True,
+        has_fp16_weights=False,
+        CB=None,
+        SCB=None,
     ):
         cls.has_fp16_weights = has_fp16_weights
         cls.CB = None
@@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
             return self.cuda(device)
         else:
             new_param = Int8Params(
-                super().to(device=device, dtype=dtype, non_blocking=non_blocking),
+                super().to(
+                    device=device, dtype=dtype, non_blocking=non_blocking
+                ),
                 requires_grad=self.requires_grad,
                 has_fp16_weights=self.has_fp16_weights,
             )
@@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
         threshold=0.0,
         index=None,
     ):
-        super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+        super(Linear8bitLt, self).__init__(
+            input_features, output_features, bias
+        )
         self.state = bnb.MatmulLtState()
         self.index = index
 
@@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
         if threshold > 0.0 and not has_fp16_weights:
             self.state.use_pool = True
 
-        self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+        self.weight = Int8Params(
+            self.weight.data, has_fp16_weights=has_fp16_weights
+        )
 
     def init_8bit_state(self):
         self.state.CB = self.weight.CB
-- 
cgit v1.2.3