Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +20 -0
    	
        modeling_rwkv5.py
    CHANGED
    
    | @@ -735,6 +735,26 @@ class Rwkv5Model(Rwkv5PreTrainedModel): | |
| 735 | 
             
                        hidden_states=all_hidden_states,  # None
         | 
| 736 | 
             
                        attentions=all_self_attentions,  # None
         | 
| 737 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 738 |  | 
| 739 | 
             
                def _rescale_layers(self):
         | 
| 740 | 
             
                    # Layers should be rescaled for inference only.
         | 
|  | |
| 735 | 
             
                        hidden_states=all_hidden_states,  # None
         | 
| 736 | 
             
                        attentions=all_self_attentions,  # None
         | 
| 737 | 
             
                    )
         | 
| 738 | 
            +
                def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
         | 
| 739 | 
            +
                    r"""
         | 
| 740 | 
            +
                    Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
         | 
| 741 | 
            +
                    be quantized again.
         | 
| 742 | 
            +
                    """
         | 
| 743 | 
            +
                    if not is_bitsandbytes_available():
         | 
| 744 | 
            +
                        raise ImportError("Please install bitsandbytes to use this method.")
         | 
| 745 | 
            +
                    import bitsandbytes as bnb
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                    dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                    # re-quantize the model:
         | 
| 752 | 
            +
                    # we need to put it first on CPU then back to the device
         | 
| 753 | 
            +
                    # this will create an overhead :/
         | 
| 754 | 
            +
                    # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
         | 
| 755 | 
            +
                    # bugs with bnb
         | 
| 756 | 
            +
                    quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
         | 
| 757 | 
            +
                    setattr(target_layer, "weight", quant_weight)
         | 
| 758 |  | 
| 759 | 
             
                def _rescale_layers(self):
         | 
| 760 | 
             
                    # Layers should be rescaled for inference only.
         |