Skip to content

fix: miscalculation of num_steps when using num_epoch and lmdb#5488

Open
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/epoch
Open

fix: miscalculation of num_steps when using num_epoch and lmdb#5488
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/epoch

Conversation

@OutisLi
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi commented Jun 3, 2026

Summary by CodeRabbit

  • Bug Fixes

    • More accurate batch-count reporting for mixed and grouped batching modes.
    • Dataset index now reflects dataset-level batch totals (sampler-driven) instead of reader-only estimates.
    • Distributed sampler now reports per-rank batch counts using a precomputed global total.
    • Training step resolution updated to use actual dataloader batch counts.
  • Tests

    • Added tests validating dataset total/batch index behavior and distributed sampler length caching.

Copilot AI review requested due to automatic review settings June 3, 2026 06:04
@dosubot dosubot Bot added the bug label Jun 3, 2026
@github-actions github-actions Bot added the Python label Jun 3, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adjust LMDB batch counting to reflect auto-probability expansion and align distributed/training step calculations, with added regression tests.

Changes:

  • Update total_batch/index semantics to track sampler-expanded batch counts.
  • Use DataLoader length for LR step calculations when training on LmdbDataset.
  • Add tests covering total_batch alignment and distributed sampler length with auto-prob expansion.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
source/tests/pt/test_lmdb_dataloader.py Adds tests validating expanded batch counts and distributed sampler __len__() behavior.
deepmd/pt/utils/lmdb_dataset.py Changes index/total_batch to be derived from the batch sampler length (including expansion).
deepmd/pt/train/training.py Uses len(training_dataloader) to compute batch counts for LR scheduling with LMDB datasets.
deepmd/dpmodel/utils/lmdb_data.py Updates total_batch calculation and adjusts distributed sampler __len__() to include expansion via SameNlocBatchSampler.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/utils/lmdb_dataset.py
Comment thread deepmd/dpmodel/utils/lmdb_data.py Outdated
Comment thread source/tests/pt/test_lmdb_dataloader.py
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 3, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8c9ad03d-da8f-4326-988d-abaa066c9805

📥 Commits

Reviewing files that changed from the base of the PR and between 48ecc68 and e64bfae.

📒 Files selected for processing (3)
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/utils/lmdb_dataset.py
  • source/tests/pt/test_lmdb_dataloader.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/utils/lmdb_dataset.py

📝 Walkthrough

Walkthrough

Batch-count computation moved to sampler/dataloader-derived totals: LmdbDataReader exposes total_batch and index uses it; LmdbDataset.total_batch returns len(self._batch_sampler) and index adapts when block targets are set; DistributedSameNlocBatchSampler caches global expanded batch count and divides by world size; Trainer uses dataloader lengths to resolve step counts. Tests validate these behaviors.

Changes

Batch Count Estimation Alignment

Layer / File(s) Summary
Reader-level batch counting & distributed sampler caching
deepmd/dpmodel/utils/lmdb_data.py
Adds LmdbDataReader.total_batch (mixed vs non-mixed logic) and makes index return [total_batch]. DistributedSameNlocBatchSampler.__init__ caches len(SameNlocBatchSampler(..., shuffle=False, block_targets=...)) as _total_batches, and __len__ returns ceil(_total_batches / world_size).
Dataset-level batch alignment
deepmd/pt/utils/lmdb_dataset.py
LmdbDataset.total_batch now returns len(self._batch_sampler); LmdbDataset.index returns the reader index only when _block_targets is falsy, otherwise returns [self.total_batch].
Trainer batch count resolution
deepmd/pt/train/training.py
Trainer now resolves single-task and per-task total batches from len(self.training_dataloader) / len(self.training_dataloader[model_key]) instead of reading training_data.total_batch.
Batch alignment validation tests
source/tests/pt/test_lmdb_dataloader.py
Adds tests verifying ds.total_batch == len(ds._batch_sampler), that ds.index follows ds.total_batch when auto-prob expands batching, distributed sampler length equals ceil(len(ds._batch_sampler)/2) for world_size=2, and that SameNlocBatchSampler constructor is called only once when distributed sampler caches the global total.

Sequence Diagram(s)

sequenceDiagram
  participant LmdbDataReader
  participant SameNlocBatchSampler
  participant DistributedSameNlocBatchSampler
  participant Trainer
  LmdbDataReader->>SameNlocBatchSampler: instantiate (shuffle=False, block_targets) to measure expanded batches
  SameNlocBatchSampler-->>LmdbDataReader: return expanded batch count
  LmdbDataReader->>DistributedSameNlocBatchSampler: provide total_batch or allow sampler-based estimation
  DistributedSameNlocBatchSampler->>DistributedSameNlocBatchSampler: cache _total_batches and compute ceil(_total_batches/world_size)
  DistributedSameNlocBatchSampler-->>Trainer: len() returns per-rank batch count
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5413: Related LMDB batching/length semantics changes; both PRs affect batch-size computation and indexing logic in deepmd/dpmodel/utils/lmdb_data.py.

Suggested reviewers

  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.25% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly addresses the main fix: resolving miscalculation of num_steps when using num_epoch with LMDB datasets, which is reflected in the training.py changes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/utils/lmdb_data.py`:
- Around line 1311-1320: The current __len__() returns ceil(total /
self._world_size) for every rank which mismatches the strided partitioning in
__iter__(); change __len__() to compute a rank-aware batch count by getting
total = len(SameNlocBatchSampler(self._reader, shuffle=False,
block_targets=self._block_targets)), then compute base = total //
self._world_size and remainder = total % self._world_size and return base + (1
if self._rank < remainder else 0) so that __len__() matches the actual number of
batches produced by the __iter__() strided partitioning (alternatively, adjust
_partition_batches() to pad/repeat batches so every rank emits
ceil(total/world_size) and keep current __len__()).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8ecc555b-6099-4600-a3c5-6d331cce1d8d

📥 Commits

Reviewing files that changed from the base of the PR and between 27a18b6 and 48ecc68.

📒 Files selected for processing (4)
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/train/training.py
  • deepmd/pt/utils/lmdb_dataset.py
  • source/tests/pt/test_lmdb_dataloader.py

Comment thread deepmd/dpmodel/utils/lmdb_data.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 3, 2026

Codecov Report

❌ Patch coverage is 81.25000% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.36%. Comparing base (27a18b6) to head (e64bfae).

Files with missing lines Patch % Lines
deepmd/pt/train/training.py 0.00% 2 Missing ⚠️
deepmd/dpmodel/utils/lmdb_data.py 90.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5488   +/-   ##
=======================================
  Coverage   81.36%   81.36%           
=======================================
  Files         868      868           
  Lines       96567    96571    +4     
  Branches     4233     4235    +2     
=======================================
+ Hits        78570    78575    +5     
+ Misses      16697    16696    -1     
  Partials     1300     1300           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

expected_len = math.ceil(len(ds._batch_sampler) / 2)
assert len(dist_sampler) == expected_len
assert len(dist_sampler) == expected_len
assert calls == 1
@OutisLi OutisLi requested a review from njzjz June 4, 2026 06:11
bs = self._reader.get_batch_size_for_nloc(nloc)
total += (len(indices) + bs - 1) // bs
return math.ceil(total / self._world_size)
return math.ceil(self._total_batches / self._world_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__len__ should match this sampler's rank-specific iterator. _partition_batches yields all_batches[self._rank :: self._world_size], whose length is ceil((total - rank) / world_size) (or 0 if rank >= total), not always ceil(total / world_size). With total_batches=3, world_size=2, rank 1 yields only one batch but reports two. This can make len(DataLoader) overestimate num_steps on nonzero ranks.

Could we compute the length from the same strided partition formula, e.g. (self._total_batches + self._world_size - 1 - self._rank) // self._world_size?

— OpenClaw 2026.5.28 (model: custom-chat-jinzhezeng-group/gpt-5.5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants