fix(nn/Linear4bit): consume QuantState keys in _load_from_state_dict (#1946)#1958
Open
Anai-Guo wants to merge 1 commit into
Open
fix(nn/Linear4bit): consume QuantState keys in _load_from_state_dict (#1946)#1958Anai-Guo wants to merge 1 commit into
Anai-Guo wants to merge 1 commit into
Conversation
…itsandbytes-foundation#1946) Linear4bit overrides _save_to_state_dict to write weight.absmax / weight.quant_map / weight.nested_* / weight.quant_state.bitsandbytes__* alongside the packed weight, but inherits nn.Linear._load_from_state_dict which only consumes weight and bias. Result: - strict=True load raises Unexpected key(s) in state_dict for every QuantState component. - strict=False silently drops them and the destination layer keeps the freshly-quantized quant_state from the prior .to('cuda') call, which does not match the packed bytes that were just loaded. This mirrors what Linear8bitLt already does for SCB (_load_from_state_dict at modules.py:1119): walk unexpected_keys for entries under '<prefix>weight.', collect them into a qs_dict, reconstruct via QuantState.from_dict, install on self.weight, and remove the consumed keys from unexpected_keys. Fixes bitsandbytes-foundation#1946
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #1946.
Problem
Linear4bit._save_to_state_dictwrites the packedweightand the QuantState components:weight.absmax,weight.quant_mapweight.nested_absmax,weight.nested_quant_map(whencompress_statistics=True)weight.quant_state.bitsandbytes__nf4/__fp4But
Linear4bitdoes not override_load_from_state_dict; it inheritsnn.Linear._load_from_state_dict, which only consumesweightandbias. So:load_state_dict(strict=True)raisesRuntimeError: Unexpected key(s) in state_dictfor every QuantState entry.load_state_dict(strict=False)silently drops them. The destination layer keeps the freshly-quantizedquant_statefrom the prior.to('cuda')call, which does not match the packed bytes that were just loaded — so forward passes return garbage. This is asymmetric vsLinear8bitLt, which already implements both halves (save atmodules.py:1095, load atmodules.py:1119).Fix
Add
Linear4bit._load_from_state_dict. After delegating tosuper(), walkunexpected_keysfor entries under<prefix>weight., collect them into aqs_dict, rebuild viaQuantState.from_dict(...), install onself.weight, and remove the consumed keys fromunexpected_keys. Mirrors the existingLinear8bitLtpattern.Reproducer (from the issue, abbreviated)
Notes
weight(andbias): when noqs_prefixkeys remain inunexpected_keys, the new code is a no-op early-return.load_state_dictmachinery already expects modules to implement when they extend_save_to_state_dict.Embedding4bithas the same issue but a separate_save_to_state_dictshape; happy to extend the fix there in a follow-up if maintainers prefer one PR per class.🤖 Generated with Claude Code