diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index ceada0b1..80d150e2 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -35,6 +35,7 @@ from .gaussian_conditional import GaussianConditionalLatentCodec from .hyper import HyperLatentCodec from .hyperprior import HyperpriorLatentCodec +from .multi_context_checkerboard import MultiContextCheckerboardLatentCodec from .rasterscan import RasterScanLatentCodec __all__ = [ @@ -47,5 +48,6 @@ "GaussianConditionalLatentCodec", "HyperLatentCodec", "HyperpriorLatentCodec", + "MultiContextCheckerboardLatentCodec", "RasterScanLatentCodec", ] diff --git a/compressai/latent_codecs/_checkerboard_helpers.py b/compressai/latent_codecs/_checkerboard_helpers.py new file mode 100644 index 00000000..84b65a3a --- /dev/null +++ b/compressai/latent_codecs/_checkerboard_helpers.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Pure functional helpers shared by checkerboard latent codecs. + +These are extracted from :class:`CheckerboardLatentCodec` so that sibling +codecs (e.g. :class:`MultiContextCheckerboardLatentCodec`) can reuse the +exact same checkerboard split / merge / mask logic without duplicating it. +A single source of truth here also means an anchor-parity boundary fix +applies to every checkerboard codec at once. +""" + +from __future__ import annotations + +import torch + +from torch import Tensor + +__all__ = [ + "embed", + "embed_step", + "mask_all", + "mask_all_but_step", + "merge", + "step_parity", + "unembed", + "write_step", +] + + +def step_parity(step: str, anchor_parity: str) -> str: + """Resolve a ``step`` ('anchor' / 'non_anchor') to a parity string.""" + if step == "anchor": + return anchor_parity + if step == "non_anchor": + return "odd" if anchor_parity == "even" else "even" + raise ValueError(f'Invalid "step" value "{step}"') + + +def unembed(y: Tensor, *, anchor_parity: str) -> Tensor: + """Separate single tensor into two even/odd checkerboard chunks. + + .. code-block:: none + + ■ □ ■ □ ■ ■ □ □ + □ ■ □ ■ ---> ■ ■ □ □ + ■ □ ■ □ ■ ■ □ □ + """ + n, c, h, w = y.shape + y_packed = y.new_zeros((2, n, c, h, w // 2)) + if anchor_parity == "even": + y_packed[0, ..., 0::2, :] = y[..., 0::2, 0::2] + y_packed[0, ..., 1::2, :] = y[..., 1::2, 1::2] + y_packed[1, ..., 0::2, :] = y[..., 0::2, 1::2] + y_packed[1, ..., 1::2, :] = y[..., 1::2, 0::2] + else: + y_packed[0, ..., 0::2, :] = y[..., 0::2, 1::2] + y_packed[0, ..., 1::2, :] = y[..., 1::2, 0::2] + y_packed[1, ..., 0::2, :] = y[..., 0::2, 0::2] + y_packed[1, ..., 1::2, :] = y[..., 1::2, 1::2] + return y_packed + + +def embed(y_packed: Tensor, *, anchor_parity: str) -> Tensor: + """Combine two even/odd checkerboard chunks into single tensor. + + .. code-block:: none + + ■ ■ □ □ ■ □ ■ □ + ■ ■ □ □ ---> □ ■ □ ■ + ■ ■ □ □ ■ □ ■ □ + """ + num_chunks, n, c, h, w_half = y_packed.shape + assert num_chunks == 2 + y = y_packed.new_zeros((n, c, h, w_half * 2)) + if anchor_parity == "even": + y[..., 0::2, 0::2] = y_packed[0, ..., 0::2, :] + y[..., 1::2, 1::2] = y_packed[0, ..., 1::2, :] + y[..., 0::2, 1::2] = y_packed[1, ..., 0::2, :] + y[..., 1::2, 0::2] = y_packed[1, ..., 1::2, :] + else: + y[..., 0::2, 1::2] = y_packed[0, ..., 0::2, :] + y[..., 1::2, 0::2] = y_packed[0, ..., 1::2, :] + y[..., 0::2, 0::2] = y_packed[1, ..., 0::2, :] + y[..., 1::2, 1::2] = y_packed[1, ..., 1::2, :] + return y + + +def embed_step( + step_index: int, y_i: Tensor, width: int, *, anchor_parity: str +) -> Tensor: + """Embed a per-step half-width tensor back into a full-grid tensor.""" + n, c, h, _ = y_i.shape + y_packed = y_i.new_zeros((2, n, c, h, width // 2)) + y_packed[step_index] = y_i + return embed(y_packed, anchor_parity=anchor_parity) + + +def write_step(dest: Tensor, src: Tensor, step: str, *, anchor_parity: str) -> None: + """Copy ``src`` pixels at the current step's positions into ``dest`` in-place.""" + parity = step_parity(step, anchor_parity) + if parity == "even": + dest[..., 0::2, 0::2] = src[..., 0::2, 0::2] + dest[..., 1::2, 1::2] = src[..., 1::2, 1::2] + else: + dest[..., 0::2, 1::2] = src[..., 0::2, 1::2] + dest[..., 1::2, 0::2] = src[..., 1::2, 0::2] + + +def mask_all_but_step(y: Tensor, step: str, *, anchor_parity: str) -> Tensor: + """Keep only pixels in the current step, and zero out the rest.""" + y = y.clone() + parity = step_parity(step, anchor_parity) + if parity == "even": + y[..., 0::2, 1::2] = 0 + y[..., 1::2, 0::2] = 0 + else: + y[..., 0::2, 0::2] = 0 + y[..., 1::2, 1::2] = 0 + return y + + +def mask_all(y: Tensor) -> Tensor: + """Return a zero tensor with the same shape, dtype and device as ``y``.""" + return torch.zeros_like(y) + + +def merge(*args: Tensor) -> Tensor: + """Concatenate tensors along the channel dimension.""" + return torch.cat(args, dim=1) diff --git a/compressai/latent_codecs/_selective_checkerboard.py b/compressai/latent_codecs/_selective_checkerboard.py new file mode 100644 index 00000000..f4813cea --- /dev/null +++ b/compressai/latent_codecs/_selective_checkerboard.py @@ -0,0 +1,215 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import GaussianConditional + +from . import _checkerboard_helpers as _ckb + +__all__ = [ + "apply_selective_y_hat", + "apply_selective_y_hat_packed", + "apply_selective_compression", + "apply_selective_decompression", + "compress_selected", + "decompress_selected", + "selective_mask", + "selective_mask_packed", +] + + +def selective_mask( + selective_predictor: Optional[nn.Module], + step: str, + side_params: Tensor, + scales: Tensor, + means: Tensor, + *, + anchor_parity: str, +) -> Optional[Tensor]: + if selective_predictor is None: + return None + selective_map = selective_predictor( + side_params=side_params, + scales=scales, + means=means, + step=step, + ) + if isinstance(selective_map, dict): + selective_map = selective_map["selective_map"] + if selective_map.shape != scales.shape: + selective_map = selective_map.expand_as(scales) + if selective_map.dtype == torch.bool: + mask = selective_map + else: + mask = selective_map >= 0.5 + return _ckb.mask_all_but_step(mask, step, anchor_parity=anchor_parity) + + +def selective_mask_packed( + selective_predictor: Optional[nn.Module], + step_index: int, + step: str, + side_params: Tensor, + scales: Tensor, + means: Tensor, + *, + anchor_parity: str, +) -> Optional[Tensor]: + if selective_predictor is None: + return None + width = side_params.shape[-1] + scales_full = _ckb.embed_step( + step_index, scales, width, anchor_parity=anchor_parity + ) + means_full = _ckb.embed_step(step_index, means, width, anchor_parity=anchor_parity) + mask = selective_mask( + selective_predictor, + step, + side_params, + scales_full, + means_full, + anchor_parity=anchor_parity, + ) + if mask is None: + return None + return _ckb.unembed(mask, anchor_parity=anchor_parity)[step_index] + + +def apply_selective_y_hat( + step: str, + y_hat: Tensor, + means: Tensor, + selective_mask: Optional[Tensor], + *, + anchor_parity: str, +) -> Tensor: + if selective_mask is None: + return y_hat + y_hat = torch.where(selective_mask, y_hat, means) + return _ckb.mask_all_but_step(y_hat, step, anchor_parity=anchor_parity) + + +def apply_selective_y_hat_packed( + y_hat: Tensor, + means: Tensor, + selective_mask: Optional[Tensor], +) -> Tensor: + if selective_mask is None: + return y_hat + return torch.where(selective_mask, y_hat, means) + + +def apply_selective_compression( + latent_codec: Any, + y: Tensor, + params: Tensor, + scales: Tensor, + means: Tensor, + selective_mask: Optional[Tensor], +) -> Dict[str, Any]: + if selective_mask is None: + return latent_codec.compress(y, params) + return compress_selected( + latent_codec.gaussian_conditional, y, scales, means, selective_mask + ) + + +def apply_selective_decompression( + latent_codec: Any, + strings: List[bytes], + shape: Tuple[int, ...], + params: Tensor, + scales: Tensor, + means: Tensor, + selective_mask: Optional[Tensor], +) -> Dict[str, Any]: + if selective_mask is None: + return latent_codec.decompress([strings], shape, params) + y_hat = decompress_selected( + latent_codec.gaussian_conditional, strings, scales, means, selective_mask + ) + assert y_hat.shape[1:] == shape + return {"y_hat": y_hat} + + +def compress_selected( + gaussian_conditional: GaussianConditional, + y: Tensor, + scales: Tensor, + means: Tensor, + selective_mask: Tensor, +) -> Dict[str, Any]: + indexes = gaussian_conditional.build_indexes(scales) + y_strings = [] + y_hat = means.clone() + + for sample_index in range(y.shape[0]): + mask = selective_mask[sample_index].reshape(-1) + if not mask.any(): + y_strings.append(b"") + continue + + y_i = y[sample_index].reshape(-1)[mask].unsqueeze(0) + indexes_i = indexes[sample_index].reshape(-1)[mask].unsqueeze(0) + means_i = means[sample_index].reshape(-1)[mask].unsqueeze(0) + y_string = gaussian_conditional.compress(y_i, indexes_i, means_i)[0] + y_hat_i = gaussian_conditional.decompress([y_string], indexes_i, means=means_i) + y_hat[sample_index].reshape(-1)[mask] = y_hat_i.reshape(-1).to(y_hat.dtype) + y_strings.append(y_string) + + return {"strings": [y_strings], "shape": y.shape[2:4], "y_hat": y_hat} + + +def decompress_selected( + gaussian_conditional: GaussianConditional, + strings: List[bytes], + scales: Tensor, + means: Tensor, + selective_mask: Tensor, +) -> Tensor: + indexes = gaussian_conditional.build_indexes(scales) + y_hat = means.clone() + + for sample_index, y_string in enumerate(strings): + mask = selective_mask[sample_index].reshape(-1) + if not mask.any(): + continue + indexes_i = indexes[sample_index].reshape(-1)[mask].unsqueeze(0) + means_i = means[sample_index].reshape(-1)[mask].unsqueeze(0) + y_hat_i = gaussian_conditional.decompress([y_string], indexes_i, means=means_i) + y_hat[sample_index].reshape(-1)[mask] = y_hat_i.reshape(-1).to(y_hat.dtype) + + return y_hat diff --git a/compressai/latent_codecs/multi_context_checkerboard.py b/compressai/latent_codecs/multi_context_checkerboard.py new file mode 100644 index 00000000..3def49f8 --- /dev/null +++ b/compressai/latent_codecs/multi_context_checkerboard.py @@ -0,0 +1,418 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import GaussianConditional +from compressai.ops import quantize_ste +from compressai.registry import register_module + +from . import _checkerboard_helpers as _ckb +from . import _selective_checkerboard as _sel +from .base import LatentCodec +from .gaussian_conditional import GaussianConditionalLatentCodec + +__all__ = ["MultiContextCheckerboardLatentCodec"] +LrpInputBuilder = Callable[[Tensor, Tensor, Tensor], Tensor] +LrpActivation = Optional[Callable[[Tensor], Tensor]] + + +@register_module("MultiContextCheckerboardLatentCodec") +class MultiContextCheckerboardLatentCodec(LatentCodec): + """Two-pass checkerboard codec with separate heads and optional contexts. + + This is a sibling of :class:`CheckerboardLatentCodec` for models whose + anchor and non-anchor passes use distinct entropy-parameter heads and + optional per-pass latent residual prediction. + + Optional context hooks (``spatial_context_anchor`` / + ``spatial_context_nonanchor`` / ``intra_channel_context_nonanchor``) + are *omitted* from the entropy-parameters input when ``None``; they + do not contribute zero-padding. The entropy-parameter heads must be + sized to ``side_params.shape[1]`` plus the channel widths produced + by whichever context modules are supplied for that pass. + + LRP modules are treated as raw residual predictors by default: + ``lrp_activation`` (default: ``torch.tanh``) is applied before scaling. + Set ``lrp_activation=None`` when the supplied LRP module already applies + its own bounded activation. + """ + + def __init__( + self, + *, + entropy_parameters_anchor: nn.Module, + entropy_parameters_nonanchor: nn.Module, + latent_codec: Optional[Mapping[str, LatentCodec]] = None, + scale_table: Optional[List[float]] = None, + gaussian_conditional: Optional[GaussianConditional] = None, + spatial_context_anchor: Optional[nn.Module] = None, + spatial_context_nonanchor: Optional[nn.Module] = None, + intra_channel_context_nonanchor: Optional[nn.Module] = None, + selective_predictor: Optional[nn.Module] = None, + lrp_anchor: Optional[nn.Module] = None, + lrp_nonanchor: Optional[nn.Module] = None, + lrp_input_builder: Optional[LrpInputBuilder] = None, + lrp_activation: LrpActivation = torch.tanh, + lrp_scale: float = 0.5, + anchor_parity: str = "even", + quantizer: str = "ste", + **kwargs: Any, + ) -> None: + super().__init__() + if anchor_parity not in ("even", "odd"): + raise ValueError(f'Invalid "anchor_parity" value "{anchor_parity}"') + if quantizer != "ste": + raise ValueError(f'Invalid quantizer "{quantizer}"') + + self._kwargs = kwargs + self.anchor_parity = anchor_parity + self.non_anchor_parity = {"odd": "even", "even": "odd"}[anchor_parity] + self.quantizer = quantizer + self.entropy_parameters_anchor = entropy_parameters_anchor + self.entropy_parameters_nonanchor = entropy_parameters_nonanchor + self.spatial_context_anchor = spatial_context_anchor + self.spatial_context_nonanchor = spatial_context_nonanchor + self.intra_channel_context_nonanchor = intra_channel_context_nonanchor + self.selective_predictor = selective_predictor + self.lrp_anchor = lrp_anchor + self.lrp_nonanchor = lrp_nonanchor + self.lrp_input_builder = lrp_input_builder + self.lrp_activation = lrp_activation + self.lrp_scale = float(lrp_scale) + + if latent_codec is None: + latent_codec = { + "y": GaussianConditionalLatentCodec( + scale_table=scale_table, + gaussian_conditional=gaussian_conditional, + ) + } + self.y = latent_codec["y"] + self.latent_codec = latent_codec + + def forward(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + b, c, h, w = y.shape + params = y.new_zeros((b, c * 2, h, w)) + y_hat_steps = [] + selective_masks = [] + + for step in ("anchor", "non_anchor"): + ctx_params = self._ctx_params(step, y, side_params, y_hat_steps) + params_i = self._entropy_parameters(step)(ctx_params) + params_i = _ckb.mask_all_but_step( + params_i, step, anchor_parity=self.anchor_parity + ) + _ckb.write_step(params, params_i, step, anchor_parity=self.anchor_parity) + + scales_i, means_i = self.y._chunk(params_i) + selective_mask_i = _sel.selective_mask( + self.selective_predictor, + step, + side_params, + scales_i, + means_i, + anchor_parity=self.anchor_parity, + ) + if selective_mask_i is not None: + selective_masks.append(selective_mask_i) + y_i = _ckb.mask_all_but_step(y, step, anchor_parity=self.anchor_parity) + y_hat_i = self._quantize(y_i, means_i) + y_hat_i = _ckb.mask_all_but_step( + y_hat_i, step, anchor_parity=self.anchor_parity + ) + y_hat_i = _sel.apply_selective_y_hat( + step, + y_hat_i, + means_i, + selective_mask_i, + anchor_parity=self.anchor_parity, + ) + lrp_input_y_hat = y_hat_i + if step == "non_anchor": + lrp_input_y_hat = y_hat_steps[0] + y_hat_i + y_hat_i = self._apply_lrp( + step, + side_params, + params_i, + y_hat_i, + lrp_input_y_hat, + ) + y_hat_i = _sel.apply_selective_y_hat( + step, + y_hat_i, + means_i, + selective_mask_i, + anchor_parity=self.anchor_parity, + ) + y_hat_steps.append(y_hat_i) + + y_hat = y_hat_steps[0] + y_hat_steps[1] + y_out = self.y(y, params) + y_likelihoods = y_out["likelihoods"]["y"] + if selective_masks: + selective_mask = selective_masks[0] | selective_masks[1] + y_likelihoods = torch.where( + selective_mask, y_likelihoods, torch.ones_like(y_likelihoods) + ) + + return { + "likelihoods": { + "y": y_likelihoods, + }, + "y_hat": y_hat, + } + + def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: + n, c, h, w = y.shape + y_hat = y.new_zeros((n, c, h, w)) + y_hat_packed = y.new_zeros((2, n, c, h, w // 2)) + y_packed = _ckb.unembed(y, anchor_parity=self.anchor_parity) + side_params_packed = _ckb.unembed(side_params, anchor_parity=self.anchor_parity) + y_strings = [None] * 2 + + for i, step in enumerate(("anchor", "non_anchor")): + ctx_params_i = self._ctx_params_packed( + i, step, side_params, side_params_packed, y_hat_packed + ) + params_i = self._entropy_parameters(step)(ctx_params_i) + scales_i, means_i = self.y._chunk(params_i) + selective_mask_i = _sel.selective_mask_packed( + self.selective_predictor, + i, + step, + side_params, + scales_i, + means_i, + anchor_parity=self.anchor_parity, + ) + y_out = _sel.apply_selective_compression( + self.y, y_packed[i], params_i, scales_i, means_i, selective_mask_i + ) + y_hat_for_lrp = y_hat_packed.clone() + y_hat_for_lrp[i] = y_out["y_hat"] + y_hat_i = self._apply_lrp_packed( + i, step, side_params, params_i, y_hat_for_lrp + ) + y_hat_i = _sel.apply_selective_y_hat_packed( + y_hat_i, means_i, selective_mask_i + ) + y_hat_packed[i] = y_hat_i + [y_strings[i]] = y_out["strings"] + + y_hat[:] = _ckb.embed(y_hat_packed, anchor_parity=self.anchor_parity) + + return { + "strings": y_strings, + "shape": y_hat.shape[1:], + "y_hat": y_hat, + } + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + side_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + y_strings = strings + n = len(y_strings[0]) + assert len(y_strings) == 2 + assert all(len(x) == n for x in y_strings) + + c, h, w = shape + y_i_shape = (c, h, w // 2) + y_hat_packed = side_params.new_zeros((2, n, c, h, w // 2)) + side_params_packed = _ckb.unembed(side_params, anchor_parity=self.anchor_parity) + + for i, step in enumerate(("anchor", "non_anchor")): + ctx_params_i = self._ctx_params_packed( + i, step, side_params, side_params_packed, y_hat_packed + ) + params_i = self._entropy_parameters(step)(ctx_params_i) + scales_i, means_i = self.y._chunk(params_i) + selective_mask_i = _sel.selective_mask_packed( + self.selective_predictor, + i, + step, + side_params, + scales_i, + means_i, + anchor_parity=self.anchor_parity, + ) + y_out = _sel.apply_selective_decompression( + self.y, + y_strings[i], + y_i_shape, + params_i, + scales_i, + means_i, + selective_mask_i, + ) + y_hat_for_lrp = y_hat_packed.clone() + y_hat_for_lrp[i] = y_out["y_hat"] + y_hat_i = self._apply_lrp_packed( + i, step, side_params, params_i, y_hat_for_lrp + ) + y_hat_packed[i] = _sel.apply_selective_y_hat_packed( + y_hat_i, means_i, selective_mask_i + ) + + return { + "y_hat": _ckb.embed(y_hat_packed, anchor_parity=self.anchor_parity), + } + + def _ctx_params( + self, + step: str, + y: Tensor, + side_params: Tensor, + y_hat_steps: List[Tensor], + ) -> Tensor: + ctx_parts: List[Tensor] = [] + spatial = self._spatial_context_module(step) + if spatial is not None: + y_hat_for_ctx = _ckb.mask_all(y) if step == "anchor" else y_hat_steps[0] + ctx_parts.append( + self._apply_spatial_context(spatial, y_hat_for_ctx, side_params) + ) + ctx_parts.append(side_params) + if step == "non_anchor" and self.intra_channel_context_nonanchor is not None: + ctx_parts.append( + self.intra_channel_context_nonanchor(side_params, y_hat_steps[0]) + ) + return _ckb.merge(*ctx_parts) + + def _ctx_params_packed( + self, + step_index: int, + step: str, + side_params: Tensor, + side_params_packed: Tensor, + y_hat_packed: Tensor, + ) -> Tensor: + ctx_parts: List[Tensor] = [] + spatial = self._spatial_context_module(step) + if spatial is not None: + y_hat_full = _ckb.embed(y_hat_packed, anchor_parity=self.anchor_parity) + if step == "anchor": + y_hat_full = _ckb.mask_all(y_hat_full) + y_ctx = self._apply_spatial_context(spatial, y_hat_full, side_params) + ctx_parts.append( + _ckb.unembed(y_ctx, anchor_parity=self.anchor_parity)[step_index] + ) + ctx_parts.append(side_params_packed[step_index]) + if step == "non_anchor" and self.intra_channel_context_nonanchor is not None: + y_hat_full = _ckb.embed(y_hat_packed, anchor_parity=self.anchor_parity) + intra_ctx = self.intra_channel_context_nonanchor(side_params, y_hat_full) + ctx_parts.append( + _ckb.unembed(intra_ctx, anchor_parity=self.anchor_parity)[step_index] + ) + return _ckb.merge(*ctx_parts) + + def _spatial_context_module(self, step: str) -> Optional[nn.Module]: + if step == "anchor": + return self.spatial_context_anchor + return self.spatial_context_nonanchor + + def _apply_spatial_context( + self, + spatial: nn.Module, + y_hat: Tensor, + side_params: Tensor, + ) -> Tensor: + if getattr(spatial, "requires_side_params", False): + return spatial(y_hat, side_params=side_params) + return spatial(y_hat) + + def _entropy_parameters(self, step: str) -> nn.Module: + if step == "anchor": + return self.entropy_parameters_anchor + return self.entropy_parameters_nonanchor + + def _quantize(self, y: Tensor, means: Tensor) -> Tensor: + return quantize_ste(y - means) + means + + def _apply_lrp( + self, + step: str, + side_params: Tensor, + params: Tensor, + y_hat: Tensor, + lrp_input_y_hat: Optional[Tensor] = None, + ) -> Tensor: + lrp = self.lrp_anchor if step == "anchor" else self.lrp_nonanchor + if lrp is None: + return y_hat + if lrp_input_y_hat is None: + lrp_input_y_hat = y_hat + lrp_input = self._build_lrp_input(side_params, params, lrp_input_y_hat) + y_hat = y_hat + self.lrp_scale * self._activate_lrp(lrp(lrp_input)) + return _ckb.mask_all_but_step(y_hat, step, anchor_parity=self.anchor_parity) + + def _apply_lrp_packed( + self, + step_index: int, + step: str, + side_params: Tensor, + params: Tensor, + y_hat_packed: Tensor, + ) -> Tensor: + lrp = self.lrp_anchor if step == "anchor" else self.lrp_nonanchor + if lrp is None: + return y_hat_packed[step_index] + y_hat = _ckb.embed(y_hat_packed, anchor_parity=self.anchor_parity) + params_full = _ckb.embed_step( + step_index, + params, + side_params.shape[-1], + anchor_parity=self.anchor_parity, + ) + lrp_input = self._build_lrp_input(side_params, params_full, y_hat) + lrp_out = _ckb.unembed(lrp(lrp_input), anchor_parity=self.anchor_parity)[ + step_index + ] + return y_hat_packed[step_index] + self.lrp_scale * self._activate_lrp(lrp_out) + + def _build_lrp_input( + self, side_params: Tensor, params: Tensor, y_hat: Tensor + ) -> Tensor: + if self.lrp_input_builder is not None: + return self.lrp_input_builder(side_params, params, y_hat) + return _ckb.merge(side_params, y_hat) + + def _activate_lrp(self, residual: Tensor) -> Tensor: + if self.lrp_activation is None: + return residual + return self.lrp_activation(residual) diff --git a/compressai/models/_helpers/mlic/__init__.py b/compressai/models/_helpers/mlic/__init__.py new file mode 100644 index 00000000..b3d7ff6b --- /dev/null +++ b/compressai/models/_helpers/mlic/__init__.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .context import ( + ChannelContext, + LinearGlobalInterContext, + LinearGlobalIntraContext, + LocalContext, + StackedCheckerboardConv, + VanillaGlobalInterContext, + VanillaGlobalIntraContext, + WindowCheckerboardAttn, +) +from .transforms import ( + AnalysisTransform, + EntropyParameters, + HyperAnalysis, + HyperSynthesis, + LatentResidualPrediction, + SynthesisTransform, +) +from .utils import ( + checkerboard_anchor, + checkerboard_merge, + checkerboard_nonanchor, + checkerboard_split, + compress_symbols, + decompress_symbols, + squeeze_anchor, + squeeze_nonanchor, + unsqueeze_anchor, + unsqueeze_nonanchor, +) + +__all__ = [ + "AnalysisTransform", + "ChannelContext", + "EntropyParameters", + "HyperAnalysis", + "HyperSynthesis", + "LatentResidualPrediction", + "LinearGlobalInterContext", + "LinearGlobalIntraContext", + "LocalContext", + "SynthesisTransform", + "StackedCheckerboardConv", + "VanillaGlobalInterContext", + "VanillaGlobalIntraContext", + "WindowCheckerboardAttn", + "checkerboard_anchor", + "checkerboard_merge", + "checkerboard_nonanchor", + "checkerboard_split", + "compress_symbols", + "decompress_symbols", + "squeeze_anchor", + "squeeze_nonanchor", + "unsqueeze_anchor", + "unsqueeze_nonanchor", +] diff --git a/compressai/models/_helpers/mlic/context.py b/compressai/models/_helpers/mlic/context.py new file mode 100644 index 00000000..4b6f578f --- /dev/null +++ b/compressai/models/_helpers/mlic/context.py @@ -0,0 +1,656 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch import Tensor + +from compressai.layers.attn.swin import Mlp + +from .utils import ( + build_position_index, + checkerboard_split, + squeeze_anchor, + squeeze_nonanchor, + unsqueeze_anchor, + unsqueeze_nonanchor, +) + +__all__ = [ + "ChannelContext", + "LinearGlobalInterContext", + "LinearGlobalIntraContext", + "LocalContext", + "StackedCheckerboardConv", + "VanillaGlobalInterContext", + "VanillaGlobalIntraContext", + "WindowCheckerboardAttn", +] + + +def _pointwise_then_dwconv(dim: int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + ) + + +def _checkerboard_coordinates( + height: int, + width: int, + *, + anchor: bool, + device: torch.device, +) -> Tensor: + rows = torch.arange(height, device=device) + cols = torch.arange(width, device=device) + coords = torch.stack(torch.meshgrid(rows, cols, indexing="ij"), dim=-1) + coords = coords.reshape(height * width, 2) + parity = 1 if anchor else 0 + return coords[(coords[:, 0] + coords[:, 1]) % 2 == parity] + + +def _local_exclusion_mask( + query_coords: Tensor, + key_coords: Tensor, + radius: int, +) -> Tensor: + if radius < 0: + return query_coords.new_zeros( + (query_coords.shape[0], key_coords.shape[0]), + dtype=torch.bool, + ) + distance = (query_coords[:, None, :] - key_coords[None, :, :]).abs() + return (distance <= radius).all(dim=-1) + + +def _quadratic_attention( + queries: Tensor, + keys: Tensor, + values: Tensor, + num_heads: int, + attention_mask: Optional[Tensor] = None, +) -> Tensor: + batch_size, channels, query_count = queries.shape + key_count = keys.shape[-1] + head_dim = channels // num_heads + + query = queries.reshape(batch_size, num_heads, head_dim, query_count).transpose( + -2, + -1, + ) + key = keys.reshape(batch_size, num_heads, head_dim, key_count).transpose(-2, -1) + value = values.reshape(batch_size, num_heads, head_dim, key_count).transpose( + -2, + -1, + ) + + attention = (query @ key.transpose(-2, -1)) * (head_dim**-0.5) + if attention_mask is not None: + attention = attention.masked_fill( + attention_mask.unsqueeze(0).unsqueeze(0), + -100.0, + ) + attention = F.softmax(attention, dim=-1) + output = attention @ value + return output.transpose(-2, -1).reshape(batch_size, channels, query_count) + + +class StackedCheckerboardConv(nn.Module): + """Stacked convolutional local context used by MLIC.""" + + def __init__(self, dim: int, kernel: int = 5, num_layers: int = 3) -> None: + super().__init__() + if kernel <= 0 or kernel % 2 == 0: + raise ValueError( + "StackedCheckerboardConv kernel must be a positive odd integer" + ) + if num_layers <= 0 or num_layers % 2 == 0: + raise ValueError( + "StackedCheckerboardConv num_layers must be a positive odd" + ) + + layers: List[nn.Module] = [] + padding = kernel // 2 + for index in range(num_layers): + out_channels = dim * 2 if index == num_layers - 1 else dim + layers.append( + nn.Conv2d(dim, out_channels, kernel_size=kernel, padding=padding) + ) + if index != num_layers - 1: + layers.append(nn.GELU()) + self.context = nn.Sequential(*layers) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.context(input_tensor) + + +class LocalContext(nn.Module): + """Windowed intra-slice context used by MLIC++.""" + + def __init__( + self, + dim: int = 32, + window_size: int = 5, + mlp_ratio: float = 2.0, + num_heads: int = 2, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError("LocalContext dim must be divisible by num_heads") + + self.H = -1 + self.W = -1 + self.dim = dim + self.window_size = window_size + self.window_area = window_size * window_size + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.unfold = nn.Unfold( + kernel_size=window_size, + stride=1, + padding=(window_size - 1) // 2, + ) + self.relative_position_table = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) + ) + nn.init.trunc_normal_(self.relative_position_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + self.proj = nn.Linear(dim * 2, dim * 2) + self.mlp = Mlp( + in_features=dim * 2, + hidden_features=int(dim * 2 * mlp_ratio), + out_features=dim * 2, + ) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim * 2) + self.register_buffer( + "relative_position_index", + build_position_index((window_size, window_size)), + ) + self.attn_mask: Optional[Tensor] = None + self.fusion = nn.Conv2d(dim, dim * 2, kernel_size=window_size) + + def update_resolution( + self, + height: int, + width: int, + device: torch.device, + mask: Optional[Tensor] = None, + ) -> bool: + updated = False + if self.H != height or self.W != width: + self.H = height + self.W = width + if mask is not None: + self.attn_mask = mask.to(device) + return True + + checkerboard = torch.zeros( + (1, 2, height, width), + device=device, + requires_grad=False, + ) + checkerboard[:, :, 0::2, 1::2] = 1 + checkerboard[:, :, 1::2, 0::2] = 1 + qk_windows = self.unfold(checkerboard).permute(0, 2, 1) + qk_windows = qk_windows.reshape( + 1, + height * width, + 2, + 1, + self.window_size, + self.window_size, + ).permute(2, 0, 1, 3, 4, 5) + q_windows, k_windows = qk_windows[0], qk_windows[1] + query = q_windows.reshape(1, height * width, 1, self.window_area).permute( + 0, + 1, + 3, + 2, + ) + key = k_windows.reshape(1, height * width, 1, self.window_area).permute( + 0, + 1, + 3, + 2, + ) + attn_mask = query @ key.transpose(-2, -1) + attn_mask = attn_mask.masked_fill(attn_mask == 0.0, float(-100.0)) + self.attn_mask = attn_mask.masked_fill(attn_mask == 1, 0.0)[0].detach() + updated = True + return updated + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + num_tokens = height * width + self.update_resolution(height, width, input_tensor.device) + + output = input_tensor.reshape(batch_size, channels, num_tokens).permute(0, 2, 1) + output = self.norm1(output) + + qkv = self.qkv_proj(output).reshape(batch_size, height, width, 3, channels) + qkv = qkv.permute(3, 0, 4, 1, 2).contiguous() + query, key, value = qkv[0], qkv[1], qkv[2] + + qkv_windows = self.unfold(torch.cat([query, key, value], dim=1)).permute( + 0, 2, 1 + ) + qkv_windows = qkv_windows.reshape( + batch_size, + num_tokens, + 3, + channels, + self.window_size, + self.window_size, + ).permute(2, 0, 1, 3, 4, 5) + query_windows, key_windows, value_windows = qkv_windows + + query = query_windows.reshape( + batch_size, + num_tokens, + self.head_dim, + self.num_heads, + self.window_area, + ).permute(0, 1, 3, 4, 2) + key = key_windows.reshape( + batch_size, + num_tokens, + self.head_dim, + self.num_heads, + self.window_area, + ).permute(0, 1, 3, 4, 2) + value = value_windows.reshape( + batch_size, + num_tokens, + self.head_dim, + self.num_heads, + self.window_area, + ).permute(0, 1, 3, 4, 2) + + attention = (query * self.scale) @ key.transpose(-2, -1) + relative_position_bias = self.relative_position_table[ + self.relative_position_index.reshape(-1) + ].view(self.window_area, self.window_area, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention = attention + relative_position_bias.unsqueeze(0).unsqueeze(1) + + if self.attn_mask is None: + raise RuntimeError("LocalContext attention mask is not initialized") + attention = attention + self.attn_mask.unsqueeze(0).unsqueeze(2) + attention = self.softmax(attention) + + output = (attention @ value).reshape( + batch_size, + num_tokens, + self.num_heads, + self.window_size, + self.window_size, + self.head_dim, + ) + output = output.permute(0, 1, 3, 4, 2, 5).reshape( + batch_size * num_tokens, + self.window_size, + self.window_size, + channels, + ) + output = output.permute(0, 3, 1, 2) + output = self.fusion(output).reshape(batch_size, num_tokens, channels * 2) + output = self.proj(output) + output = output + self.mlp(self.norm2(output)) + return output.permute(0, 2, 1).reshape(batch_size, channels * 2, height, width) + + +class WindowCheckerboardAttn(LocalContext): + """Overlapped window checkerboard attention used by MLIC+.""" + + +class ChannelContext(nn.Module): + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + self.fushion = nn.Sequential( + nn.Conv2d(in_dim, 192, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Conv2d(128, out_dim * 4, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, channel_params: Tensor) -> Tensor: + return self.fushion(channel_params) + + +class LinearGlobalIntraContext(nn.Module): + def __init__(self, dim: int = 32, num_heads: int = 2) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + "LinearGlobalIntraContext dim must be divisible by num_heads" + ) + + self.dim = dim + self.num_heads = num_heads + self.keys = _pointwise_then_dwconv(dim) + self.queries = _pointwise_then_dwconv(dim) + self.values = _pointwise_then_dwconv(dim) + self.reprojection = nn.Conv2d(dim, dim * 2, kernel_size=5, stride=1, padding=2) + self.mlp = nn.Sequential( + nn.Conv2d(dim * 2, dim * 4, kernel_size=1, stride=1), + nn.GELU(), + nn.Conv2d( + dim * 4, + dim * 4, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 4, + ), + nn.GELU(), + nn.Conv2d(dim * 4, dim * 2, kernel_size=1, stride=1), + ) + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + batch_size, _, height, width = x1.shape + x1_anchor, x1_nonanchor = checkerboard_split(x1) + queries = squeeze_nonanchor(self.queries(x1_nonanchor)).reshape( + batch_size, + self.dim, + height * width // 2, + ) + keys = squeeze_anchor(self.keys(x1_anchor)).reshape( + batch_size, + self.dim, + height * width // 2, + ) + values = squeeze_anchor(self.values(x2)).reshape( + batch_size, + self.dim, + height * width // 2, + ) + head_dim = self.dim // self.num_heads + + attended_values = [] + for index in range(self.num_heads): + key = F.softmax( + keys[:, index * head_dim : (index + 1) * head_dim, :], dim=2 + ) + query = F.softmax( + queries[:, index * head_dim : (index + 1) * head_dim, :], + dim=1, + ) + value = values[:, index * head_dim : (index + 1) * head_dim, :] + key = unsqueeze_anchor( + key.reshape(batch_size, head_dim, height, width // 2) + ) + key = key.reshape(batch_size, head_dim, height * width) + value = unsqueeze_anchor( + value.reshape(batch_size, head_dim, height, width // 2) + ) + value = value.reshape(batch_size, head_dim, height * width) + query = unsqueeze_nonanchor( + query.reshape(batch_size, head_dim, height, width // 2) + ) + query = query.reshape(batch_size, head_dim, height * width) + context = key @ value.transpose(1, 2) + attended_values.append( + (context.transpose(1, 2) @ query).reshape( + batch_size, head_dim, height, width + ) + ) + + attention = self.reprojection(torch.cat(attended_values, dim=1)) + return attention + self.mlp(attention) + + +class LinearGlobalInterContext(nn.Module): + def __init__(self, dim: int = 32, out_dim: int = 64, num_heads: int = 2) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + "LinearGlobalInterContext dim must be divisible by num_heads" + ) + + self.dim = dim + self.num_heads = num_heads + self.keys = _pointwise_then_dwconv(dim) + self.queries = _pointwise_then_dwconv(dim) + self.values = _pointwise_then_dwconv(dim) + self.reprojection = nn.Conv2d( + dim, + out_dim * 3 // 2, + kernel_size=5, + stride=1, + padding=2, + ) + self.mlp = nn.Sequential( + nn.Conv2d(out_dim * 3 // 2, out_dim * 2, kernel_size=1, stride=1), + nn.GELU(), + nn.Conv2d( + out_dim * 2, + out_dim * 2, + kernel_size=3, + stride=1, + padding=1, + groups=out_dim * 2, + ), + nn.GELU(), + nn.Conv2d(out_dim * 2, out_dim, kernel_size=1, stride=1), + ) + self.skip = nn.Conv2d(out_dim * 3 // 2, out_dim, kernel_size=1, stride=1) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, _, height, width = input_tensor.shape + queries = self.queries(input_tensor).reshape( + batch_size, self.dim, height * width + ) + keys = self.keys(input_tensor).reshape(batch_size, self.dim, height * width) + values = self.values(input_tensor).reshape(batch_size, self.dim, height * width) + head_dim = self.dim // self.num_heads + + attended_values = [] + for index in range(self.num_heads): + key = F.softmax( + keys[:, index * head_dim : (index + 1) * head_dim, :], dim=2 + ) + query = F.softmax( + queries[:, index * head_dim : (index + 1) * head_dim, :], + dim=1, + ) + value = values[:, index * head_dim : (index + 1) * head_dim, :] + context = key @ value.transpose(1, 2) + attended_values.append( + (context.transpose(1, 2) @ query).reshape( + batch_size, head_dim, height, width + ) + ) + + attention = self.reprojection(torch.cat(attended_values, dim=1)) + return self.skip(attention) + self.mlp(attention) + + +class VanillaGlobalIntraContext(nn.Module): + """Quadratic intra-slice global context used by MLIC and MLIC+.""" + + def __init__( + self, + dim: int = 32, + num_heads: int = 2, + local_mask_radius: int = 2, + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + "VanillaGlobalIntraContext dim must be divisible by num_heads" + ) + + self.dim = dim + self.num_heads = num_heads + self.local_mask_radius = local_mask_radius + self.keys = _pointwise_then_dwconv(dim) + self.queries = _pointwise_then_dwconv(dim) + self.values = _pointwise_then_dwconv(dim) + self.reprojection = nn.Conv2d(dim, dim * 2, kernel_size=5, stride=1, padding=2) + self.mlp = nn.Sequential( + nn.Conv2d(dim * 2, dim * 4, kernel_size=1, stride=1), + nn.GELU(), + nn.Conv2d( + dim * 4, + dim * 4, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 4, + ), + nn.GELU(), + nn.Conv2d(dim * 4, dim * 2, kernel_size=1, stride=1), + ) + + def _attention_mask( + self, + height: int, + width: int, + device: torch.device, + ) -> Tensor: + query_coords = _checkerboard_coordinates( + height, + width, + anchor=False, + device=device, + ) + key_coords = _checkerboard_coordinates( + height, + width, + anchor=True, + device=device, + ) + return _local_exclusion_mask(query_coords, key_coords, self.local_mask_radius) + + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: + batch_size, _, height, width = x1.shape + x1_anchor, x1_nonanchor = checkerboard_split(x1) + queries = squeeze_nonanchor(self.queries(x1_nonanchor)).reshape( + batch_size, + self.dim, + height * width // 2, + ) + keys = squeeze_anchor(self.keys(x1_anchor)).reshape( + batch_size, + self.dim, + height * width // 2, + ) + values = squeeze_anchor(self.values(x2)).reshape( + batch_size, + self.dim, + height * width // 2, + ) + + attention_mask = self._attention_mask(height, width, x1.device) + attention = _quadratic_attention( + queries, + keys, + values, + self.num_heads, + attention_mask, + ) + attention = unsqueeze_nonanchor( + attention.reshape(batch_size, self.dim, height, width // 2) + ) + attention = self.reprojection(attention) + return attention + self.mlp(attention) + + +class VanillaGlobalInterContext(nn.Module): + """Quadratic inter-slice global context used by MLIC+.""" + + def __init__( + self, + in_dim: int = 32, + out_dim: int = 64, + num_heads: int = 2, + ) -> None: + super().__init__() + if in_dim % num_heads != 0: + raise ValueError( + "VanillaGlobalInterContext in_dim must be divisible by num_heads" + ) + + self.dim = in_dim + self.num_heads = num_heads + self.keys = _pointwise_then_dwconv(in_dim) + self.queries = _pointwise_then_dwconv(in_dim) + self.values = _pointwise_then_dwconv(in_dim) + self.reprojection = nn.Conv2d( + in_dim, + out_dim * 3 // 2, + kernel_size=5, + stride=1, + padding=2, + ) + self.mlp = nn.Sequential( + nn.Conv2d(out_dim * 3 // 2, out_dim * 2, kernel_size=1, stride=1), + nn.GELU(), + nn.Conv2d( + out_dim * 2, + out_dim * 2, + kernel_size=3, + stride=1, + padding=1, + groups=out_dim * 2, + ), + nn.GELU(), + nn.Conv2d(out_dim * 2, out_dim, kernel_size=1, stride=1), + ) + self.skip = nn.Conv2d(out_dim * 3 // 2, out_dim, kernel_size=1, stride=1) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, _, height, width = input_tensor.shape + num_tokens = height * width + queries = self.queries(input_tensor).reshape(batch_size, self.dim, num_tokens) + keys = self.keys(input_tensor).reshape(batch_size, self.dim, num_tokens) + values = self.values(input_tensor).reshape(batch_size, self.dim, num_tokens) + + attention = _quadratic_attention(queries, keys, values, self.num_heads) + attention = attention.reshape(batch_size, self.dim, height, width) + attention = self.reprojection(attention) + return self.skip(attention) + self.mlp(attention) diff --git a/compressai/models/_helpers/mlic/transforms.py b/compressai/models/_helpers/mlic/transforms.py new file mode 100644 index 00000000..46db7b27 --- /dev/null +++ b/compressai/models/_helpers/mlic/transforms.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Type + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.layers import ( + GDN, + conv1x1, + conv3x3, + subpel_conv3x3, +) + +__all__ = [ + "AnalysisTransform", + "EntropyParameters", + "HyperAnalysis", + "HyperSynthesis", + "LatentResidualPrediction", + "SynthesisTransform", +] + + +class _ResidualBlockWithStride(nn.Module): + def __init__(self, in_ch: int, out_ch: int, stride: int = 2) -> None: + super().__init__() + self.conv1 = conv3x3(in_ch, out_ch, stride=stride) + self.act = nn.GELU() + self.conv2 = conv3x3(out_ch, out_ch) + self.gdn = GDN(out_ch) + self.skip = ( + conv1x1(in_ch, out_ch, stride=stride) + if stride != 1 or in_ch != out_ch + else None + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + identity = input_tensor + output = self.gdn(self.conv2(self.act(self.conv1(input_tensor)))) + if self.skip is not None: + identity = self.skip(input_tensor) + return output + identity + + +class _ResidualBlockUpsample(nn.Module): + def __init__(self, in_ch: int, out_ch: int, upsample: int = 2) -> None: + super().__init__() + self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) + self.act = nn.GELU() + self.conv = conv3x3(out_ch, out_ch) + self.igdn = GDN(out_ch, inverse=True) + self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.subpel_conv(input_tensor) + output = self.igdn(self.conv(self.act(output))) + return output + self.upsample(input_tensor) + + +class _ResidualBlock(nn.Module): + def __init__(self, in_ch: int, out_ch: int) -> None: + super().__init__() + self.conv1 = conv3x3(in_ch, out_ch) + self.act = nn.GELU() + self.conv2 = conv3x3(out_ch, out_ch) + self.skip = conv1x1(in_ch, out_ch) if in_ch != out_ch else None + + def forward(self, input_tensor: Tensor) -> Tensor: + identity = input_tensor + output = self.act(self.conv2(self.act(self.conv1(input_tensor)))) + if self.skip is not None: + identity = self.skip(input_tensor) + return output + identity + + +class AnalysisTransform(nn.Module): + """MLIC++ analysis transform.""" + + def __init__(self, N: int, M: int) -> None: + super().__init__() + self.analysis_transform = nn.Sequential( + _ResidualBlockWithStride(3, N, stride=2), + _ResidualBlock(N, N), + _ResidualBlockWithStride(N, N, stride=2), + _ResidualBlock(N, N), + _ResidualBlockWithStride(N, N, stride=2), + _ResidualBlock(N, N), + conv3x3(N, M, stride=2), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.analysis_transform(input_tensor) + + +class HyperAnalysis(nn.Module): + def __init__(self, M: int = 192, N: int = 192) -> None: + super().__init__() + self.M = M + self.N = N + self.reduction = nn.Sequential( + conv3x3(M, N), + nn.GELU(), + conv3x3(N, N), + nn.GELU(), + conv3x3(N, N, stride=2), + nn.GELU(), + conv3x3(N, N), + nn.GELU(), + conv3x3(N, N, stride=2), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.reduction(input_tensor) + + +class HyperSynthesis(nn.Module): + def __init__(self, M: int = 192, N: int = 192) -> None: + super().__init__() + self.M = M + self.N = N + self.increase = nn.Sequential( + conv3x3(N, M), + nn.GELU(), + subpel_conv3x3(M, M, 2), + nn.GELU(), + conv3x3(M, M * 3 // 2), + nn.GELU(), + subpel_conv3x3(M * 3 // 2, M * 3 // 2, 2), + nn.GELU(), + conv3x3(M * 3 // 2, M * 2), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.increase(input_tensor) + + +class SynthesisTransform(nn.Module): + """MLIC++ synthesis transform.""" + + def __init__(self, N: int, M: int) -> None: + super().__init__() + self.synthesis_transform = nn.Sequential( + _ResidualBlock(M, M), + _ResidualBlockUpsample(M, N, 2), + _ResidualBlock(N, N), + _ResidualBlockUpsample(N, N, 2), + _ResidualBlock(N, N), + _ResidualBlockUpsample(N, N, 2), + _ResidualBlock(N, N), + subpel_conv3x3(N, 3, 2), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.synthesis_transform(input_tensor) + + +class EntropyParameters(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.fusion = nn.Sequential( + nn.Conv2d(in_dim, 320, kernel_size=1, stride=1, padding=0), + act(), + nn.Conv2d(320, 256, kernel_size=1, stride=1, padding=0), + act(), + nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0), + act(), + nn.Conv2d(128, out_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, params: Tensor) -> Tensor: + return self.fusion(params) + + +class LatentResidualPrediction(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lrp_transform = nn.Sequential( + conv3x3(in_dim, 224), + act(), + conv3x3(224, 128), + act(), + conv3x3(128, out_dim), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return 0.5 * torch.tanh(self.lrp_transform(input_tensor)) diff --git a/compressai/models/_helpers/mlic/utils.py b/compressai/models/_helpers/mlic/utils.py new file mode 100644 index 00000000..b1eb8c9e --- /dev/null +++ b/compressai/models/_helpers/mlic/utils.py @@ -0,0 +1,181 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Callable, List, Sequence, Tuple, Union + +import torch + +from torch import Tensor + +from compressai.entropy_models import EntropyModel + +__all__ = [ + "build_position_index", + "checkerboard_anchor", + "checkerboard_merge", + "checkerboard_nonanchor", + "checkerboard_split", + "compress_symbols", + "decompress_symbols", + "squeeze_anchor", + "squeeze_nonanchor", + "unsqueeze_anchor", + "unsqueeze_nonanchor", +] + + +def build_position_index(window_size: Union[int, Tuple[int, int]]) -> Tensor: + if isinstance(window_size, int): + window_height = window_width = window_size + else: + window_height, window_width = window_size + + coords = torch.stack( + torch.meshgrid( + torch.arange(window_height), + torch.arange(window_width), + indexing="ij", + ) + ) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += window_height - 1 + relative_coords[:, :, 1] += window_width - 1 + relative_coords[:, :, 0] *= 2 * window_width - 1 + return relative_coords.sum(-1) + + +def checkerboard_split(input_tensor: Tensor) -> Tuple[Tensor, Tensor]: + return checkerboard_anchor(input_tensor), checkerboard_nonanchor(input_tensor) + + +def checkerboard_merge(anchor: Tensor, nonanchor: Tensor) -> Tensor: + return anchor + nonanchor + + +def checkerboard_anchor(input_tensor: Tensor) -> Tensor: + output = torch.zeros_like(input_tensor) + output[:, :, 0::2, 1::2] = input_tensor[:, :, 0::2, 1::2] + output[:, :, 1::2, 0::2] = input_tensor[:, :, 1::2, 0::2] + return output + + +def checkerboard_nonanchor(input_tensor: Tensor) -> Tensor: + output = torch.zeros_like(input_tensor) + output[:, :, 0::2, 0::2] = input_tensor[:, :, 0::2, 0::2] + output[:, :, 1::2, 1::2] = input_tensor[:, :, 1::2, 1::2] + return output + + +def squeeze_anchor(input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + output = input_tensor.new_zeros((batch_size, channels, height, width // 2)) + output[:, :, 0::2, :] = input_tensor[:, :, 0::2, 1::2] + output[:, :, 1::2, :] = input_tensor[:, :, 1::2, 0::2] + return output + + +def squeeze_nonanchor(input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + output = input_tensor.new_zeros((batch_size, channels, height, width // 2)) + output[:, :, 0::2, :] = input_tensor[:, :, 0::2, 0::2] + output[:, :, 1::2, :] = input_tensor[:, :, 1::2, 1::2] + return output + + +def unsqueeze_anchor(input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + output = input_tensor.new_zeros((batch_size, channels, height, width * 2)) + output[:, :, 0::2, 1::2] = input_tensor[:, :, 0::2, :] + output[:, :, 1::2, 0::2] = input_tensor[:, :, 1::2, :] + return output + + +def unsqueeze_nonanchor(input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + output = input_tensor.new_zeros((batch_size, channels, height, width * 2)) + output[:, :, 0::2, 0::2] = input_tensor[:, :, 0::2, :] + output[:, :, 1::2, 1::2] = input_tensor[:, :, 1::2, :] + return output + + +def compress_symbols( + gaussian_conditional: EntropyModel, + input_tensor: Tensor, + scales: Tensor, + means: Tensor, + squeeze_fn: Callable[[Tensor], Tensor], + unsqueeze_fn: Callable[[Tensor], Tensor], + symbols_list: List[int], + indexes_list: List[int], +) -> Tensor: + input_half = squeeze_fn(input_tensor) + scales_half = squeeze_fn(scales) + means_half = squeeze_fn(means) + indexes = gaussian_conditional.build_indexes(scales_half) + quantized = gaussian_conditional.quantize(input_half, "symbols", means_half) + symbols_list.extend(quantized.reshape(-1).tolist()) + indexes_list.extend(indexes.reshape(-1).tolist()) + return unsqueeze_fn(quantized + means_half) + + +def decompress_symbols( + gaussian_conditional: EntropyModel, + scales: Tensor, + means: Tensor, + decoder: object, + cdf: Sequence[Sequence[int]], + cdf_lengths: Sequence[int], + offsets: Sequence[int], + squeeze_fn: Callable[[Tensor], Tensor], + unsqueeze_fn: Callable[[Tensor], Tensor], +) -> Tensor: + scales_half = squeeze_fn(scales) + means_half = squeeze_fn(means) + indexes = gaussian_conditional.build_indexes(scales_half) + decoded = decoder.decode_stream( + indexes.reshape(-1).tolist(), + cdf, + cdf_lengths, + offsets, + ) + decoded_tensor = torch.tensor( + decoded, + device=scales.device, + dtype=means_half.dtype, + ).reshape(scales_half.shape) + return unsqueeze_fn(decoded_tensor + means_half) diff --git a/compressai/models/_helpers/mlicv2/__init__.py b/compressai/models/_helpers/mlicv2/__init__.py new file mode 100644 index 00000000..8d95b8d0 --- /dev/null +++ b/compressai/models/_helpers/mlicv2/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts the MLIC family design from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .context import ContextReweighting, GSCModule, HGCPModule, RoPE2D +from .transforms import SimpleTokenMixing, STMAnalysis, STMSynthesis + +__all__ = [ + "ContextReweighting", + "GSCModule", + "HGCPModule", + "RoPE2D", + "STMAnalysis", + "STMSynthesis", + "SimpleTokenMixing", +] diff --git a/compressai/models/_helpers/mlicv2/context.py b/compressai/models/_helpers/mlicv2/context.py new file mode 100644 index 00000000..6627a701 --- /dev/null +++ b/compressai/models/_helpers/mlicv2/context.py @@ -0,0 +1,318 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts the MLIC family design from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import LayerNorm2d +from torch import Tensor + +from compressai.layers import conv3x3 +from compressai.models._helpers.mlic.utils import ( + checkerboard_anchor, + checkerboard_merge, + checkerboard_nonanchor, +) + +__all__ = [ + "ContextReweighting", + "GSCModule", + "HGCPModule", + "RoPE2D", +] + + +class _Gate(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.net = nn.Sequential( + LayerNorm2d(dim), + nn.Conv2d(dim, dim, kernel_size=1), + nn.GELU(), + nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim), + nn.GELU(), + nn.Conv2d(dim, dim, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.net(x) * x + + +def _pointwise_then_dwconv(dim: int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=1), + nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim), + ) + + +class ContextReweighting(nn.Module): + """Channel-wise attention over an already captured spatial context.""" + + def __init__(self, dim: int) -> None: + super().__init__() + self.norm = LayerNorm2d(dim) + self.queries = nn.Conv2d(dim, dim, kernel_size=1) + self.keys = nn.Conv2d(dim, dim, kernel_size=1) + self.values = nn.Conv2d(dim, dim, kernel_size=1) + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + self.out_norm = LayerNorm2d(dim) + self.gate = _Gate(dim) + + def channel_attention(self, input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + num_positions = height * width + normalized = self.norm(input_tensor) + queries = self.queries(normalized).reshape(batch_size, channels, num_positions) + keys = self.keys(normalized).reshape(batch_size, channels, num_positions) + return F.softmax( + queries @ keys.transpose(1, 2) * (num_positions**-0.5), + dim=-1, + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + attention = self.channel_attention(input_tensor) + normalized = self.norm(input_tensor) + values = self.values(normalized).reshape(batch_size, channels, height * width) + output = attention @ values + output = output.reshape(batch_size, channels, height, width) + output = self.proj(output) + output = self.out_norm(output) + return input_tensor + output + self.gate(output) + + +class RoPE2D(nn.Module): + """Two-dimensional rotary position embedding for NCHW tensors.""" + + def __init__(self, dim: int, learnable_thetas: bool = True) -> None: + super().__init__() + if dim % 2 != 0: + raise ValueError("RoPE2D dim must be even") + self.dim = int(dim) + theta_x = torch.tensor(10000.0) + theta_y = torch.tensor(10000.0) + if learnable_thetas: + self.theta_x = nn.Parameter(theta_x) + self.theta_y = nn.Parameter(theta_y) + else: + self.register_buffer("theta_x", theta_x) + self.register_buffer("theta_y", theta_y) + freq = torch.arange(0, dim, 2, dtype=torch.float32) / float(dim) + self.register_buffer("frequency", freq) + + def _angles(self, height: int, width: int, device: torch.device) -> Tensor: + rows = torch.arange(height, device=device, dtype=self.frequency.dtype) + cols = torch.arange(width, device=device, dtype=self.frequency.dtype) + yy, xx = torch.meshgrid(rows, cols, indexing="ij") + theta_x = self.theta_x.to(device=device, dtype=self.frequency.dtype).abs() + theta_y = self.theta_y.to(device=device, dtype=self.frequency.dtype).abs() + inv_x = theta_x.clamp_min(1.0).pow(-self.frequency.to(device)) + inv_y = theta_y.clamp_min(1.0).pow(-self.frequency.to(device)) + return xx[..., None] * inv_x + yy[..., None] * inv_y + + def rotate(self, input_tensor: Tensor) -> Tensor: + batch_size, channels, height, width = input_tensor.shape + if channels != self.dim: + raise ValueError(f"Expected {self.dim} channels, got {channels}") + angles = self._angles(height, width, input_tensor.device).to(input_tensor.dtype) + cos = angles.cos().permute(2, 0, 1).unsqueeze(0) + sin = angles.sin().permute(2, 0, 1).unsqueeze(0) + pairs = input_tensor.reshape(batch_size, channels // 2, 2, height, width) + x_even = pairs[:, :, 0] + x_odd = pairs[:, :, 1] + rotated = torch.stack( + (x_even * cos - x_odd * sin, x_even * sin + x_odd * cos), + dim=2, + ) + return rotated.reshape(batch_size, channels, height, width) + + def forward( + self, query: Tensor, key: Optional[Tensor] = None + ) -> Tensor | Tuple[Tensor, Tensor]: + query = self.rotate(query) + if key is None: + return query + return query, self.rotate(key) + + +class HGCPModule(nn.Module): + """Hyperprior-guided global correlation prediction for the first slice.""" + + def __init__( + self, + M: int, + slice_ch: int, + out_ch: Optional[int] = None, + num_heads: int = 2, + ) -> None: + super().__init__() + if slice_ch % num_heads != 0: + raise ValueError("HGCPModule slice_ch must be divisible by num_heads") + self.M = int(M) + self.slice_ch = int(slice_ch) + self.out_ch = int(out_ch or 2 * slice_ch) + self.num_heads = int(num_heads) + self.queries = _pointwise_then_dwconv(M) + self.keys = _pointwise_then_dwconv(M) + self.hyper_values = _pointwise_then_dwconv(M) + self.slice_values = nn.Conv2d(slice_ch, M, kernel_size=1) + self.proj = conv3x3(M, self.out_ch) + self.gate = _Gate(self.out_ch) + + def _linear_attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + batch_size, channels, height, width = query.shape + head_dim = channels // self.num_heads + token_count = height * width + outputs = [] + for index in range(self.num_heads): + start = index * head_dim + end = (index + 1) * head_dim + query_i = query[:, start:end].reshape(batch_size, head_dim, token_count) + key_i = key[:, start:end].reshape(batch_size, head_dim, token_count) + value_i = value[:, start:end].reshape(batch_size, head_dim, token_count) + key_i = F.softmax(key_i, dim=2) + query_i = F.softmax(query_i, dim=1) + context = key_i @ value_i.transpose(1, 2) + outputs.append( + (context.transpose(1, 2) @ query_i).reshape( + batch_size, + head_dim, + height, + width, + ) + ) + return torch.cat(outputs, dim=1) + + def forward( + self, + hyper_params: Tensor, + anchor_y_hat: Optional[Tensor] = None, + ) -> Tensor: + hyper = hyper_params[:, : self.M] + if anchor_y_hat is None: + values = self.hyper_values(hyper) + else: + values = self.slice_values(anchor_y_hat) + + hyper_anchor = checkerboard_anchor(hyper) + hyper_nonanchor = checkerboard_nonanchor(hyper) + value_anchor = checkerboard_anchor(values) + value_nonanchor = checkerboard_nonanchor(values) + + anchor_context = self._linear_attention( + self.queries(hyper_anchor), + self.keys(hyper_nonanchor), + value_nonanchor, + ) + nonanchor_context = self._linear_attention( + self.queries(hyper_nonanchor), + self.keys(hyper_anchor), + value_anchor, + ) + context = checkerboard_merge(anchor_context, nonanchor_context) + context = self.proj(context) + return context + self.gate(context) + + +class GSCModule(nn.Module): + """Guided selective compression predictor compatible with leaf hooks.""" + + def __init__( + self, + slice_ch: int, + side_ch: Optional[int] = None, + hidden_ch: Optional[int] = None, + threshold: float = 0.3, + ) -> None: + super().__init__() + hidden_ch = int(hidden_ch or max(16, slice_ch)) + self.slice_ch = int(slice_ch) + self.threshold = float(threshold) + self.side_proj: nn.Module + if side_ch is None: + self.side_proj = nn.LazyConv2d(hidden_ch, kernel_size=1) + else: + self.side_proj = nn.Conv2d(side_ch, hidden_ch, kernel_size=1) + self.predictor = nn.Sequential( + conv3x3(hidden_ch + 3 * slice_ch, hidden_ch), + nn.GELU(), + conv3x3(hidden_ch, hidden_ch), + nn.GELU(), + nn.Conv2d(hidden_ch, slice_ch, kernel_size=1), + ) + self.step_bias = nn.Parameter(torch.zeros(2, slice_ch, 1, 1)) + self.scale_slope = nn.Parameter(torch.tensor(8.0)) + + def extra_repr(self) -> str: + return f"slice_ch={self.slice_ch}, threshold={self.threshold}" + + def forward( + self, + *, + side_params: Tensor, + scales: Tensor, + means: Tensor, + step: str, + ) -> Tensor | Dict[str, Tensor]: + if step == "anchor": + step_index = 0 + elif step == "non_anchor": + step_index = 1 + else: + raise ValueError(f'Invalid checkerboard step "{step}"') + + scale_prior = (scales >= self.threshold).to(scales.dtype) + features = torch.cat( + [ + self.side_proj(side_params), + scales, + means, + scale_prior, + ], + dim=1, + ) + logits = self.predictor(features) + logits = logits + self.scale_slope * (scales - self.threshold) + logits = logits + self.step_bias[step_index].unsqueeze(0) + selective_map = torch.sigmoid(logits) + return { + "selective_map": selective_map, + "scale_prior": scale_prior, + } diff --git a/compressai/models/_helpers/mlicv2/transforms.py b/compressai/models/_helpers/mlicv2/transforms.py new file mode 100644 index 00000000..0be4a112 --- /dev/null +++ b/compressai/models/_helpers/mlicv2/transforms.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts the MLIC family design from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Type + +import torch.nn as nn + +from timm.layers import LayerNorm2d +from torch import Tensor + +from compressai.layers import conv3x3, subpel_conv3x3 +from compressai.models._helpers.mlic.transforms import ( + _ResidualBlockUpsample, + _ResidualBlockWithStride, +) + +__all__ = [ + "STMAnalysis", + "STMSynthesis", + "SimpleTokenMixing", +] + + +class _DepthwiseResidualBlock(nn.Module): + def __init__(self, dim: int, act: Type[nn.Module] = nn.GELU) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim), + act(), + nn.Conv2d(dim, dim, kernel_size=1), + act(), + ) + + def forward(self, x: Tensor) -> Tensor: + return x + self.block(x) + + +class _Gate(nn.Module): + def __init__(self, dim: int, act: Type[nn.Module] = nn.GELU) -> None: + super().__init__() + self.gate = nn.Sequential( + LayerNorm2d(dim), + nn.Conv2d(dim, dim, kernel_size=1), + act(), + nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim), + act(), + nn.Conv2d(dim, dim, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.gate(x) * x + + +class SimpleTokenMixing(nn.Module): + """MetaFormer-style token mixing block used by MLICv2 transforms.""" + + def __init__(self, dim: int, act: Type[nn.Module] = nn.GELU) -> None: + super().__init__() + self.norm1 = LayerNorm2d(dim) + self.token_mixer = nn.Sequential( + _DepthwiseResidualBlock(dim, act=act), + nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim), + nn.Conv2d(dim, dim, kernel_size=1), + ) + self.norm2 = LayerNorm2d(dim) + self.gate = _Gate(dim, act=act) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.token_mixer(self.norm1(x)) + return x + self.gate(self.norm2(x)) + + +def _stm_pair(dim: int) -> nn.Sequential: + return nn.Sequential(SimpleTokenMixing(dim), SimpleTokenMixing(dim)) + + +class STMAnalysis(nn.Module): + """MLICv2 analysis transform with STM blocks replacing residual blocks.""" + + def __init__(self, N: int, M: int) -> None: + super().__init__() + self.analysis_transform = nn.Sequential( + _ResidualBlockWithStride(3, N, stride=2), + _stm_pair(N), + _ResidualBlockWithStride(N, N, stride=2), + _stm_pair(N), + _ResidualBlockWithStride(N, N, stride=2), + _stm_pair(N), + conv3x3(N, M, stride=2), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.analysis_transform(input_tensor) + + +class STMSynthesis(nn.Module): + """MLICv2 synthesis transform with STM blocks replacing residual blocks.""" + + def __init__(self, N: int, M: int) -> None: + super().__init__() + self.synthesis_transform = nn.Sequential( + _stm_pair(M), + _ResidualBlockUpsample(M, N, 2), + _stm_pair(N), + _ResidualBlockUpsample(N, N, 2), + _stm_pair(N), + _ResidualBlockUpsample(N, N, 2), + _stm_pair(N), + subpel_conv3x3(N, 3, 2), + ) + + def forward(self, input_tensor: Tensor) -> Tensor: + return self.synthesis_transform(input_tensor) diff --git a/compressai/models/_helpers/multi_context_slice.py b/compressai/models/_helpers/multi_context_slice.py new file mode 100644 index 00000000..3b0549d2 --- /dev/null +++ b/compressai/models/_helpers/multi_context_slice.py @@ -0,0 +1,376 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""MLIC-family per-slice entropy building blocks. + +These parameter-holding modules (prior aggregation, entropy-parameter fusion, +intra-channel context wrappers, MLICv2 refinements) are the pieces the MLIC +family models wire together, ELIC-style, inside their ``__init__``. The +per-variant assembly itself lives in ``compressai.models.mlic`` so the model +module owns its entropy-stack layout (matching the TCM/STF/CCA convention). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.models._helpers.mlic import ( + ChannelContext, + EntropyParameters, + LinearGlobalInterContext, + LinearGlobalIntraContext, + VanillaGlobalInterContext, + VanillaGlobalIntraContext, +) +from compressai.models._helpers.mlicv2 import ( + ContextReweighting, + HGCPModule, + RoPE2D, +) + +__all__ = [ + "_MlicppSideLayout", + "_MlicppPriorAggregation", + "_MlicppEntropyParameters", + "_MlicppIntraWrapper", + "_MlicIntraWrapper", + "_Mlicv2ContextRefinement", + "_Mlicv2HgcpAnchorContext", + "_build_lrp_input_builder", + "_select_global_inter_factory", +] + + +def _select_num_heads(channels: int) -> int: + target = max(1, channels // 32) + while channels % target != 0: + target -= 1 + return target + + +GlobalInterFactory = Callable[[int, int, int], nn.Module] + + +def _build_linear_global_inter_context( + prior_ch: int, + slice_ch: int, + num_heads: int, +) -> nn.Module: + return LinearGlobalInterContext( + dim=prior_ch, + out_dim=2 * slice_ch, + num_heads=num_heads, + ) + + +def _build_vanilla_global_inter_context( + prior_ch: int, + slice_ch: int, + num_heads: int, +) -> nn.Module: + return VanillaGlobalInterContext( + in_dim=prior_ch, + out_dim=2 * slice_ch, + num_heads=num_heads, + ) + + +@dataclass(frozen=True) +class _MlicppSideLayout: + M: int + slice_ch: int + slice_index: int + use_global_inter: bool = True + + @property + def hyper_ch(self) -> int: + return 2 * self.M + + @property + def prior_ch(self) -> int: + return self.slice_index * self.slice_ch + + @property + def inter_ch(self) -> int: + if self.slice_index and self.use_global_inter: + return 2 * self.slice_ch + return 0 + + @property + def channel_ch(self) -> int: + return 4 * self.slice_ch if self.slice_index else 0 + + @property + def side_ch(self) -> int: + return self.hyper_ch + self.prior_ch + self.inter_ch + self.channel_ch + + def split(self, side_params: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + hyper_end = self.hyper_ch + prior_end = hyper_end + self.prior_ch + inter_end = prior_end + self.inter_ch + channel_end = inter_end + self.channel_ch + return ( + side_params[:, :hyper_end], + side_params[:, hyper_end:prior_end], + side_params[:, prior_end:inter_end], + side_params[:, inter_end:channel_end], + ) + + def hyper_means(self, side_params: Tensor) -> Tensor: + hyper_params = side_params[:, : self.hyper_ch] + return hyper_params[:, self.M :] + + def previous_slice(self, side_params: Tensor) -> Tensor: + if self.prior_ch == 0: + batch_size, _, height, width = side_params.shape + return side_params.new_zeros( + batch_size, + self.slice_ch, + height, + width, + ) + start = self.hyper_ch + self.prior_ch - self.slice_ch + end = self.hyper_ch + self.prior_ch + return side_params[:, start:end] + + +class _MlicppPriorAggregation(nn.Module): + """Build the per-slice side layout consumed by MLIC-family leaf codecs.""" + + def __init__( + self, + M: int, + slice_ch: int, + slice_index: int, + global_inter_factory: Optional[ + GlobalInterFactory + ] = _build_linear_global_inter_context, + ) -> None: + super().__init__() + self.layout = _MlicppSideLayout( + M=M, + slice_ch=slice_ch, + slice_index=slice_index, + use_global_inter=global_inter_factory is not None, + ) + if slice_index: + prior_ch = self.layout.prior_ch + self.channel_part = ChannelContext(in_dim=prior_ch, out_dim=slice_ch) + self.global_inter_part = ( + global_inter_factory( + prior_ch, + slice_ch, + _select_num_heads(prior_ch), + ) + if global_inter_factory is not None + else None + ) + else: + self.channel_part = None + self.global_inter_part = None + + def forward(self, params: Tensor) -> Tensor: + hyper_params = params[:, : self.layout.hyper_ch] + if self.layout.slice_index == 0: + return hyper_params + + prior_y_hat = params[:, self.layout.hyper_ch :] + if self.channel_part is None: + raise RuntimeError("Expected prior aggregation modules for slice k > 0") + parts = [hyper_params, prior_y_hat] + if self.global_inter_part is not None: + parts.append(self.global_inter_part(prior_y_hat)) + parts.append(self.channel_part(prior_y_hat)) + return torch.cat(parts, dim=1) + + +class _MlicppEntropyParameters(EntropyParameters): + def __init__( + self, + layout: _MlicppSideLayout, + *, + step: str, + anchor_context_ch: int = 0, + ) -> None: + if step == "anchor": + in_dim = ( + anchor_context_ch + + layout.hyper_ch + + layout.inter_ch + + layout.channel_ch + ) + elif step == "non_anchor": + in_dim = ( + 2 * layout.slice_ch + + layout.hyper_ch + + layout.inter_ch + + layout.channel_ch + ) + if layout.slice_index: + in_dim += 2 * layout.slice_ch + else: + raise ValueError(f'Invalid checkerboard step "{step}"') + super().__init__(in_dim=in_dim, out_dim=2 * layout.slice_ch) + self.layout = layout + self.step = step + self.anchor_context_ch = int(anchor_context_ch) + + def forward(self, params: Tensor) -> Tensor: + if self.step == "anchor": + anchor_ctx: Optional[Tensor] = None + if self.anchor_context_ch: + anchor_ctx = params[:, : self.anchor_context_ch] + side_params = params[:, self.anchor_context_ch :] + else: + side_params = params + local_ctx: Optional[Tensor] = None + intra_ctx: Optional[Tensor] = None + else: + anchor_ctx = None + local_ctx = params[:, : 2 * self.layout.slice_ch] + side_start = 2 * self.layout.slice_ch + side_end = side_start + self.layout.side_ch + side_params = params[:, side_start:side_end] + intra_ctx = params[:, side_end:] + + hyper_params, _, global_inter_ctx, channel_ctx = self.layout.split(side_params) + if self.step == "anchor": + parts = [hyper_params] + if self.layout.slice_index: + parts = [global_inter_ctx, channel_ctx, hyper_params] + if anchor_ctx is not None: + parts = [anchor_ctx, *parts] + else: + if local_ctx is None: + raise RuntimeError("Expected local context for non-anchor step") + parts = [local_ctx, hyper_params] + if self.layout.slice_index: + if intra_ctx is None: + raise RuntimeError("Expected intra context for slice k > 0") + parts = [ + local_ctx, + intra_ctx, + global_inter_ctx, + channel_ctx, + hyper_params, + ] + return super().forward(torch.cat(parts, dim=1)) + + +class _MlicppIntraWrapper(LinearGlobalIntraContext): + def __init__(self, layout: _MlicppSideLayout) -> None: + super().__init__(dim=layout.slice_ch) + self.layout = layout + + def forward(self, side_params: Tensor, anchor_y_hat: Tensor) -> Tensor: + return super().forward(self.layout.previous_slice(side_params), anchor_y_hat) + + +class _MlicIntraWrapper(VanillaGlobalIntraContext): + def __init__(self, layout: _MlicppSideLayout) -> None: + super().__init__(dim=layout.slice_ch) + self.layout = layout + + def forward(self, side_params: Tensor, anchor_y_hat: Tensor) -> Tensor: + return super().forward(self.layout.previous_slice(side_params), anchor_y_hat) + + +class _Mlicv2ContextRefinement(nn.Module): + def __init__(self, context: nn.Module, dim: int) -> None: + super().__init__() + self.context = context + self.rope = RoPE2D(dim=dim) + self.reweighting = ContextReweighting(dim=dim) + + def forward(self, *args: Tensor) -> Tensor: + context = self.context(*args) + context = self.rope(context) + if not isinstance(context, Tensor): + raise RuntimeError("Expected RoPE2D to return a tensor for one input") + return self.reweighting(context) + + +class _Mlicv2HgcpAnchorContext(nn.Module): + requires_side_params = True + + def __init__(self, M: int, slice_ch: int) -> None: + super().__init__() + self.hgcp = HGCPModule( + M=M, + slice_ch=slice_ch, + num_heads=_select_num_heads(slice_ch), + ) + + def forward(self, _y_hat: Tensor, *, side_params: Tensor) -> Tensor: + return self.hgcp(side_params) + + +def _build_lrp_input_builder( + layout: _MlicppSideLayout, +) -> Callable[[Tensor, Tensor, Tensor], Tensor]: + def _lrp_inputs(side_params: Tensor, _params: Tensor, y_hat: Tensor) -> Tensor: + _, prior_y_hat, _, _ = layout.split(side_params) + return torch.cat([layout.hyper_means(side_params), prior_y_hat, y_hat], dim=1) + + return _lrp_inputs + + +def _build_mlicv2_global_inter_context( + prior_ch: int, + slice_ch: int, + num_heads: int, +) -> nn.Module: + return _Mlicv2ContextRefinement( + _build_linear_global_inter_context(prior_ch, slice_ch, num_heads), + dim=2 * slice_ch, + ) + + +def _select_global_inter_factory( + variant: str, +) -> Optional[GlobalInterFactory]: + """Return the per-slice global-inter context factory for ``variant``. + + ``None`` means the variant has no inter-slice global context (``mlic``). + """ + if variant == "mlic+": + return _build_vanilla_global_inter_context + if variant == "mlicv2": + return _build_mlicv2_global_inter_context + if variant == "mlicpp": + return _build_linear_global_inter_context + if variant == "mlic": + return None + raise ValueError('variant must be one of "mlic", "mlic+", "mlicpp", or "mlicv2"') diff --git a/compressai/models/mlic.py b/compressai/models/mlic.py new file mode 100644 index 00000000..0299685f --- /dev/null +++ b/compressai/models/mlic.py @@ -0,0 +1,497 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts the MLIC family design from https://github.com/JiangWeibeta/MLIC +# (originally distributed under the Apache License 2.0). Modifications by +# InterDigital Communications, Inc. are released under the BSD 3-Clause Clear +# License terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import re + +from typing import Dict, Iterable, List, Sequence, Tuple, Union + +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + HyperpriorLatentCodec, + MultiContextCheckerboardLatentCodec, +) +from compressai.models._helpers.mlic import ( + AnalysisTransform, + HyperAnalysis, + HyperSynthesis, + LatentResidualPrediction, + LocalContext, + StackedCheckerboardConv, + SynthesisTransform, + WindowCheckerboardAttn, +) +from compressai.models._helpers.mlicv2 import GSCModule, STMAnalysis, STMSynthesis +from compressai.models._helpers.multi_context_slice import ( + _build_lrp_input_builder, + _MlicIntraWrapper, + _MlicppEntropyParameters, + _MlicppIntraWrapper, + _MlicppPriorAggregation, + _MlicppSideLayout, + _Mlicv2ContextRefinement, + _Mlicv2HgcpAnchorContext, + _select_global_inter_factory, +) +from compressai.models.base import CompressionModel +from compressai.registry import register_model + +__all__ = [ + "MLIC", + "MLICPlus", + "MLICPlusPlus", + "MLICv2", +] + + +# MLIC family evolution map: +# +# MLIC --- + inter-slice global ---> MLIC+ +# | | +# | conv stacked checkerboard | overlapped window attention +# | quadratic global intra | quadratic global intra + inter +# | N=192, M=320, slice_ch=32 | +# v v +# only conv local --- replace quadratic with linear attention +# ---> MLIC++ +# | +# + replace residual blocks with STM +# + HGCP: slice-0 global context from hyperprior +# + Context Reweighting: channel-wise attention +# + 2D RoPE replacing relative position bias +# + GSC: post-training skip predictor +# v +# MLICv2 + + +_CURRENT_SLICE_RE = re.compile(r"^latent_codec\.y\.latent_codec\.y(\d+)\.") +_STACKED_CONTEXT_RE = re.compile( + r"^latent_codec\.y\.latent_codec\.y0\.spatial_context_nonanchor" + r"\.context\.(\d+)\.weight$" +) + + +def _infer_slice_num(keys: Iterable[str]) -> int: + """Infer ``slice_num`` from the per-slice ``latent_codec.y.latent_codec`` + keys of an already-compressai-layout state dict. + + Upstream MLIC++ checkpoints use a different root-level layout; convert them + first with ``examples/convert_mlic_checkpoint.py`` before calling + :meth:`from_state_dict`. + """ + indices: List[int] = [] + for key in keys: + match = _CURRENT_SLICE_RE.match(key) + if match is not None: + indices.append(int(match.group(1))) + return max(indices) + 1 if indices else 10 + + +def _infer_context_window(state_dict: Dict[str, Tensor]) -> int: + index_key = ( + "latent_codec.y.latent_codec.y0." + "spatial_context_nonanchor.relative_position_index" + ) + if index_key in state_dict: + return int(round(state_dict[index_key].size(0) ** 0.5)) + + table_key = ( + "latent_codec.y.latent_codec.y0." + "spatial_context_nonanchor.relative_position_table" + ) + if table_key in state_dict: + side = int(round(state_dict[table_key].size(0) ** 0.5)) + return (side + 1) // 2 + return 5 + + +def _infer_local_kernel(state_dict: Dict[str, Tensor]) -> int: + key = "latent_codec.y.latent_codec.y0.spatial_context_nonanchor.context.0.weight" + if key in state_dict: + return int(state_dict[key].size(-1)) + return 5 + + +def _infer_local_layers(keys: Iterable[str]) -> int: + indices = [] + for key in keys: + match = _STACKED_CONTEXT_RE.match(key) + if match is not None: + indices.append(int(match.group(1))) + return len(indices) if indices else 3 + + +class _SideContextChannelGroupsLatentCodec(ChannelGroupsLatentCodec): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if "y0" not in self.channel_context: + raise ValueError("side-parameter channel groups require channel_context.y0") + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + if k == 0: + return self.channel_context["y0"](side_params) + support = [y_hat_[i] for i in self.support_slices[k]] + if not support: + return self.channel_context[f"y{k}"](side_params) + return self.channel_context[f"y{k}"]( + self.merge_params(side_params, self.merge_y(*support)) + ) + + +class _BaseMLIC(CompressionModel): + _variant = "mlic" + _analysis_cls = AnalysisTransform + _synthesis_cls = SynthesisTransform + + def __init__( + self, + *, + N: int, + M: int, + slice_num: int, + context_window: int = 5, + local_kernel: int = 5, + local_layers: int = 3, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if slice_num <= 0: + raise ValueError("slice_num must be positive") + if context_window % 2 == 0: + raise ValueError("context_window must be odd") + if M % slice_num != 0: + raise ValueError("M must be divisible by slice_num") + + self.N = int(N) + self.M = int(M) + self.context_window = int(context_window) + self.local_kernel = int(local_kernel) + self.local_layers = int(local_layers) + self.slice_num = int(slice_num) + self.slice_ch = int(M // slice_num) + + self.g_a = self._analysis_cls(N=N, M=M) + self.g_s = self._synthesis_cls(N=N, M=M) + + # Per-slice entropy stack, inlined ELIC-style. The variant only changes + # three things: the inter-slice global context factory, the non-anchor + # spatial context module, and the intra-channel context wrapper (plus + # the MLICv2-only HGCP anchor context and GSC selective predictor). + variant = self._variant + slice_ch = self.slice_ch + global_inter_factory = _select_global_inter_factory(variant) + use_global_inter = global_inter_factory is not None + + def _side_layout(k: int) -> _MlicppSideLayout: + return _MlicppSideLayout( + M=M, + slice_ch=slice_ch, + slice_index=k, + use_global_inter=use_global_inter, + ) + + def _spatial_context() -> nn.Module: + if variant == "mlic": + return StackedCheckerboardConv( + dim=slice_ch, + kernel=local_kernel, + num_layers=local_layers, + ) + if variant == "mlic+": + return WindowCheckerboardAttn(dim=slice_ch, window_size=context_window) + return LocalContext(dim=slice_ch, window_size=context_window) + + def _intra_context(layout: _MlicppSideLayout) -> nn.Module: + if variant in ("mlic", "mlic+"): + return _MlicIntraWrapper(layout) + context = _MlicppIntraWrapper(layout) + if variant == "mlicv2": + return _Mlicv2ContextRefinement(context, dim=2 * slice_ch) + return context + + def _leaf(k: int) -> MultiContextCheckerboardLatentCodec: + layout = _side_layout(k) + anchor_context_ch = 2 * slice_ch if variant == "mlicv2" and k == 0 else 0 + return MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=_MlicppEntropyParameters( + layout, + step="anchor", + anchor_context_ch=anchor_context_ch, + ), + entropy_parameters_nonanchor=_MlicppEntropyParameters( + layout, + step="non_anchor", + ), + spatial_context_anchor=( + _Mlicv2HgcpAnchorContext(M=M, slice_ch=slice_ch) + if variant == "mlicv2" and k == 0 + else None + ), + spatial_context_nonanchor=_spatial_context(), + intra_channel_context_nonanchor=( + _intra_context(layout) if k > 0 else None + ), + selective_predictor=( + GSCModule(slice_ch=slice_ch, side_ch=layout.side_ch) + if variant == "mlicv2" + else None + ), + lrp_anchor=LatentResidualPrediction( + in_dim=M + (k + 1) * slice_ch, + out_dim=slice_ch, + ), + lrp_nonanchor=LatentResidualPrediction( + in_dim=M + (k + 1) * slice_ch, + out_dim=slice_ch, + ), + lrp_input_builder=_build_lrp_input_builder(layout), + lrp_activation=None, + lrp_scale=1.0, + anchor_parity="odd", + ) + + support_slices = [list(range(k)) for k in range(slice_num)] + + self.latent_codec = HyperpriorLatentCodec( + h_a=HyperAnalysis(M=M, N=N), + h_s=HyperSynthesis(M=M, N=N), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), + quantizer="ste", + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=[slice_ch] * slice_num, + channel_context={ + f"y{k}": _MlicppPriorAggregation( + M=M, + slice_ch=slice_ch, + slice_index=k, + global_inter_factory=global_inter_factory, + ) + for k in range(slice_num) + }, + latent_codec={f"y{k}": _leaf(k) for k in range(slice_num)}, + support_slices=support_slices, + ), + }, + ) + + @property + def downsampling_factor(self) -> int: + return 2 ** (4 + 2) + + def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: + y = self.g_a(x) + y_out = self.latent_codec(y) + return { + "x_hat": self.g_s(y_out["y_hat"]), + "likelihoods": y_out["likelihoods"], + } + + def compress(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} + + def decompress( + self, + strings: Sequence[Sequence[bytes]], + shape: Dict[str, Union[List[Tuple[int, ...]], Tuple[int, ...]]], + ) -> Dict[str, Tensor]: + y_out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self.g_s(y_out["y_hat"]).clamp_(0, 1)} + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "_BaseMLIC": + N = state_dict["g_a.analysis_transform.0.conv1.weight"].size(0) + M = state_dict["g_a.analysis_transform.6.weight"].size(0) + slice_num = _infer_slice_num(state_dict.keys()) + kwargs = { + "N": N, + "M": M, + "slice_num": slice_num, + } + if cls._variant in ("mlic+", "mlicpp", "mlicv2"): + kwargs["context_window"] = _infer_context_window(state_dict) + else: + kwargs["local_kernel"] = _infer_local_kernel(state_dict) + kwargs["local_layers"] = _infer_local_layers(state_dict.keys()) + + net = cls(**kwargs) + incompatible_keys = net.load_state_dict(state_dict, strict=False) + allowed_missing = { + key for key in net.state_dict() if key.endswith("relative_position_index") + } + missing_keys = set(incompatible_keys.missing_keys) - allowed_missing + if missing_keys or incompatible_keys.unexpected_keys: + raise RuntimeError( + f"Unexpected incompatibility while loading {cls.__name__} state_dict: " + f"missing={sorted(missing_keys)}, " + f"unexpected={sorted(incompatible_keys.unexpected_keys)}" + ) + return net + + +@register_model("mlic") +class MLIC(_BaseMLIC): + r"""MLIC model from W. Jiang, J. Yang, Y. Zhai, P. Ning, F. Gao, R. Wang: + `"MLIC: Multi-Reference Entropy Model for Learned Image Compression" + `_, ACM Multimedia 2023. + """ + + _variant = "mlic" + + def __init__( + self, + N: int = 192, + M: int = 192, + slice_num: int = 6, + local_kernel: int = 5, + local_layers: int = 3, + **kwargs, + ) -> None: + super().__init__( + N=N, + M=M, + slice_num=slice_num, + local_kernel=local_kernel, + local_layers=local_layers, + **kwargs, + ) + + +@register_model("mlicplus") +class MLICPlus(_BaseMLIC): + r"""MLIC+ model from W. Jiang, J. Yang, Y. Zhai, P. Ning, F. Gao, R. Wang: + `"MLIC: Multi-Reference Entropy Model for Learned Image Compression" + `_, ACM Multimedia 2023. + """ + + _variant = "mlic+" + + def __init__( + self, + N: int = 192, + M: int = 320, + slice_num: int = 10, + context_window: int = 5, + **kwargs, + ) -> None: + super().__init__( + N=N, + M=M, + slice_num=slice_num, + context_window=context_window, + **kwargs, + ) + + +@register_model("mlicpp") +class MLICPlusPlus(_BaseMLIC): + r"""MLIC++ model from W. Jiang, J. Yang, Y. Zhai, F. Gao, R. Wang: + `"MLIC++: Linear Complexity Multi-Reference Entropy Modeling for Learned + Image Compression" `_, ACM Trans. + Multimedia Comput. Commun. Appl. (TOMM), 2025; ICML 2023 Neural + Compression Workshop. + + This implementation uses a containerized hyperprior entropy stack: + ``HyperpriorLatentCodec`` wraps the MLIC++ hyper transforms and a + ``ChannelGroupsLatentCodec`` built from per-slice + ``MultiContextCheckerboardLatentCodec`` leaves. + + Upstream MLIC++ checkpoints from JiangWeibeta/MLIC use a different + root-level key layout; convert them to the compressai layout first with + ``examples/convert_mlic_checkpoint.py`` before calling + :meth:`from_state_dict`. + """ + + _variant = "mlicpp" + + def __init__( + self, + N: int = 192, + M: int = 320, + slice_num: int = 10, + context_window: int = 5, + **kwargs, + ) -> None: + super().__init__( + N=N, + M=M, + slice_num=slice_num, + context_window=context_window, + **kwargs, + ) + + +@register_model("mlicv2") +class MLICv2(_BaseMLIC): + r"""MLICv2 model from W. Jiang, J. Yang, Y. Zhai, F. Gao, R. Wang: + `"MLIC++: Linear Complexity Multi-Reference Entropy Modeling for Learned + Image Compression" `_ follow-up family. + + This variant replaces the MLIC++ analysis/synthesis transforms with STM + blocks and enables HGCP, context reweighting, 2D RoPE, and GSC in the + shared multi-context slice factory. + """ + + _variant = "mlicv2" + _analysis_cls = STMAnalysis + _synthesis_cls = STMSynthesis + + def __init__( + self, + N: int = 192, + M: int = 320, + slice_num: int = 10, + context_window: int = 5, + **kwargs, + ) -> None: + super().__init__( + N=N, + M=M, + slice_num=slice_num, + context_window=context_window, + **kwargs, + ) diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index 17f53136..0d876a79 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -37,6 +37,10 @@ dcae, mbt2018, mbt2018_mean, + mlic, + mlicplus, + mlicpp, + mlicv2, saaf, stf, stf_wacnn, @@ -55,6 +59,10 @@ "cheng2020-anchor": cheng2020_anchor, "cheng2020-attn": cheng2020_attn, "dcae": dcae, + "mlic": mlic, + "mlicplus": mlicplus, + "mlicpp": mlicpp, + "mlicv2": mlicv2, "saaf": saaf, "stf": stf, "stf-wacnn": stf_wacnn, diff --git a/compressai/zoo/image.py b/compressai/zoo/image.py index 120c619f..d7ddf2ab 100644 --- a/compressai/zoo/image.py +++ b/compressai/zoo/image.py @@ -83,6 +83,10 @@ def __getattr__(self, item): "cheng2020_anchor", "cheng2020_attn", "dcae", + "mlic", + "mlicplus", + "mlicpp", + "mlicv2", "saaf", "stf", "stf_wacnn", @@ -99,6 +103,10 @@ def __getattr__(self, item): "cheng2020-anchor": Cheng2020Anchor, "cheng2020-attn": Cheng2020Attention, "dcae": _LazyImport("compressai.models.dcae", "DCAE"), + "mlic": _LazyImport("compressai.models.mlic", "MLIC"), + "mlicplus": _LazyImport("compressai.models.mlic", "MLICPlus"), + "mlicpp": _LazyImport("compressai.models.mlic", "MLICPlusPlus"), + "mlicv2": _LazyImport("compressai.models.mlic", "MLICv2"), "saaf": _LazyImport("compressai.models.saaf", "SAAF"), # Resolved lazily so `compressai.zoo` is importable without `timm`. "stf": _LazyImport("compressai.models.stf", "SymmetricalTransFormer"), @@ -535,6 +543,82 @@ def saaf(pretrained: bool = False, progress: bool = True, **kwargs): return SAAF(**kwargs) +def mlic(pretrained: bool = False, progress: bool = True, **kwargs): + r"""MLIC model from W. Jiang, J. Yang, Y. Zhai, P. Ning, F. Gao, R. Wang: + `"MLIC: Multi-Reference Entropy Model for Learned Image Compression" + `_, ACM Multimedia 2023. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained MLIC weights are not yet hosted on S3.") + from compressai.models.mlic import MLIC + + return MLIC(**kwargs) + + +def mlicplus(pretrained: bool = False, progress: bool = True, **kwargs): + r"""MLIC+ model from W. Jiang, J. Yang, Y. Zhai, P. Ning, F. Gao, R. Wang: + `"MLIC: Multi-Reference Entropy Model for Learned Image Compression" + `_, ACM Multimedia 2023. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained MLIC+ weights are not yet hosted on S3.") + from compressai.models.mlic import MLICPlus + + return MLICPlus(**kwargs) + + +def mlicpp(pretrained: bool = False, progress: bool = True, **kwargs): + r"""MLIC++ model from W. Jiang, J. Yang, Y. Zhai, F. Gao, R. Wang: + `"MLIC++: Linear Complexity Multi-Reference Entropy Modeling for Learned + Image Compression" `_, ACM Trans. + Multimedia Comput. Commun. Appl. (TOMM), 2025; ICML 2023 Neural + Compression Workshop. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained MLIC++ weights are not yet hosted on S3.") + from compressai.models.mlic import MLICPlusPlus + + return MLICPlusPlus(**kwargs) + + +def mlicv2(pretrained: bool = False, progress: bool = True, **kwargs): + r"""MLICv2 model from W. Jiang, J. Yang, Y. Zhai, F. Gao, R. Wang. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained MLICv2 weights are not yet hosted on S3.") + from compressai.models.mlic import MLICv2 + + return MLICv2(**kwargs) + + def stf(pretrained: bool = False, progress: bool = True, **kwargs): r"""Symmetrical TransFormer (STF) model from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the Details: Window-based Attention for Image diff --git a/examples/convert_mlic_checkpoint.py b/examples/convert_mlic_checkpoint.py new file mode 100644 index 00000000..ae0bcb3a --- /dev/null +++ b/examples/convert_mlic_checkpoint.py @@ -0,0 +1,281 @@ +"""Convert an MLIC-family checkpoint to compressai layout. + +Loads an MLIC-family checkpoint, instantiates the matching compressai model +through ``from_state_dict``, and optionally writes the converted state dict. +MLIC++ checkpoints from JiangWeibeta/MLIC are translated from the published +root-level layout to the containerized ``latent_codec.*`` layout inside +``convert_upstream_mlicpp_state_dict``. + +The upstream-checkpoint conversion helpers live in this example CLI (not in +``compressai.models.mlic``) so the model module stays a clean compressai-native +definition. ``examples/`` is not an importable package, so tests load +``convert_upstream_mlicpp_state_dict`` by file path. + +Example:: + + python examples/convert_mlic_checkpoint.py \ + --src candidate/MLIC/mlicpp_mse_q5_2960000.pth.tar \ + --variant mlicpp \ + --dst /tmp/mlicpp_compressai.pth \ + --smoke +""" + +from __future__ import annotations + +import argparse +import re + +from pathlib import Path +from typing import Dict, Iterable, List, Type + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.models.mlic import ( + MLIC, + MLICPlus, + MLICPlusPlus, + MLICv2, +) + +_VARIANTS: Dict[str, Type[nn.Module]] = { + "mlic": MLIC, + "mlic+": MLICPlus, + "mlicpp": MLICPlusPlus, + "mlicv2": MLICv2, +} + + +# --------------------------------------------------------------------------- +# Upstream MLIC++ checkpoint conversion +# +# The old fork-script layout stored the hyperprior and per-slice modules under +# a monolithic ``latent_codec``. The compressai model follows the ELIC-style +# container structure: ``HyperpriorLatentCodec`` owns ``h_a`` / ``h_s`` / ``z``, +# while ``latent_codec.y`` owns the channel groups and per-slice checkerboard +# leaves. +# --------------------------------------------------------------------------- + +_CURRENT_SLICE_RE = re.compile(r"^latent_codec\.y\.latent_codec\.y(\d+)\.") +_LEGACY_SLICE_RE = re.compile( + r"^(?:latent_codec\.)?" + r"(?:local_context|channel_context|global_inter_context|" + r"global_intra_context|entropy_parameters_anchor|" + r"entropy_parameters_nonanchor|lrp_anchor|lrp_nonanchor)\.(\d+)\." +) + +_ROOT_TO_CONTAINER_PREFIXES: Dict[str, str] = { + "h_a.": "latent_codec.h_a.", + "h_s.": "latent_codec.h_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} + +_LEGACY_LIST_RENAMES: Dict[str, str] = { + "local_context": "latent_codec.y.latent_codec.y{index}.spatial_context_nonanchor.", + "channel_context": "latent_codec.y.channel_context.y{index}.channel_part.", + "global_inter_context": "latent_codec.y.channel_context.y{index}.global_inter_part.", + "global_intra_context": ( + "latent_codec.y.latent_codec.y{index}.intra_channel_context_nonanchor." + ), + "entropy_parameters_anchor": ( + "latent_codec.y.latent_codec.y{index}.entropy_parameters_anchor." + ), + "entropy_parameters_nonanchor": ( + "latent_codec.y.latent_codec.y{index}.entropy_parameters_nonanchor." + ), + "lrp_anchor": "latent_codec.y.latent_codec.y{index}.lrp_anchor.", + "lrp_nonanchor": "latent_codec.y.latent_codec.y{index}.lrp_nonanchor.", +} + + +def _strip_data_parallel_prefix(key: str) -> str: + if key.startswith("module."): + return key[len("module.") :] + return key + + +def _infer_slice_num(keys: Iterable[str]) -> int: + indices: List[int] = [] + for raw_key in keys: + key = _strip_data_parallel_prefix(raw_key) + for pattern in (_CURRENT_SLICE_RE, _LEGACY_SLICE_RE): + match = pattern.match(key) + if match is not None: + indices.append(int(match.group(1))) + break + return max(indices) + 1 if indices else 10 + + +def _convert_mlicpp_key( + key: str, + *, + slice_num: int, +) -> List[str]: + key = _strip_data_parallel_prefix(key) + + for old, new in _ROOT_TO_CONTAINER_PREFIXES.items(): + if key.startswith(old): + return [new + key[len(old) :]] + + if key.startswith("latent_codec.entropy_bottleneck."): + return [ + "latent_codec.z.entropy_bottleneck." + + key[len("latent_codec.entropy_bottleneck.") :] + ] + + for prefix in ("gaussian_conditional.", "latent_codec.gaussian_conditional."): + if key.startswith(prefix): + suffix = key[len(prefix) :] + return [ + f"latent_codec.y.latent_codec.y{index}.y.gaussian_conditional." + suffix + for index in range(slice_num) + ] + + legacy_key = key + if legacy_key.startswith("latent_codec."): + legacy_key = legacy_key[len("latent_codec.") :] + parts = legacy_key.split(".", 2) + if len(parts) == 3 and parts[0] in _LEGACY_LIST_RENAMES: + name, index, suffix = parts + return [_LEGACY_LIST_RENAMES[name].format(index=index) + suffix] + + return [key] + + +def convert_upstream_mlicpp_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Convert legacy MLIC++ checkpoint keys to the containerized layout. + + The old fork-script layout stored the hyperprior and per-slice modules + under a monolithic ``latent_codec``. The compressai model follows the + ELIC-style container structure: ``HyperpriorLatentCodec`` owns ``h_a`` / + ``h_s`` / ``z``, while ``latent_codec.y`` owns the channel groups and + per-slice checkerboard leaves. + """ + slice_num = _infer_slice_num(state_dict.keys()) + converted: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + for new_key in _convert_mlicpp_key(key, slice_num=slice_num): + converted[new_key] = value + return converted + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream or already-converted MLIC-family checkpoint.", + ) + parser.add_argument( + "--variant", + choices=sorted(_VARIANTS), + default="mlicpp", + help="Model variant to instantiate (default: mlicpp).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic image.", + ) + parser.add_argument( + "--smoke-size", + type=int, + default=256, + help="Synthetic square image size used by --smoke (default: 256).", + ) + return parser.parse_args() + + +def load_state_dict(path: Path) -> Dict[str, Tensor]: + checkpoint = torch.load(path, map_location="cpu", weights_only=False) + if isinstance(checkpoint, dict): + state_dict = checkpoint.get("state_dict", checkpoint) + else: + state_dict = checkpoint + if not isinstance(state_dict, dict): + raise SystemExit(f"checkpoint does not contain a state dict: {path}") + return dict(state_dict) + + +def make_synthetic_image(size: int) -> Tensor: + if size <= 0: + raise SystemExit("--smoke-size must be positive") + ys, xs = torch.meshgrid( + torch.linspace(0, 1, size), + torch.linspace(0, 1, size), + indexing="ij", + ) + return ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + +def run_smoke(net: nn.Module, size: int) -> None: + img = make_synthetic_image(size) + with torch.no_grad(): + out = net(img) + n_pix = size * size + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + state_dict = load_state_dict(args.src) + if args.variant == "mlicpp": + state_dict = convert_upstream_mlicpp_state_dict(state_dict) + print(f"loaded checkpoint -> {len(state_dict)} compressai keys") + + cls = _VARIANTS[args.variant] + net = cls.from_state_dict(state_dict).eval() + print( + "variant: " + f"{args.variant}, N={net.N}, M={net.M}, slice_num={net.slice_num}, " + f"context_window={getattr(net, 'context_window', None)}, " + f"local_kernel={getattr(net, 'local_kernel', None)}, " + f"local_layers={getattr(net, 'local_layers', None)}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict -> {args.dst}") + + if args.smoke: + run_smoke(net, args.smoke_size) + + +if __name__ == "__main__": + main() diff --git a/tests/test_mlic_layers.py b/tests/test_mlic_layers.py new file mode 100644 index 00000000..b1854dcb --- /dev/null +++ b/tests/test_mlic_layers.py @@ -0,0 +1,288 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Callable, Tuple + +import pytest +import torch +import torch.nn as nn + +import compressai.layers as layers + +try: + from compressai.models._helpers.mlic import ( + AnalysisTransform, + ChannelContext, + EntropyParameters, + HyperAnalysis, + HyperSynthesis, + LatentResidualPrediction, + LinearGlobalInterContext, + LinearGlobalIntraContext, + LocalContext, + StackedCheckerboardConv, + SynthesisTransform, + VanillaGlobalInterContext, + VanillaGlobalIntraContext, + WindowCheckerboardAttn, + ) + from compressai.models._helpers.mlic.utils import ( + build_position_index, + checkerboard_anchor, + checkerboard_merge, + checkerboard_nonanchor, + checkerboard_split, + squeeze_anchor, + squeeze_nonanchor, + unsqueeze_anchor, + unsqueeze_nonanchor, + ) +except ModuleNotFoundError as err: + if err.name != "timm": + raise + pytestmark = pytest.mark.skip(reason="MLIC++ layers require the [attn] extra") + + +def _round_trip_output( + make_module: Callable[[], nn.Module], + inputs: Tuple[torch.Tensor, ...], +) -> torch.Tensor: + torch.manual_seed(0) + module = make_module().eval() + with torch.no_grad(): + expected = module(*inputs) + + clone = make_module().eval() + clone.load_state_dict(module.state_dict()) + with torch.no_grad(): + actual = clone(*inputs) + + assert torch.allclose(actual, expected, atol=1e-6) + return expected + + +def test_mlic_layers_are_deep_import_only(): + assert not hasattr(layers, "LocalContext") + assert not hasattr(layers, "EntropyParameters") + + +class TestMlicContextLayers: + @staticmethod + def test_local_context_forward_shape_and_state_dict_round_trip(): + x = torch.randn(2, 4, 4, 4) + y = _round_trip_output( + lambda: LocalContext(dim=4, window_size=3, num_heads=2), + (x,), + ) + assert y.shape == (2, 8, 4, 4) + + @staticmethod + def test_context_layers_reject_invalid_head_count(): + with pytest.raises(ValueError): + LocalContext(dim=5, num_heads=2) + with pytest.raises(ValueError): + LinearGlobalInterContext(dim=5, num_heads=2) + with pytest.raises(ValueError): + LinearGlobalIntraContext(dim=5, num_heads=2) + with pytest.raises(ValueError): + VanillaGlobalInterContext(in_dim=5, num_heads=2) + with pytest.raises(ValueError): + VanillaGlobalIntraContext(dim=5, num_heads=2) + + @staticmethod + def test_channel_context_forward_shape_and_state_dict_round_trip(): + x = torch.randn(2, 8, 4, 4) + y = _round_trip_output(lambda: ChannelContext(in_dim=8, out_dim=4), (x,)) + assert y.shape == (2, 16, 4, 4) + + @staticmethod + def test_linear_global_inter_context_shape_and_state_dict_round_trip(): + x = torch.randn(2, 4, 4, 4) + y = _round_trip_output( + lambda: LinearGlobalInterContext(dim=4, out_dim=8, num_heads=2), + (x,), + ) + assert y.shape == (2, 8, 4, 4) + + @staticmethod + def test_stacked_checkerboard_conv_shape_and_state_dict_round_trip(): + x = torch.randn(2, 4, 6, 6) + y = _round_trip_output( + lambda: StackedCheckerboardConv(dim=4, kernel=5, num_layers=3), + (x,), + ) + assert y.shape == (2, 8, 6, 6) + + @staticmethod + def test_stacked_checkerboard_conv_rejects_even_kernel_or_layers(): + with pytest.raises(ValueError): + StackedCheckerboardConv(dim=4, kernel=4) + with pytest.raises(ValueError): + StackedCheckerboardConv(dim=4, num_layers=2) + + @staticmethod + def test_window_checkerboard_attention_shape_and_mask(): + x = torch.randn(2, 4, 4, 4) + y = _round_trip_output( + lambda: WindowCheckerboardAttn(dim=4, window_size=3, num_heads=2), + (x,), + ) + assert y.shape == (2, 8, 4, 4) + + module = WindowCheckerboardAttn(dim=4, window_size=3, num_heads=2) + module.update_resolution(4, 4, x.device) + assert module.attn_mask is not None + assert module.attn_mask.shape == (16, 9, 9) + assert torch.any(module.attn_mask == 0) + assert torch.any(module.attn_mask == -100) + assert torch.all((module.attn_mask == 0) | (module.attn_mask == -100)) + + @staticmethod + def test_vanilla_global_inter_context_shape_and_state_dict_round_trip(): + x = torch.randn(2, 4, 4, 4) + y = _round_trip_output( + lambda: VanillaGlobalInterContext(in_dim=4, out_dim=8, num_heads=2), + (x,), + ) + assert y.shape == (2, 8, 4, 4) + + @staticmethod + def test_vanilla_global_intra_context_shape_mask_and_state_dict_round_trip(): + x1 = torch.randn(2, 4, 4, 4) + x2 = torch.randn(2, 4, 4, 4) + y = _round_trip_output( + lambda: VanillaGlobalIntraContext(dim=4, num_heads=2), + (x1, x2), + ) + assert y.shape == (2, 8, 4, 4) + + module = VanillaGlobalIntraContext(dim=4, num_heads=2, local_mask_radius=0) + mask = module._attention_mask(4, 4, x1.device) + assert mask.shape == (8, 8) + assert not torch.any(mask) + + module = VanillaGlobalIntraContext(dim=4, num_heads=2, local_mask_radius=1) + mask = module._attention_mask(4, 4, x1.device) + assert mask.shape == (8, 8) + assert torch.any(mask) + assert not torch.all(mask) + + @staticmethod + def test_linear_global_intra_context_shape_and_state_dict_round_trip(): + x1 = torch.randn(2, 4, 4, 4) + x2 = torch.randn(2, 4, 4, 4) + y = _round_trip_output( + lambda: LinearGlobalIntraContext(dim=4, num_heads=2), + (x1, x2), + ) + assert y.shape == (2, 8, 4, 4) + + +class TestMlicTransforms: + @staticmethod + def test_analysis_transform_shape_and_state_dict_round_trip(): + x = torch.randn(1, 3, 64, 64) + module = AnalysisTransform(N=8, M=16) + assert isinstance(module.analysis_transform[0].act, nn.GELU) + assert not hasattr(module.analysis_transform[0], "leaky_relu") + y = _round_trip_output(lambda: module, (x,)) + assert y.shape == (1, 16, 4, 4) + + @staticmethod + def test_synthesis_transform_shape_and_state_dict_round_trip(): + y = torch.randn(1, 16, 4, 4) + module = SynthesisTransform(N=8, M=16) + assert isinstance(module.synthesis_transform[0].act, nn.GELU) + assert not hasattr(module.synthesis_transform[0], "leaky_relu") + x_hat = _round_trip_output(lambda: module, (y,)) + assert x_hat.shape == (1, 3, 64, 64) + + @staticmethod + def test_hyper_analysis_shape_and_state_dict_round_trip(): + y = torch.randn(1, 16, 4, 4) + z = _round_trip_output(lambda: HyperAnalysis(M=16, N=8), (y,)) + assert z.shape == (1, 8, 1, 1) + + @staticmethod + def test_hyper_synthesis_shape_and_state_dict_round_trip(): + z = torch.randn(1, 8, 1, 1) + params = _round_trip_output(lambda: HyperSynthesis(M=16, N=8), (z,)) + assert params.shape == (1, 32, 4, 4) + + @staticmethod + def test_entropy_parameters_shape_and_state_dict_round_trip(): + params = torch.randn(2, 10, 4, 4) + y = _round_trip_output( + lambda: EntropyParameters(in_dim=10, out_dim=8), (params,) + ) + assert y.shape == (2, 8, 4, 4) + + @staticmethod + def test_lrp_shape_bound_and_state_dict_round_trip(): + params = torch.randn(2, 10, 4, 4) + y = _round_trip_output( + lambda: LatentResidualPrediction(in_dim=10, out_dim=4), + (params,), + ) + assert y.shape == (2, 4, 4, 4) + assert torch.all(y <= 0.5) + assert torch.all(y >= -0.5) + + +class TestMlicCheckerboardUtils: + @staticmethod + def test_checkerboard_split_merge_and_squeeze_layout(): + x = torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4) + anchor, nonanchor = checkerboard_split(x) + + expected_anchor_squeezed = torch.tensor( + [[[[1.0, 3.0], [4.0, 6.0], [9.0, 11.0], [12.0, 14.0]]]] + ) + expected_nonanchor_squeezed = torch.tensor( + [[[[0.0, 2.0], [5.0, 7.0], [8.0, 10.0], [13.0, 15.0]]]] + ) + + assert torch.equal(checkerboard_merge(anchor, nonanchor), x) + assert torch.equal(checkerboard_anchor(x), anchor) + assert torch.equal(checkerboard_nonanchor(x), nonanchor) + assert torch.equal(squeeze_anchor(x), expected_anchor_squeezed) + assert torch.equal(squeeze_nonanchor(x), expected_nonanchor_squeezed) + assert torch.equal(unsqueeze_anchor(squeeze_anchor(x)), anchor) + assert torch.equal(unsqueeze_nonanchor(squeeze_nonanchor(x)), nonanchor) + + @staticmethod + def test_build_position_index_shape_and_center_value(): + position_index = build_position_index((3, 3)) + + assert position_index.shape == (9, 9) + assert position_index[4, 4].item() == 12 + assert position_index.min().item() == 0 + assert position_index.max().item() == 24 diff --git a/tests/test_mlicv2_layers.py b/tests/test_mlicv2_layers.py new file mode 100644 index 00000000..0bb05366 --- /dev/null +++ b/tests/test_mlicv2_layers.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Callable, Tuple + +import pytest +import torch +import torch.nn as nn + +import compressai.layers as layers + +from compressai.models._helpers.mlicv2 import ( + ContextReweighting, + GSCModule, + HGCPModule, + RoPE2D, + SimpleTokenMixing, + STMAnalysis, + STMSynthesis, +) + + +def _round_trip_output( + make_module: Callable[[], nn.Module], + inputs: Tuple[torch.Tensor, ...], +) -> torch.Tensor: + torch.manual_seed(0) + module = make_module().eval() + with torch.no_grad(): + expected = module(*inputs) + + clone = make_module().eval() + clone.load_state_dict(module.state_dict()) + with torch.no_grad(): + actual = clone(*inputs) + + assert torch.allclose(actual, expected, atol=1e-6) + return expected + + +def test_mlicv2_layers_are_deep_import_only() -> None: + assert not hasattr(layers, "SimpleTokenMixing") + assert not hasattr(layers, "HGCPModule") + + +class TestMlicv2Transforms: + @staticmethod + def test_simple_token_mixing_shape_and_state_dict_round_trip() -> None: + x = torch.randn(2, 8, 8, 8) + y = _round_trip_output(lambda: SimpleTokenMixing(dim=8), (x,)) + assert y.shape == x.shape + + @staticmethod + def test_stm_analysis_shape_and_state_dict_round_trip() -> None: + x = torch.randn(1, 3, 64, 64) + y = _round_trip_output(lambda: STMAnalysis(N=8, M=16), (x,)) + assert y.shape == (1, 16, 4, 4) + + @staticmethod + def test_stm_synthesis_shape_and_state_dict_round_trip() -> None: + y = torch.randn(1, 16, 4, 4) + x_hat = _round_trip_output(lambda: STMSynthesis(N=8, M=16), (y,)) + assert x_hat.shape == (1, 3, 64, 64) + + +class TestMlicv2Context: + @staticmethod + def test_context_reweighting_shape_attention_and_state_dict_round_trip() -> None: + x = torch.randn(2, 8, 4, 4) + y = _round_trip_output(lambda: ContextReweighting(dim=8), (x,)) + assert y.shape == x.shape + + module = ContextReweighting(dim=8).eval() + with torch.no_grad(): + attention = module.channel_attention(x) + assert attention.shape == (2, 8, 8) + assert torch.allclose( + attention.sum(dim=-1), + torch.ones(2, 8), + atol=1e-6, + ) + + @staticmethod + def test_rope2d_shape_state_dict_and_relative_position_property() -> None: + module = RoPE2D(dim=4, learnable_thetas=False).eval() + x = torch.ones(1, 4, 4, 4) + y = _round_trip_output(lambda: RoPE2D(dim=4, learnable_thetas=False), (x,)) + assert y.shape == x.shape + + rotated = module.rotate(x).reshape(1, 2, 2, 4, 4) + token_a = rotated[:, :, :, 0, 0].reshape(1, 4) + token_b = rotated[:, :, :, 1, 1].reshape(1, 4) + token_c = rotated[:, :, :, 2, 1].reshape(1, 4) + token_d = rotated[:, :, :, 3, 2].reshape(1, 4) + score_ab = (token_a * token_b).sum(dim=1) + score_cd = (token_c * token_d).sum(dim=1) + assert torch.allclose(score_ab, score_cd, atol=1e-6) + + @staticmethod + def test_rope2d_rejects_odd_dim() -> None: + with pytest.raises(ValueError): + RoPE2D(dim=5) + + @staticmethod + def test_hgcp_shape_and_state_dict_round_trip() -> None: + hyper = torch.randn(2, 32, 4, 4) + y_hat = torch.randn(2, 8, 4, 4) + y = _round_trip_output(lambda: HGCPModule(M=16, slice_ch=8), (hyper, y_hat)) + assert y.shape == (2, 16, 4, 4) + + @staticmethod + def test_hgcp_rejects_invalid_head_count() -> None: + with pytest.raises(ValueError): + HGCPModule(M=16, slice_ch=7, num_heads=2) + + @staticmethod + def test_gsc_shape_skip_rate_and_state_dict_round_trip() -> None: + side_params = torch.randn(2, 12, 4, 4) + scales = torch.linspace(0.1, 0.5, steps=2 * 8 * 4 * 4).reshape(2, 8, 4, 4) + means = torch.zeros_like(scales) + + torch.manual_seed(0) + module = GSCModule(slice_ch=8, side_ch=12, threshold=0.3).eval() + with torch.no_grad(): + out = module( + side_params=side_params, + scales=scales, + means=means, + step="anchor", + ) + selective_map = out["selective_map"] + assert selective_map.shape == scales.shape + assert torch.all((selective_map >= 0) & (selective_map <= 1)) + hard_ratio = (selective_map >= 0.5).float().mean().item() + assert 0.1 < hard_ratio < 0.9 + + clone = GSCModule(slice_ch=8, side_ch=12, threshold=0.3).eval() + clone.load_state_dict(module.state_dict()) + with torch.no_grad(): + cloned = clone( + side_params=side_params, + scales=scales, + means=means, + step="anchor", + ) + assert torch.allclose(cloned["selective_map"], selective_map, atol=1e-6) + + @staticmethod + def test_gsc_rejects_invalid_step() -> None: + module = GSCModule(slice_ch=4, side_ch=8) + with pytest.raises(ValueError): + module( + side_params=torch.randn(1, 8, 2, 2), + scales=torch.ones(1, 4, 2, 2), + means=torch.zeros(1, 4, 2, 2), + step="bad", + ) diff --git a/tests/test_models.py b/tests/test_models.py index 66c4013a..5a92edf0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1106,6 +1106,241 @@ def test_saaf_upstream_state_dict_conversion(self): assert "gaussian_conditional.scale_table" not in converted +class TestMlic: + def test_mlic_forward_and_state_dict_round_trip(self): + pytest.importorskip("timm") + from compressai.models.mlic import MLIC + + model = MLIC(N=8, M=12, slice_num=3, local_kernel=3).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + sd_keys = set(model.state_dict().keys()) + assert "latent_codec.h_a.reduction.0.weight" in sd_keys + assert "latent_codec.h_s.increase.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + assert ( + "latent_codec.y.channel_context.y1.channel_part.fushion.0.weight" in sd_keys + ) + assert not any( + k.startswith("latent_codec.y.channel_context.y1.global_inter_part.") + for k in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y0.spatial_context_nonanchor.context.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y1.intra_channel_context_nonanchor.keys.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y0.y.gaussian_conditional.scale_table" + in sd_keys + ) + assert "h_a.reduction.0.weight" not in sd_keys + assert not hasattr(model, "entropy_bottleneck") + assert not hasattr(model, "gaussian_conditional") + + loaded = MLIC.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert torch.allclose(out["likelihoods"]["y"], out_loaded["likelihoods"]["y"]) + assert torch.allclose(out["likelihoods"]["z"], out_loaded["likelihoods"]["z"]) + assert loaded.N == 8 + assert loaded.M == 12 + assert loaded.slice_num == 3 + assert loaded.local_kernel == 3 + assert loaded.local_layers == 3 + + def test_mlicplus_forward_and_state_dict_round_trip(self): + pytest.importorskip("timm") + from compressai.models.mlic import MLICPlus + + model = MLICPlus(N=8, M=16, slice_num=4, context_window=3).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + sd_keys = set(model.state_dict().keys()) + assert ( + "latent_codec.y.channel_context.y1.global_inter_part.keys.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y0.spatial_context_nonanchor.relative_position_table" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y1.intra_channel_context_nonanchor.keys.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y0.lrp_anchor.lrp_transform.0.weight" + in sd_keys + ) + assert "latent_codec.entropy_bottleneck.quantiles" not in sd_keys + + loaded = MLICPlus.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert torch.allclose(out["likelihoods"]["y"], out_loaded["likelihoods"]["y"]) + assert torch.allclose(out["likelihoods"]["z"], out_loaded["likelihoods"]["z"]) + assert loaded.N == 8 + assert loaded.M == 16 + assert loaded.slice_num == 4 + assert loaded.context_window == 3 + + def test_mlicpp_upstream_state_dict_conversion(self): + convert_upstream_mlicpp_state_dict = _load_convert_fn( + "convert_mlic_checkpoint.py", "convert_upstream_mlicpp_state_dict" + ) + + upstream = { + "h_a.reduction.0.weight": torch.zeros(2), + "h_s.increase.0.weight": torch.zeros(2), + "entropy_bottleneck.quantiles": torch.zeros(2), + "gaussian_conditional.scale_table": torch.zeros(2), + "local_context.0.relative_position_table": torch.zeros(2), + "channel_context.0.fushion.0.weight": torch.zeros(2), + "global_inter_context.1.keys.0.weight": torch.zeros(2), + "global_intra_context.1.keys.0.weight": torch.zeros(2), + "entropy_parameters_anchor.0.fusion.0.weight": torch.zeros(2), + "entropy_parameters_nonanchor.1.fusion.0.weight": torch.zeros(2), + "lrp_anchor.0.lrp_transform.0.weight": torch.zeros(2), + "lrp_nonanchor.1.lrp_transform.0.weight": torch.zeros(2), + } + converted = convert_upstream_mlicpp_state_dict(upstream) + + assert "latent_codec.h_a.reduction.0.weight" in converted + assert "latent_codec.h_s.increase.0.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + assert ( + "latent_codec.y.latent_codec.y0.y.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.y.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y0.spatial_context_nonanchor.relative_position_table" + in converted + ) + assert ( + "latent_codec.y.channel_context.y0.channel_part.fushion.0.weight" + in converted + ) + assert ( + "latent_codec.y.channel_context.y1.global_inter_part.keys.0.weight" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.intra_channel_context_nonanchor.keys.0.weight" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y0.entropy_parameters_anchor.fusion.0.weight" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.entropy_parameters_nonanchor.fusion.0.weight" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y0.lrp_anchor.lrp_transform.0.weight" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.lrp_nonanchor.lrp_transform.0.weight" + in converted + ) + + assert "h_a.reduction.0.weight" not in converted + assert "entropy_bottleneck.quantiles" not in converted + assert "gaussian_conditional.scale_table" not in converted + assert "local_context.0.relative_position_table" not in converted + assert "channel_context.0.fushion.0.weight" not in converted + + +class TestMlicv2: + def test_forward_state_dict_round_trip_and_gsc_skip_rate(self): + pytest.importorskip("timm") + from compressai.models.mlic import MLICv2 + + model = MLICv2(N=8, M=16, slice_num=4, context_window=3).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + sd_keys = set(model.state_dict().keys()) + assert "g_a.analysis_transform.1.0.norm1.weight" in sd_keys + assert "g_s.synthesis_transform.0.0.norm1.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.spatial_context_anchor.hgcp.queries.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y0.selective_predictor.predictor.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.channel_context.y1.global_inter_part.context.keys.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.channel_context.y1.global_inter_part.reweighting.queries.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y1.intra_channel_context_nonanchor.context.keys.0.weight" + in sd_keys + ) + assert ( + "latent_codec.y.latent_codec.y1.intra_channel_context_nonanchor.rope.theta_x" + in sd_keys + ) + assert "latent_codec.entropy_bottleneck.quantiles" not in sd_keys + + predictor = model.latent_codec.y.latent_codec["y0"].selective_predictor + side_params = torch.randn(1, 32, 4, 4) + scales = torch.linspace(0.1, 0.5, steps=1 * 4 * 4 * 4).reshape(1, 4, 4, 4) + means = torch.zeros_like(scales) + with torch.no_grad(): + selective = predictor( + side_params=side_params, + scales=scales, + means=means, + step="anchor", + )["selective_map"] + hard_ratio = (selective >= 0.5).float().mean().item() + assert 0.1 < hard_ratio < 0.9 + + loaded = MLICv2.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert torch.allclose(out["likelihoods"]["y"], out_loaded["likelihoods"]["y"]) + assert torch.allclose(out["likelihoods"]["z"], out_loaded["likelihoods"]["z"]) + assert loaded.N == 8 + assert loaded.M == 16 + assert loaded.slice_num == 4 + assert loaded.context_window == 3 + assert loaded.downsampling_factor == 64 + + class TestCca: def test_cca_forward_and_state_dict_round_trip(self): from compressai.models.cca import CCAModel diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py index ba4ab71b..afe0c829 100644 --- a/tests/test_models_helpers.py +++ b/tests/test_models_helpers.py @@ -192,6 +192,128 @@ def test_infer_max_support_slices_new_path(self): assert infer_max_support_slices(sd, latent_channels=64, num_slices=8) == 2 +class TestMlicModelSliceCodec: + def _make_mlicpp_codec(self): + pytest.importorskip("timm") + from compressai.models.mlic import MLICPlusPlus + + return MLICPlusPlus(N=8, M=8, slice_num=2, context_window=3).latent_codec.y + + def test_mlicpp_layout_uses_side_context_channel_groups(self): + from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + MultiContextCheckerboardLatentCodec, + ) + from compressai.models.mlic import MLICPlusPlus + + model = MLICPlusPlus(N=8, M=8, slice_num=2, context_window=3) + codec = model.latent_codec.y + assert isinstance(model.latent_codec.z, EntropyBottleneckLatentCodec) + assert model.latent_codec.z.quantizer == "ste" + assert isinstance(codec, ChannelGroupsLatentCodec) + assert type(codec).__name__ == "_SideContextChannelGroupsLatentCodec" + assert codec.groups == [4, 4] + assert codec.support_slices == [(), (0,)] + assert set(codec.channel_context.keys()) == {"y0", "y1"} + assert set(codec.latent_codec.keys()) == {"y0", "y1"} + assert isinstance(codec.latent_codec["y0"], MultiContextCheckerboardLatentCodec) + assert codec.latent_codec["y0"].anchor_parity == "odd" + + def test_mlicpp_state_dict_paths_match_containerized_layout(self): + codec = self._make_mlicpp_codec() + keys = set(codec.state_dict().keys()) + assert any(k.startswith("channel_context.y1.channel_part.") for k in keys) + assert any(k.startswith("channel_context.y1.global_inter_part.") for k in keys) + assert any( + k.startswith("latent_codec.y0.entropy_parameters_anchor.fusion.") + for k in keys + ) + assert any( + k.startswith("latent_codec.y1.entropy_parameters_nonanchor.fusion.") + for k in keys + ) + assert any( + k.startswith("latent_codec.y0.spatial_context_nonanchor.") for k in keys + ) + assert not any( + k.startswith("latent_codec.y0.intra_channel_context_nonanchor.") + for k in keys + ) + assert any( + k.startswith("latent_codec.y1.intra_channel_context_nonanchor.keys.") + for k in keys + ) + assert any(k.startswith("latent_codec.y1.lrp_anchor.") for k in keys) + assert any(k.startswith("latent_codec.y1.lrp_nonanchor.") for k in keys) + assert any( + k.startswith("latent_codec.y1.y.gaussian_conditional.") for k in keys + ) + + def test_mlicpp_slice_codec_forward_runs(self): + torch.manual_seed(0) + codec = self._make_mlicpp_codec().eval() + y = torch.randn(2, 8, 8, 8) + side_params = torch.randn(2, 16, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == (2, 8, 8, 8) + assert out["likelihoods"]["y"].shape == (2, 8, 8, 8) + + @pytest.mark.parametrize("variant", ["mlic", "mlic+"]) + def test_mlic_family_variant_slice_codecs_forward(self, variant): + pytest.importorskip("timm") + from compressai.models.mlic import MLIC, MLICPlus + + torch.manual_seed(2) + if variant == "mlic": + codec = MLIC(N=8, M=8, slice_num=2, local_kernel=3).latent_codec.y + else: + codec = MLICPlus(N=8, M=8, slice_num=2, context_window=3).latent_codec.y + codec = codec.eval() + y = torch.randn(2, 8, 8, 8) + side_params = torch.randn(2, 16, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == (2, 8, 8, 8) + assert out["likelihoods"]["y"].shape == (2, 8, 8, 8) + + def test_mlicv2_layout_injects_hgcp_context_refinement_and_gsc(self): + pytest.importorskip("timm") + from compressai.models.mlic import MLICv2 + + codec = MLICv2(N=8, M=8, slice_num=2, context_window=3).latent_codec.y + keys = set(codec.state_dict().keys()) + + assert getattr( + codec.latent_codec["y0"].spatial_context_anchor, "requires_side_params" + ) + assert codec.latent_codec["y0"].selective_predictor is not None + assert codec.latent_codec["y1"].selective_predictor is not None + assert any( + k.startswith("latent_codec.y0.spatial_context_anchor.hgcp.") for k in keys + ) + assert any( + k.startswith("latent_codec.y0.selective_predictor.predictor.") for k in keys + ) + assert any( + k.startswith("channel_context.y1.global_inter_part.context.keys.") + for k in keys + ) + assert any( + k.startswith("channel_context.y1.global_inter_part.reweighting.") + for k in keys + ) + + def test_rejects_invalid_variant(self): + from compressai.models._helpers.multi_context_slice import ( + _select_global_inter_factory, + ) + + with pytest.raises(ValueError, match="variant"): + _select_global_inter_factory("mlic++") + + class TestSharedDictionary: def test_dt_shape_and_state_dict_path(self): from compressai.models._helpers.dictionary_context import SharedDictionary diff --git a/tests/test_multi_context_checkerboard.py b/tests/test_multi_context_checkerboard.py new file mode 100644 index 00000000..d113496f --- /dev/null +++ b/tests/test_multi_context_checkerboard.py @@ -0,0 +1,303 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn as nn + +from compressai.latent_codecs import ( + CheckerboardLatentCodec, + GaussianConditionalLatentCodec, + MultiContextCheckerboardLatentCodec, +) +from compressai.layers import CheckerboardMaskedConv2d + + +class _ZeroContext(nn.Module): + """Stand-in spatial context that emits zeros with a fixed channel count. + + Used by the ELIC-equivalence regression: upstream ``CheckerboardLatentCodec`` + feeds the anchor pass an all-zero tensor sized to ``context_prediction.out``; + ``MultiContextCheckerboardLatentCodec`` skips the spatial slot entirely when + ``spatial_context_anchor=None``, so the regression has to supply this zero + context explicitly. + """ + + def __init__(self, out_channels: int) -> None: + super().__init__() + self.out_channels = int(out_channels) + + def forward(self, y: torch.Tensor) -> torch.Tensor: + return y.new_zeros(y.shape[0], self.out_channels, y.shape[2], y.shape[3]) + + +class _ZeroEntropyParameters(nn.Module): + def __init__(self, out_channels: int) -> None: + super().__init__() + self.out_channels = int(out_channels) + + def forward(self, params: torch.Tensor) -> torch.Tensor: + return params.new_zeros( + params.shape[0], + self.out_channels, + params.shape[2], + params.shape[3], + ) + + +class _ConstantResidual(nn.Module): + def __init__(self, out_channels: int, value: float) -> None: + super().__init__() + self.out_channels = int(out_channels) + self.value = float(value) + + def forward(self, params: torch.Tensor) -> torch.Tensor: + return params.new_full( + (params.shape[0], self.out_channels, params.shape[2], params.shape[3]), + self.value, + ) + + +class _ConvSelectivePredictor(nn.Module): + def __init__(self, side_channels: int, channels: int) -> None: + super().__init__() + self.proj = nn.Conv2d(side_channels + 2 * channels, channels, 1) + + def forward( + self, + *, + side_params: torch.Tensor, + scales: torch.Tensor, + means: torch.Tensor, + step: str, + ) -> torch.Tensor: + return torch.sigmoid(self.proj(torch.cat([side_params, scales, means], dim=1))) + + +class TestMultiContextCheckerboardLatentCodec: + class _IntraContext(nn.Module): + def __init__(self, side_ch=8, y_ch=4, out_ch=3): + super().__init__() + self.proj = nn.Conv2d(side_ch + y_ch, out_ch, 1) + + def forward(self, side_params, anchor_y_hat): + return self.proj(torch.cat([side_params, anchor_y_hat], dim=1)) + + @staticmethod + def _scale_table(): + return [0.11, 0.5, 1.0, 2.0, 4.0] + + def _make( + self, + *, + y_ch=4, + side_ch=8, + anchor_in=None, + nonanchor_in=None, + **kwargs, + ): + """Construct a codec with caller-controlled head input widths. + + ``anchor_in`` / ``nonanchor_in`` default to ``side_ch`` because the + codec now omits any ``spatial_context_*=None`` slot from the + entropy-parameters input. Tests that supply spatial / intra-channel + hooks must widen the corresponding head. + """ + if anchor_in is None: + anchor_in = side_ch + if nonanchor_in is None: + nonanchor_in = side_ch + return MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=nn.Conv2d(anchor_in, 2 * y_ch, 1), + entropy_parameters_nonanchor=nn.Conv2d(nonanchor_in, 2 * y_ch, 1), + scale_table=self._scale_table(), + **kwargs, + ) + + def test_default_forward_shapes(self): + codec = self._make().eval() + y = torch.randn(2, 4, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == (2, 4, 8, 8) + assert out["likelihoods"]["y"].shape == (2, 4, 8, 8) + + def test_anchor_skips_spatial_context_when_none(self): + """``spatial_context_anchor=None`` must NOT pad with zeros. + + Anchor head sized to ``side_ch`` only — would crash with channel + mismatch if the leaf still emitted a ``y_ch``-wide zero block. + Locks in the skip semantics required for MLIC++ k=0 anchor wiring. + """ + codec = self._make( + anchor_in=8, # side_ch only + nonanchor_in=8 + 4, # side_ch + spatial_context_nonanchor.out + spatial_context_nonanchor=CheckerboardMaskedConv2d(4, 4, 5, padding=2), + ).eval() + y = torch.randn(2, 4, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == y.shape + + def test_spatial_nonanchor_forward_shapes(self): + codec = self._make( + nonanchor_in=8 + 4, # side_ch + spatial_context_nonanchor.out + spatial_context_nonanchor=CheckerboardMaskedConv2d(4, 4, 5, padding=2), + ).eval() + y = torch.randn(2, 4, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == y.shape + assert out["likelihoods"]["y"].shape == y.shape + + def test_all_hooks_forward_shapes_and_state_dict_paths(self): + def lrp_inputs(side_params, params, y_hat): + return torch.cat([side_params, params, y_hat], dim=1) + + codec = MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=nn.Conv2d(4 + 8, 8, 1), + entropy_parameters_nonanchor=nn.Conv2d(4 + 8 + 3, 8, 1), + scale_table=self._scale_table(), + spatial_context_anchor=nn.Conv2d(4, 4, 1), + spatial_context_nonanchor=CheckerboardMaskedConv2d(4, 4, 5, padding=2), + intra_channel_context_nonanchor=self._IntraContext(), + lrp_anchor=nn.Conv2d(8 + 8 + 4, 4, 1), + lrp_nonanchor=nn.Conv2d(8 + 8 + 4, 4, 1), + lrp_input_builder=lrp_inputs, + selective_predictor=_ConvSelectivePredictor(8, 4), + ).eval() + keys = set(codec.state_dict().keys()) + assert any(k.startswith("entropy_parameters_anchor.") for k in keys) + assert any(k.startswith("entropy_parameters_nonanchor.") for k in keys) + assert any(k.startswith("spatial_context_anchor.") for k in keys) + assert any(k.startswith("spatial_context_nonanchor.") for k in keys) + assert any(k.startswith("intra_channel_context_nonanchor.") for k in keys) + assert any(k.startswith("selective_predictor.") for k in keys) + assert any(k.startswith("lrp_anchor.") for k in keys) + assert any(k.startswith("lrp_nonanchor.") for k in keys) + assert any(k.startswith("y.gaussian_conditional.") for k in keys) + + y = torch.randn(2, 4, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out = codec(y, side_params) + assert out["y_hat"].shape == y.shape + + def test_lrp_activation_can_be_skipped(self): + codec = MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=_ZeroEntropyParameters(8), + entropy_parameters_nonanchor=_ZeroEntropyParameters(8), + scale_table=self._scale_table(), + lrp_anchor=_ConstantResidual(4, 0.25), + lrp_nonanchor=_ConstantResidual(4, 0.25), + lrp_activation=None, + lrp_scale=1.0, + ).eval() + y = torch.zeros(1, 4, 4, 4) + side_params = torch.zeros(1, 8, 4, 4) + + with torch.no_grad(): + out = codec(y, side_params) + + assert torch.allclose(out["y_hat"], torch.full_like(y, 0.25)) + + def test_state_dict_round_trip(self): + torch.manual_seed(13) + kwargs = dict( + nonanchor_in=8 + 4, + spatial_context_nonanchor=CheckerboardMaskedConv2d(4, 4, 5, padding=2), + ) + codec = self._make(**kwargs).eval() + reconstructed = self._make(**kwargs).eval() + reconstructed.load_state_dict(codec.state_dict()) + y = torch.randn(2, 4, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + with torch.no_grad(): + out_a = codec(y, side_params) + out_b = reconstructed(y, side_params) + assert torch.allclose(out_a["y_hat"], out_b["y_hat"]) + assert torch.allclose(out_a["likelihoods"]["y"], out_b["likelihoods"]["y"]) + + def test_compress_decompress_round_trip(self): + torch.manual_seed(17) + codec = self._make( + nonanchor_in=8 + 4, + spatial_context_nonanchor=CheckerboardMaskedConv2d(4, 4, 5, padding=2), + ).eval() + codec.y.gaussian_conditional.update() + y = torch.randn(1, 4, 8, 8) + side_params = torch.randn(1, 8, 8, 8) + with torch.no_grad(): + forward = codec(y, side_params) + compressed = codec.compress(y, side_params) + decompressed = codec.decompress( + compressed["strings"], compressed["shape"], side_params + ) + assert torch.allclose(forward["y_hat"], compressed["y_hat"]) + assert torch.allclose(compressed["y_hat"], decompressed["y_hat"]) + + def test_matches_checkerboard_latent_codec_when_heads_are_shared(self): + torch.manual_seed(19) + y_ch, side_ch = 4, 8 + # Upstream CheckerboardLatentCodec: shared head + masked-conv spatial + # context. Anchor pass feeds an all-zero tensor of width + # ``context_prediction.out_channels`` to the head. + entropy_parameters = nn.Conv2d(y_ch + side_ch, 2 * y_ch, 1) + context_prediction = CheckerboardMaskedConv2d(y_ch, y_ch, 5, padding=2) + base = CheckerboardLatentCodec( + latent_codec={ + "y": GaussianConditionalLatentCodec(scale_table=self._scale_table()) + }, + entropy_parameters=entropy_parameters, + context_prediction=context_prediction, + ).eval() + # Sibling leaf must reproduce that anchor-pass zero slot explicitly, + # because spatial_context_anchor=None is now "skip the slot" not + # "pad with y-shaped zeros". Wider regression for arbitrary + # context_prediction.out_channels != y_ch is enabled by this design. + generalized = MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=entropy_parameters, + entropy_parameters_nonanchor=entropy_parameters, + spatial_context_anchor=_ZeroContext(context_prediction.out_channels), + spatial_context_nonanchor=context_prediction, + latent_codec={ + "y": GaussianConditionalLatentCodec(scale_table=self._scale_table()) + }, + ).eval() + y = torch.randn(2, y_ch, 8, 8) + side_params = torch.randn(2, side_ch, 8, 8) + with torch.no_grad(): + base_out = base(y, side_params) + generalized_out = generalized(y, side_params) + assert torch.allclose(base_out["y_hat"], generalized_out["y_hat"]) + assert torch.allclose( + base_out["likelihoods"]["y"], generalized_out["likelihoods"]["y"] + ) diff --git a/tests/test_multi_context_checkerboard_selective.py b/tests/test_multi_context_checkerboard_selective.py new file mode 100644 index 00000000..e406f0b4 --- /dev/null +++ b/tests/test_multi_context_checkerboard_selective.py @@ -0,0 +1,193 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List, Optional + +import torch +import torch.nn as nn + +from compressai.latent_codecs import MultiContextCheckerboardLatentCodec +from compressai.layers import CheckerboardMaskedConv2d + + +class _ZeroEntropyParameters(nn.Module): + def __init__(self, out_channels: int) -> None: + super().__init__() + self.out_channels = int(out_channels) + + def forward(self, params: torch.Tensor) -> torch.Tensor: + return params.new_zeros( + params.shape[0], + self.out_channels, + params.shape[2], + params.shape[3], + ) + + +class _StepSelectivePredictor(nn.Module): + def __init__(self, *, anchor_value: float, non_anchor_value: float) -> None: + super().__init__() + self.anchor_value = float(anchor_value) + self.non_anchor_value = float(non_anchor_value) + + def forward( + self, + *, + side_params: torch.Tensor, + scales: torch.Tensor, + means: torch.Tensor, + step: str, + ) -> torch.Tensor: + value = self.anchor_value if step == "anchor" else self.non_anchor_value + return torch.full_like(scales, value) + + +def _scale_table() -> List[float]: + return [0.11, 0.5, 1.0, 2.0, 4.0] + + +def _make( + *, + y_ch: int = 4, + side_ch: int = 8, + anchor_in: Optional[int] = None, + nonanchor_in: Optional[int] = None, + **kwargs, +) -> MultiContextCheckerboardLatentCodec: + if anchor_in is None: + anchor_in = side_ch + if nonanchor_in is None: + nonanchor_in = side_ch + return MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=nn.Conv2d(anchor_in, 2 * y_ch, 1), + entropy_parameters_nonanchor=nn.Conv2d(nonanchor_in, 2 * y_ch, 1), + scale_table=_scale_table(), + **kwargs, + ) + + +def _non_anchor_mask_like(y: torch.Tensor) -> torch.Tensor: + mask = torch.zeros_like(y, dtype=torch.bool) + mask[..., 0::2, 1::2] = True + mask[..., 1::2, 0::2] = True + return mask + + +def _encoded_size(strings: List[List[bytes]]) -> int: + return sum(len(s) for step_strings in strings for s in step_strings) + + +def test_selective_predictor_none_is_identity() -> None: + torch.manual_seed(23) + kwargs = dict( + nonanchor_in=8 + 4, + spatial_context_nonanchor=CheckerboardMaskedConv2d(4, 4, 5, padding=2), + ) + baseline = _make(**kwargs).eval() + explicit_none = _make(selective_predictor=None, **kwargs).eval() + explicit_none.load_state_dict(baseline.state_dict()) + baseline.y.gaussian_conditional.update() + explicit_none.y.gaussian_conditional.update() + + y = torch.randn(1, 4, 8, 8) + side_params = torch.randn(1, 8, 8, 8) + + with torch.no_grad(): + baseline_forward = baseline(y, side_params) + explicit_forward = explicit_none(y, side_params) + assert torch.allclose( + baseline_forward["likelihoods"]["y"], + explicit_forward["likelihoods"]["y"], + ) + assert torch.allclose(baseline_forward["y_hat"], explicit_forward["y_hat"]) + + baseline_compressed = baseline.compress(y, side_params) + explicit_compressed = explicit_none.compress(y, side_params) + assert baseline_compressed["strings"] == explicit_compressed["strings"] + assert torch.allclose(baseline_compressed["y_hat"], explicit_compressed["y_hat"]) + + baseline_decompressed = baseline.decompress( + baseline_compressed["strings"], baseline_compressed["shape"], side_params + ) + explicit_decompressed = explicit_none.decompress( + explicit_compressed["strings"], explicit_compressed["shape"], side_params + ) + assert torch.allclose( + baseline_decompressed["y_hat"], explicit_decompressed["y_hat"] + ) + + +def test_selective_predictor_skip_semantics() -> None: + selective = MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=_ZeroEntropyParameters(8), + entropy_parameters_nonanchor=_ZeroEntropyParameters(8), + scale_table=_scale_table(), + selective_predictor=_StepSelectivePredictor( + anchor_value=1.0, + non_anchor_value=0.0, + ), + ).eval() + full = MultiContextCheckerboardLatentCodec( + entropy_parameters_anchor=_ZeroEntropyParameters(8), + entropy_parameters_nonanchor=_ZeroEntropyParameters(8), + scale_table=_scale_table(), + ).eval() + selective.y.gaussian_conditional.update() + full.y.gaussian_conditional.update() + + y = torch.linspace(-2.0, 2.0, steps=64).reshape(1, 4, 4, 4) + side_params = torch.zeros(1, 8, 4, 4) + non_anchor_mask = _non_anchor_mask_like(y) + + with torch.no_grad(): + out = selective(y, side_params) + + assert torch.all(out["likelihoods"]["y"][non_anchor_mask] == 1) + assert torch.allclose( + out["y_hat"][non_anchor_mask], + torch.zeros_like(out["y_hat"][non_anchor_mask]), + ) + + selective_compressed = selective.compress(y, side_params) + full_compressed = full.compress(y, side_params) + assert selective_compressed["strings"][1] == [b""] + assert _encoded_size(selective_compressed["strings"]) < _encoded_size( + full_compressed["strings"] + ) + + decompressed = selective.decompress( + selective_compressed["strings"], + selective_compressed["shape"], + side_params, + ) + assert torch.allclose(selective_compressed["y_hat"], decompressed["y_hat"]) + assert torch.allclose( + decompressed["y_hat"][non_anchor_mask], + torch.zeros_like(decompressed["y_hat"][non_anchor_mask]), + ) diff --git a/tests/test_zoo.py b/tests/test_zoo.py index 8ebf1c33..a150b66a 100644 --- a/tests/test_zoo.py +++ b/tests/test_zoo.py @@ -45,8 +45,12 @@ cheng2020_attn, mbt2018, mbt2018_mean, + mlic, + mlicplus, + mlicpp, + mlicv2, ) -from compressai.zoo.image import _load_model +from compressai.zoo.image import _load_model, model_architectures class TestLoadModel: @@ -58,6 +62,34 @@ def test_invalid(self): _load_model("mbt2018", "mse", 0) +class TestMlicZoo: + def test_model_architectures_use_lazy_imports(self): + for name in ("mlic", "mlicplus", "mlicpp", "mlicv2"): + assert name in model_architectures + assert type(model_architectures[name]).__name__ == "_LazyImport" + + def test_factories(self): + pytest.importorskip("timm") + from compressai.models.mlic import MLIC, MLICPlus, MLICPlusPlus, MLICv2 + + assert isinstance(mlic(N=8, M=12, slice_num=3, local_kernel=3), MLIC) + assert isinstance(mlicplus(N=8, M=16, slice_num=4, context_window=3), MLICPlus) + assert isinstance( + mlicpp(N=8, M=16, slice_num=4, context_window=3), MLICPlusPlus + ) + assert isinstance(mlicv2(N=8, M=16, slice_num=4, context_window=3), MLICv2) + + def test_pretrained_unavailable(self): + with pytest.raises(RuntimeError, match="Pre-trained MLIC"): + mlic(pretrained=True) + with pytest.raises(RuntimeError, match=r"Pre-trained MLIC\+"): + mlicplus(pretrained=True) + with pytest.raises(RuntimeError, match=r"Pre-trained MLIC\+\+"): + mlicpp(pretrained=True) + with pytest.raises(RuntimeError, match="Pre-trained MLICv2"): + mlicv2(pretrained=True) + + class TestBmshj2018Factorized: def test_params(self): for i in range(1, 6):