Skip to content

fix(nn/Linear4bit): consume QuantState keys in _load_from_state_dict (#1946)#1958

Open
Anai-Guo wants to merge 1 commit into
bitsandbytes-foundation:mainfrom
Anai-Guo:fix/linear4bit-load-from-state-dict
Open

fix(nn/Linear4bit): consume QuantState keys in _load_from_state_dict (#1946)#1958
Anai-Guo wants to merge 1 commit into
bitsandbytes-foundation:mainfrom
Anai-Guo:fix/linear4bit-load-from-state-dict

Conversation

@Anai-Guo
Copy link
Copy Markdown

Fixes #1946.

Problem

Linear4bit._save_to_state_dict writes the packed weight and the QuantState components:

  • weight.absmax, weight.quant_map
  • weight.nested_absmax, weight.nested_quant_map (when compress_statistics=True)
  • weight.quant_state.bitsandbytes__nf4 / __fp4

But Linear4bit does not override _load_from_state_dict; it inherits nn.Linear._load_from_state_dict, which only consumes weight and bias. So:

  • load_state_dict(strict=True) raises RuntimeError: Unexpected key(s) in state_dict for every QuantState entry.
  • load_state_dict(strict=False) silently drops them. The destination layer keeps the freshly-quantized quant_state from the prior .to('cuda') call, which does not match the packed bytes that were just loaded — so forward passes return garbage. This is asymmetric vs Linear8bitLt, which already implements both halves (save at modules.py:1095, load at modules.py:1119).

Fix

Add Linear4bit._load_from_state_dict. After delegating to super(), walk unexpected_keys for entries under <prefix>weight., collect them into a qs_dict, rebuild via QuantState.from_dict(...), install on self.weight, and remove the consumed keys from unexpected_keys. Mirrors the existing Linear8bitLt pattern.

Reproducer (from the issue, abbreviated)

import bitsandbytes as bnb

src = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4', compute_dtype=torch.bfloat16, compress_statistics=True).to('cuda')
dst = bnb.nn.Linear4bit(64, 64, bias=False, quant_type='nf4', compute_dtype=torch.bfloat16, compress_statistics=True).to('cuda')

dst.load_state_dict(src.state_dict())
# before: RuntimeError: Unexpected key(s) ... weight.absmax, weight.quant_map, ...
# after:  dst.weight.quant_state matches src.weight.quant_state

Notes

  • Behavior is unchanged for state_dicts that only contain weight (and bias): when no qs_prefix keys remain in unexpected_keys, the new code is a no-op early-return.
  • No public API is changed; this only adds a private hook that PyTorch's load_state_dict machinery already expects modules to implement when they extend _save_to_state_dict.
  • Embedding4bit has the same issue but a separate _save_to_state_dict shape; happy to extend the fix there in a follow-up if maintainers prefer one PR per class.

🤖 Generated with Claude Code

…itsandbytes-foundation#1946)

Linear4bit overrides _save_to_state_dict to write
weight.absmax / weight.quant_map / weight.nested_* /
weight.quant_state.bitsandbytes__* alongside the packed weight, but
inherits nn.Linear._load_from_state_dict which only consumes weight
and bias. Result:
- strict=True load raises Unexpected key(s) in state_dict for every
  QuantState component.
- strict=False silently drops them and the destination layer keeps the
  freshly-quantized quant_state from the prior .to('cuda') call, which
  does not match the packed bytes that were just loaded.

This mirrors what Linear8bitLt already does for SCB
(_load_from_state_dict at modules.py:1119): walk unexpected_keys for
entries under '<prefix>weight.', collect them into a qs_dict,
reconstruct via QuantState.from_dict, install on self.weight, and
remove the consumed keys from unexpected_keys.

Fixes bitsandbytes-foundation#1946
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Linear4bit._save_to_state_dict writes QuantState keys but no _load_from_state_dict consumes them (asymmetric serialization)

1 participant