From c60b93999a812d0159c4e17d8d221c3120f849c4 Mon Sep 17 00:00:00 2001 From: "menglingda.mld" Date: Tue, 26 May 2026 17:27:13 +0800 Subject: [PATCH] feat: add GenericLruCache, ConcurrentHashMap, MurmurHash, and Preconditions utilities --- LICENSE | 36 + NOTICE | 3 + src/paimon/common/utils/concurrent_hash_map.h | 92 ++ .../common/utils/concurrent_hash_map_test.cpp | 157 +++ src/paimon/common/utils/generic_lru_cache.h | 329 +++++++ .../common/utils/generic_lru_cache_test.cpp | 893 ++++++++++++++++++ src/paimon/common/utils/murmurhash_utils.h | 265 ++++++ .../common/utils/murmurhash_utils_test.cpp | 113 +++ src/paimon/common/utils/preconditions.h | 102 ++ .../common/utils/preconditions_test.cpp | 74 ++ 10 files changed, 2064 insertions(+) create mode 100644 src/paimon/common/utils/concurrent_hash_map.h create mode 100644 src/paimon/common/utils/concurrent_hash_map_test.cpp create mode 100644 src/paimon/common/utils/generic_lru_cache.h create mode 100644 src/paimon/common/utils/generic_lru_cache_test.cpp create mode 100644 src/paimon/common/utils/murmurhash_utils.h create mode 100644 src/paimon/common/utils/murmurhash_utils_test.cpp create mode 100644 src/paimon/common/utils/preconditions.h create mode 100644 src/paimon/common/utils/preconditions_test.cpp diff --git a/LICENSE b/LICENSE index 511feeb..976412a 100644 --- a/LICENSE +++ b/LICENSE @@ -363,6 +363,42 @@ This product includes code from Apache ORC. Copyright: 2013 and onwards The Apache Software Foundation. Home page: https://orc.apache.org/ License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from xxHash. + +* MMH_rotl32 utility in src/paimon/common/utils/murmurhash_utils.h + +Copyright: 2012-2023 Yann Collet +Home page: https://www.xxhash.com +License: https://opensource.org/license/bsd-2-clause + +BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + -------------------------------------------------------------------------------- This product includes code from cppjieba. diff --git a/NOTICE b/NOTICE index 9160245..dbfb7d3 100644 --- a/NOTICE +++ b/NOTICE @@ -14,5 +14,8 @@ This product includes software from RocksDB project (Apache 2.0 and BSD 3-clause Copyright (c) 2011-present, Facebook, Inc. All rights reserved. Copyright (c) 2011 The LevelDB Authors. All rights reserved. +This product includes software from xxHash project (BSD 2-clause) +Copyright (C) 2012-2023 Yann Collet + This product includes software from cppjieba project (MIT) Copyright 2013 \ No newline at end of file diff --git a/src/paimon/common/utils/concurrent_hash_map.h b/src/paimon/common/utils/concurrent_hash_map.h new file mode 100644 index 0000000..79a21f0 --- /dev/null +++ b/src/paimon/common/utils/concurrent_hash_map.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paimon/common/utils/murmurhash_utils.h" +#include "tbb/concurrent_hash_map.h" + +namespace paimon { +template > +class ConcurrentHashMap { + private: + using HashMap = tbb::concurrent_hash_map; + + public: + ConcurrentHashMap() = default; + ~ConcurrentHashMap() = default; + + // No copying allowed + ConcurrentHashMap(const ConcurrentHashMap&) = delete; + void operator=(const ConcurrentHashMap&) = delete; + ConcurrentHashMap(ConcurrentHashMap&&) = delete; + ConcurrentHashMap& operator=(ConcurrentHashMap&&) = delete; + + std::optional Find(const Key& key) const { + typename HashMap::const_accessor accessor; + if (hash_map_.find(accessor, key)) { + return accessor->second; + } + return std::nullopt; + } + + void Insert(const Key& key, const T& value) { + typename HashMap::accessor accessor; + hash_map_.insert(accessor, key); + accessor->second = value; + } + + void Erase(const Key& key) { + typename HashMap::accessor accessor; + if (hash_map_.find(accessor, key)) { + hash_map_.erase(accessor); + } + } + + size_t Size() const { + return hash_map_.size(); + } + + private: + HashMap hash_map_; +}; + +class VectorStringHashCompare { + public: + size_t hash(const std::vector& key) const { + int32_t ret = MurmurHashUtils::DEFAULT_SEED; + for (const auto& s : key) { + ret = MurmurHashUtils::HashUnsafeBytes(reinterpret_cast(s.data()), + /*offset=*/0, s.size(), ret); + } + return ret; + } + + bool equal(const std::vector& a, const std::vector& b) const { + return a == b; + } +}; +} // namespace paimon diff --git a/src/paimon/common/utils/concurrent_hash_map_test.cpp b/src/paimon/common/utils/concurrent_hash_map_test.cpp new file mode 100644 index 0000000..d3462af --- /dev/null +++ b/src/paimon/common/utils/concurrent_hash_map_test.cpp @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "paimon/common/utils/concurrent_hash_map.h" + +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +TEST(ConcurrentHashMapTest, TestSimple) { + ConcurrentHashMap hash_map; + ASSERT_EQ(hash_map.Find(10), std::nullopt); + hash_map.Insert(1, "a"); + hash_map.Insert(2, "b"); + hash_map.Insert(3, "c"); + + ASSERT_EQ(hash_map.Find(1).value(), "a"); + ASSERT_EQ(hash_map.Find(2).value(), "b"); + ASSERT_EQ(hash_map.Find(3).value(), "c"); + ASSERT_EQ(hash_map.Find(10), std::nullopt); + ASSERT_EQ(hash_map.Size(), 3); + + hash_map.Erase(2); + ASSERT_EQ(hash_map.Find(2), std::nullopt); + ASSERT_EQ(hash_map.Size(), 2); + + // non-exist key + hash_map.Erase(4); + ASSERT_EQ(hash_map.Find(1).value(), "a"); + ASSERT_EQ(hash_map.Find(3).value(), "c"); + ASSERT_EQ(hash_map.Size(), 2); +} + +TEST(ConcurrentHashMapTest, TestVectorStringHashCompare) { + ConcurrentHashMap, int32_t, VectorStringHashCompare> hash_map; + ASSERT_EQ(hash_map.Find({}), std::nullopt); + + hash_map.Insert({"a", "b"}, 1); + hash_map.Insert({"a", "c"}, 2); + hash_map.Insert({"b", "c"}, 3); + hash_map.Insert({}, 4); + + ASSERT_EQ(hash_map.Find({"a", "b"}).value(), 1); + ASSERT_EQ(hash_map.Find({"a", "c"}).value(), 2); + ASSERT_EQ(hash_map.Find({"b", "c"}).value(), 3); + ASSERT_EQ(hash_map.Find({}), 4); + ASSERT_EQ(hash_map.Find({"non"}), std::nullopt); + ASSERT_EQ(hash_map.Size(), 4); +} + +TEST(ConcurrentHashMapTest, TestMultiThreadInsertAndFindAndDelete) { + int32_t map_size = 1000; + auto insert_task = [&](ConcurrentHashMap& hash_map) { + for (int32_t i = 0; i < map_size; i++) { + usleep(paimon::test::RandomNumber(0, 9)); + hash_map.Insert(i, std::to_string(i + 1)); + } + }; + auto find_task = [&](ConcurrentHashMap& hash_map) { + int32_t found = 0, not_found = 0; + for (int32_t i = 0; i < map_size; i++) { + usleep(paimon::test::RandomNumber(0, 9)); + auto value = hash_map.Find(i); + if (value) { + ASSERT_EQ(value.value(), std::to_string(i + 1)); + found++; + } else { + not_found++; + } + } + ASSERT_EQ(found + not_found, map_size); + }; + + auto delete_task = [&](ConcurrentHashMap& hash_map) { + for (int32_t i = 0; i < map_size; i++) { + usleep(paimon::test::RandomNumber(0, 9)); + hash_map.Erase(i); + } + }; + + { + ConcurrentHashMap hash_map; + // insert + insert_task(hash_map); + // multi-thread find and delete + std::thread thread1(find_task, std::ref(hash_map)); + std::thread thread2(delete_task, std::ref(hash_map)); + + thread1.join(); + thread2.join(); + + // check final states + ASSERT_EQ(hash_map.Size(), 0); + } + { + ConcurrentHashMap hash_map; + // multi-thread insert and find + std::thread thread1(insert_task, std::ref(hash_map)); + std::thread thread2(find_task, std::ref(hash_map)); + + thread1.join(); + thread2.join(); + + // check final states + ASSERT_EQ(hash_map.Size(), map_size); + for (int32_t i = 0; i < map_size; i++) { + auto value = hash_map.Find(i); + ASSERT_TRUE(value); + ASSERT_EQ(value.value(), std::to_string(i + 1)); + } + } + { + ConcurrentHashMap hash_map; + // multi-thread insert and find and delete + std::thread thread1(insert_task, std::ref(hash_map)); + std::thread thread2(find_task, std::ref(hash_map)); + std::thread thread3(delete_task, std::ref(hash_map)); + + thread1.join(); + thread2.join(); + thread3.join(); + + // check final states + ASSERT_TRUE(hash_map.Size() >= 0 && hash_map.Size() <= static_cast(map_size)); + for (int32_t i = 0; i < map_size; i++) { + auto value = hash_map.Find(i); + if (value) { + ASSERT_EQ(value.value(), std::to_string(i + 1)); + } + } + } +} + +} // namespace paimon::test diff --git a/src/paimon/common/utils/generic_lru_cache.h b/src/paimon/common/utils/generic_lru_cache.h new file mode 100644 index 0000000..7ae2f63 --- /dev/null +++ b/src/paimon/common/utils/generic_lru_cache.h @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmt/format.h" +#include "paimon/result.h" +#include "paimon/traits.h" + +namespace paimon { +/// A generic LRU cache with support for weight-based eviction, time-based expiration, +/// and removal callbacks. +/// +/// Uses std::list + unordered_map for O(1) get/put/evict: +/// - list stores entries in LRU order (most recently used at front) +/// - map stores key -> list::iterator for O(1) lookup +/// +/// @tparam K Key type +/// @tparam V Value type +/// @tparam Hash Hash function for K (default: std::hash) +/// @tparam KeyEqual Equality function for K (default: std::equal_to) +/// +/// @note Thread-safe: all public methods are protected by a read-write lock. +template , + typename KeyEqual = std::equal_to> +class GenericLruCache { + public: + /// Cause of a cache entry removal, passed to the removal callback. + enum class RemovalCause { + EXPLICIT, // Removed by Invalidate() or InvalidateAll() + SIZE, // Evicted because total weight exceeded max_weight + EXPIRED, // Evicted because the entry expired (expireAfterAccess) + REPLACED // Replaced by a new value for the same key via Put() + }; + + using WeighFunc = std::function; + using RemovalCallback = std::function; + + /// Configuration options for the cache. + struct Options { + /// Maximum total weight of all entries. Entries are evicted (LRU) when exceeded. + int64_t max_weight = INT64_MAX; + + /// Time in milliseconds after last access before an entry expires. + /// A value < 0 disables expiration. + int64_t expire_after_access_ms = -1; + + /// Function to compute the weight of an entry. If nullptr, each entry has weight 1. + WeighFunc weigh_func = nullptr; + + /// Callback invoked when an entry is removed (evicted, invalidated, or replaced). + /// If nullptr, no callback is invoked. + RemovalCallback removal_callback = nullptr; + }; + + explicit GenericLruCache(Options options) : options_(std::move(options)) {} + + /// Look up a key in the cache. On hit, promotes the entry to the front (most recently + /// used) and updates its access time. Returns std::nullopt on miss or if the entry + /// has expired. + std::optional GetIfPresent(const K& key) { + std::unique_lock lock(mutex_); + return FindPromoteOrExpire(key); + } + + /// Look up a key. On miss, load via the supplier, insert into cache, and return. + /// If the supplier returns an error, the error is propagated and nothing is cached. + Result Get(const K& key, std::function(const K&)> supplier) { + { + std::unique_lock lock(mutex_); + auto cached = FindPromoteOrExpire(key); + if (cached.has_value()) { + return std::move(cached.value()); + } + } + + // Cache miss: load via supplier outside the lock + PAIMON_ASSIGN_OR_RAISE(V value, supplier(key)); + int64_t weight = ComputeWeight(key, value); + if (weight > options_.max_weight) { + return value; + } + + std::unique_lock lock(mutex_); + // Double-check: another thread may have inserted while we were loading + auto cached = FindPromoteOrExpire(key); + if (cached.has_value()) { + return std::move(cached.value()); + } + + InsertEntry(key, value, weight); + EvictIfNeeded(); + return value; + } + + /// Insert or update an entry. If the key already exists, the old value is replaced + /// and the REPLACED callback is invoked. Triggers eviction if needed. + /// @return Status::Invalid if the entry's weight exceeds max_weight, Status::OK otherwise. + Status Put(const K& key, V value) { + int64_t weight = ComputeWeight(key, value); + if (weight > options_.max_weight) { + return Status::Invalid( + fmt::format("Entry weight {} exceeds cache max weight {}, entry will not be cached", + weight, options_.max_weight)); + } + + std::unique_lock lock(mutex_); + auto it = lru_map_.find(key); + if (it != lru_map_.end()) { + if (ValuesEqual(it->second->value, value)) { + Promote(it->second); + return Status::OK(); + } + ReplaceEntry(it->second, std::move(value), weight); + } else { + InsertEntry(key, std::move(value), weight); + } + + EvictIfNeeded(); + return Status::OK(); + } + + /// Remove a specific entry. Invokes the EXPLICIT removal callback if the key exists. + void Invalidate(const K& key) { + std::unique_lock lock(mutex_); + auto it = lru_map_.find(key); + if (it != lru_map_.end()) { + RemoveEntry(it->second, RemovalCause::EXPLICIT); + } + } + + /// Remove all entries. Each entry's EXPLICIT removal callback is invoked. + void InvalidateAll() { + std::unique_lock lock(mutex_); + while (!lru_list_.empty()) { + RemoveEntry(std::prev(lru_list_.end()), RemovalCause::EXPLICIT); + } + current_weight_ = 0; + } + + /// @return The number of entries currently in the cache. + size_t Size() const { + std::shared_lock lock(mutex_); + return lru_map_.size(); + } + + /// @return The current total weight of all entries. + int64_t GetCurrentWeight() const { + std::shared_lock lock(mutex_); + return current_weight_; + } + + /// @return The maximum weight configured for this cache. + int64_t GetMaxWeight() const { + return options_.max_weight; + } + + private: + struct CacheEntry { + K key; + V value; + int64_t weight; + std::chrono::steady_clock::time_point last_access_time; + }; + + using EntryList = std::list; + using EntryMap = std::unordered_map; + + /// Look up a key, promote if found and not expired, or remove if expired. + /// Must be called with mutex_ held. + /// @return The value if found and valid, std::nullopt otherwise. + std::optional FindPromoteOrExpire(const K& key) { + auto it = lru_map_.find(key); + if (it == lru_map_.end()) { + return std::nullopt; + } + auto list_it = it->second; + if (IsExpired(list_it->last_access_time)) { + RemoveEntry(list_it, RemovalCause::EXPIRED); + return std::nullopt; + } + Promote(list_it); + return list_it->value; + } + + /// Move an entry to the front of the LRU list and update its access time. + void Promote(typename EntryList::iterator list_it) { + list_it->last_access_time = std::chrono::steady_clock::now(); + lru_list_.splice(lru_list_.begin(), lru_list_, list_it); + } + + /// Insert a new entry at the front of the LRU list. + void InsertEntry(const K& key, V value, int64_t weight) { + lru_list_.emplace_front( + CacheEntry{key, std::move(value), weight, std::chrono::steady_clock::now()}); + lru_map_[key] = lru_list_.begin(); + current_weight_ += weight; + } + + /// Compare two values for equality. For pointers, compares the underlying + /// pointer first, then dereferences and compares the pointed-to objects. + /// For other types, uses operator==. + static bool ValuesEqual(const V& lhs, const V& rhs) { + if constexpr (is_pointer::value) { + if (lhs == rhs) { + return true; + } + if (!lhs || !rhs) { + return false; + } + return *lhs == *rhs; + } else { + return lhs == rhs; + } + } + + /// Replace the value of an existing entry, invoke the REPLACED callback for the old value, + /// and promote the entry to the front. + void ReplaceEntry(typename EntryList::iterator list_it, V new_value, int64_t new_weight) { + current_weight_ -= list_it->weight; + + K key = list_it->key; + V old_value = std::move(list_it->value); + list_it->value = std::move(new_value); + list_it->weight = new_weight; + list_it->last_access_time = std::chrono::steady_clock::now(); + current_weight_ += new_weight; + lru_list_.splice(lru_list_.begin(), lru_list_, list_it); + + InvokeCallback(key, old_value, RemovalCause::REPLACED); + } + + /// Remove an entry from the cache and invoke the removal callback. + void RemoveEntry(typename EntryList::iterator list_it, RemovalCause cause) { + lru_map_.erase(list_it->key); + K key = std::move(list_it->key); + V value = std::move(list_it->value); + current_weight_ -= list_it->weight; + lru_list_.erase(list_it); + InvokeCallback(key, value, cause); + } + + /// Evict expired entries from the tail, then evict by weight if still over capacity. + void EvictIfNeeded() { + EvictExpired(); + while (current_weight_ > options_.max_weight && !lru_list_.empty()) { + RemoveEntry(std::prev(lru_list_.end()), RemovalCause::SIZE); + } + } + + /// Evict expired entries from the tail of the LRU list. + /// Since the tail has the oldest access time, we can stop as soon as we find + /// a non-expired entry. + void EvictExpired() { + if (options_.expire_after_access_ms < 0) { + return; + } + auto now = std::chrono::steady_clock::now(); + while (!lru_list_.empty()) { + auto it = std::prev(lru_list_.end()); + if (!IsExpired(it->last_access_time, now)) { + break; + } + RemoveEntry(it, RemovalCause::EXPIRED); + } + } + + /// Compute the weight of an entry using the configured weigh function. + int64_t ComputeWeight(const K& key, const V& value) const { + if (options_.weigh_func) { + return options_.weigh_func(key, value); + } + return 1; + } + + /// Check if an entry has expired based on its last access time. + bool IsExpired( + const std::chrono::steady_clock::time_point& last_access_time, + const std::chrono::steady_clock::time_point& now = std::chrono::steady_clock::now()) const { + if (options_.expire_after_access_ms < 0) { + return false; + } + auto elapsed = + std::chrono::duration_cast(now - last_access_time); + return elapsed.count() >= options_.expire_after_access_ms; + } + + /// Invoke the removal callback if one is configured. + void InvokeCallback(const K& key, const V& value, RemovalCause cause) { + if (options_.removal_callback) { + options_.removal_callback(key, value, cause); + } + } + + Options options_; + int64_t current_weight_ = 0; + EntryList lru_list_; + EntryMap lru_map_; + mutable std::shared_mutex mutex_; +}; + +} // namespace paimon diff --git a/src/paimon/common/utils/generic_lru_cache_test.cpp b/src/paimon/common/utils/generic_lru_cache_test.cpp new file mode 100644 index 0000000..318f115 --- /dev/null +++ b/src/paimon/common/utils/generic_lru_cache_test.cpp @@ -0,0 +1,893 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "paimon/common/utils/generic_lru_cache.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +class GenericLruCacheTest : public ::testing::Test { + public: + using StringIntCache = GenericLruCache; + using StringStringCache = GenericLruCache; + using IntIntCache = GenericLruCache; + using StringSharedPtrCache = GenericLruCache>; + using RemovalCause = StringIntCache::RemovalCause; + + struct RemovalRecord { + std::string key; + std::string value; + RemovalCause cause; + }; +}; + +// ==================== Basic Operations ==================== + +TEST_F(GenericLruCacheTest, ConstructorAndDefaults) { + { + StringIntCache::Options options; + StringIntCache cache(options); + + ASSERT_EQ(cache.Size(), 0); + ASSERT_EQ(cache.GetCurrentWeight(), 0); + ASSERT_EQ(cache.GetMaxWeight(), INT64_MAX); + } + { + StringIntCache::Options options; + options.max_weight = 42; + StringIntCache cache(options); + ASSERT_EQ(cache.GetMaxWeight(), 42); + } +} + +// ==================== GetIfPresent ==================== + +TEST_F(GenericLruCacheTest, GetIfPresentMissAndHit) { + StringIntCache::Options options; + StringIntCache cache(options); + + auto result = cache.GetIfPresent("nonexistent"); + ASSERT_FALSE(result.has_value()); + + ASSERT_OK(cache.Put("key1", 100)); + result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 100); +} + +TEST_F(GenericLruCacheTest, GetIfPresentExpired) { + std::vector removals; + StringIntCache::Options options; + options.expire_after_access_ms = 50; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + ASSERT_EQ(cache.Size(), 1); + + std::this_thread::sleep_for(std::chrono::milliseconds(80)); + + auto result = cache.GetIfPresent("key1"); + ASSERT_FALSE(result.has_value()); + ASSERT_EQ(cache.Size(), 0); + + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "key1"); + ASSERT_EQ(removals[0].cause, RemovalCause::EXPIRED); +} + +TEST_F(GenericLruCacheTest, GetIfPresentPromotesEntry) { + std::vector removals; + StringIntCache::Options options; + options.max_weight = 2; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + + // Access "a" to promote it + auto result = cache.GetIfPresent("a"); + ASSERT_TRUE(result.has_value()); + + // Insert "c": should evict "b" (LRU), not "a" + ASSERT_OK(cache.Put("c", 3)); + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "b"); + ASSERT_EQ(removals[0].cause, RemovalCause::SIZE); +} + +// ==================== Get with supplier ==================== + +TEST_F(GenericLruCacheTest, GetCacheMissLoadsViaSupplier) { + StringIntCache::Options options; + StringIntCache cache(options); + + int32_t supplier_calls = 0; + auto supplier = [&](const std::string& key) -> Result { + supplier_calls++; + return 42; + }; + + ASSERT_OK_AND_ASSIGN(auto value, cache.Get("key1", supplier)); + ASSERT_EQ(value, 42); + ASSERT_EQ(supplier_calls, 1); + ASSERT_EQ(cache.Size(), 1); +} + +TEST_F(GenericLruCacheTest, GetCacheHitSkipsSupplier) { + StringIntCache::Options options; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + + int32_t supplier_calls = 0; + auto supplier = [&](const std::string& key) -> Result { + supplier_calls++; + return 999; + }; + + ASSERT_OK_AND_ASSIGN(auto value, cache.Get("key1", supplier)); + ASSERT_EQ(value, 100); + ASSERT_EQ(supplier_calls, 0); +} + +TEST_F(GenericLruCacheTest, GetSupplierError) { + StringIntCache::Options options; + StringIntCache cache(options); + + auto supplier = [](const std::string& key) -> Result { + return Status::IOError("load failed"); + }; + + ASSERT_NOK_WITH_MSG(cache.Get("key1", supplier), "load failed"); + ASSERT_EQ(cache.Size(), 0); +} + +TEST_F(GenericLruCacheTest, GetWeightExceedsMaxReturnsWithoutCaching) { + StringStringCache::Options options; + options.max_weight = 5; + options.weigh_func = [](const std::string& key, const std::string& value) -> int64_t { + return static_cast(value.size()); + }; + StringStringCache cache(options); + + auto supplier = [](const std::string& key) -> Result { + return std::string("this_is_a_very_long_value"); + }; + + ASSERT_OK_AND_ASSIGN(auto value, cache.Get("key1", supplier)); + ASSERT_EQ(value, "this_is_a_very_long_value"); + ASSERT_EQ(cache.Size(), 0); +} + +TEST_F(GenericLruCacheTest, GetTriggersEviction) { + std::vector removals; + StringIntCache::Options options; + options.max_weight = 2; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + ASSERT_EQ(cache.Size(), 2); + + auto supplier = [](const std::string& key) -> Result { return 3; }; + ASSERT_OK_AND_ASSIGN(auto value, cache.Get("c", supplier)); + ASSERT_EQ(value, 3); + ASSERT_EQ(cache.Size(), 2); + + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "a"); + ASSERT_EQ(removals[0].cause, RemovalCause::SIZE); +} + +// ==================== Put ==================== + +TEST_F(GenericLruCacheTest, PutNewEntry) { + StringIntCache::Options options; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + ASSERT_EQ(cache.Size(), 1); + ASSERT_EQ(cache.GetCurrentWeight(), 1); + + auto result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 100); +} + +TEST_F(GenericLruCacheTest, PutReplaceWithDifferentValue) { + std::vector removals; + StringIntCache::Options options; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + ASSERT_OK(cache.Put("key1", 200)); + ASSERT_EQ(cache.Size(), 1); + + auto result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 200); + + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "key1"); + ASSERT_EQ(removals[0].value, "100"); + ASSERT_EQ(removals[0].cause, RemovalCause::REPLACED); +} + +TEST_F(GenericLruCacheTest, PutReplaceWithSameValuePromotes) { + std::vector removals; + StringIntCache::Options options; + options.max_weight = 2; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + + // Put same value for "a" — should promote without REPLACED callback + ASSERT_OK(cache.Put("a", 1)); + ASSERT_TRUE(removals.empty()); + + // Insert "c": should evict "b" (LRU after "a" was promoted) + ASSERT_OK(cache.Put("c", 3)); + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "b"); + ASSERT_EQ(removals[0].cause, RemovalCause::SIZE); +} + +TEST_F(GenericLruCacheTest, PutWeightExceedsMaxReturnsInvalid) { + StringStringCache::Options options; + options.max_weight = 5; + options.weigh_func = [](const std::string& key, const std::string& value) -> int64_t { + return static_cast(value.size()); + }; + StringStringCache cache(options); + + ASSERT_NOK_WITH_MSG(cache.Put("key1", "this_is_too_long"), + "Entry weight 16 exceeds cache max weight 5, entry will not be cached"); + ASSERT_EQ(cache.Size(), 0); +} + +TEST_F(GenericLruCacheTest, PutTriggersWeightEviction) { + std::vector removals; + StringStringCache::Options options; + options.max_weight = 10; + options.weigh_func = [](const std::string& key, const std::string& value) -> int64_t { + return static_cast(value.size()); + }; + options.removal_callback = [&](const std::string& key, const std::string& value, auto cause) { + removals.push_back({key, value, static_cast(cause)}); + }; + StringStringCache cache(options); + + ASSERT_OK(cache.Put("a", "aaaa")); // weight 4 + ASSERT_OK(cache.Put("b", "bbbbb")); // weight 5, total 9 + ASSERT_EQ(cache.GetCurrentWeight(), 9); + + // Insert "c" with weight 5: total would be 14 > 10, evict "a" (4), total becomes 10 + ASSERT_OK(cache.Put("c", "ccccc")); + ASSERT_EQ(cache.Size(), 2); + ASSERT_EQ(cache.GetCurrentWeight(), 10); + + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "a"); + ASSERT_EQ(removals[0].cause, RemovalCause::SIZE); +} + +TEST_F(GenericLruCacheTest, PutMultipleEvictions) { + std::vector evicted_keys; + StringStringCache::Options options; + options.max_weight = 10; + options.weigh_func = [](const std::string& key, const std::string& value) -> int64_t { + return static_cast(value.size()); + }; + options.removal_callback = [&](const std::string& key, const std::string& value, auto cause) { + evicted_keys.push_back(key); + }; + StringStringCache cache(options); + + ASSERT_OK(cache.Put("a", "aaa")); // weight 3 + ASSERT_OK(cache.Put("b", "bbb")); // weight 3 + ASSERT_OK(cache.Put("c", "ccc")); // weight 3, total 9 + + // Insert "d" with weight 9: total would be 18 > 10, evict a(3), b(3), c(3) then add d(9) + ASSERT_OK(cache.Put("d", "ddddddddd")); + ASSERT_EQ(cache.Size(), 1); + ASSERT_EQ(cache.GetCurrentWeight(), 9); + + ASSERT_EQ(evicted_keys.size(), 3); + ASSERT_EQ(evicted_keys[0], "a"); + ASSERT_EQ(evicted_keys[1], "b"); + ASSERT_EQ(evicted_keys[2], "c"); +} + +// ==================== Invalidate ==================== + +TEST_F(GenericLruCacheTest, InvalidateExistingKey) { + std::vector removals; + StringIntCache::Options options; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + ASSERT_EQ(cache.Size(), 1); + + cache.Invalidate("key1"); + ASSERT_EQ(cache.Size(), 0); + ASSERT_EQ(cache.GetCurrentWeight(), 0); + + ASSERT_EQ(removals.size(), 1); + ASSERT_EQ(removals[0].key, "key1"); + ASSERT_EQ(removals[0].cause, RemovalCause::EXPLICIT); +} + +TEST_F(GenericLruCacheTest, InvalidateNonExistentKey) { + StringIntCache::Options options; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + cache.Invalidate("nonexistent"); + ASSERT_EQ(cache.Size(), 1); +} + +// ==================== InvalidateAll ==================== + +TEST_F(GenericLruCacheTest, InvalidateAllClearsEverything) { + std::vector removals; + StringIntCache::Options options; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + ASSERT_OK(cache.Put("c", 3)); + ASSERT_EQ(cache.Size(), 3); + + cache.InvalidateAll(); + ASSERT_EQ(cache.Size(), 0); + ASSERT_EQ(cache.GetCurrentWeight(), 0); + + ASSERT_EQ(removals.size(), 3); + for (const auto& record : removals) { + ASSERT_EQ(record.cause, RemovalCause::EXPLICIT); + } +} + +TEST_F(GenericLruCacheTest, InvalidateAllOnEmptyCache) { + StringIntCache::Options options; + StringIntCache cache(options); + + cache.InvalidateAll(); + ASSERT_EQ(cache.Size(), 0); + ASSERT_EQ(cache.GetCurrentWeight(), 0); +} + +// ==================== Weight Function ==================== + +TEST_F(GenericLruCacheTest, DefaultWeightIsOne) { + StringIntCache::Options options; + options.max_weight = 3; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + ASSERT_OK(cache.Put("c", 3)); + ASSERT_EQ(cache.GetCurrentWeight(), 3); + ASSERT_EQ(cache.Size(), 3); + + // Adding one more should evict the LRU entry + ASSERT_OK(cache.Put("d", 4)); + ASSERT_EQ(cache.Size(), 3); + ASSERT_EQ(cache.GetCurrentWeight(), 3); + ASSERT_FALSE(cache.GetIfPresent("a").has_value()); +} + +TEST_F(GenericLruCacheTest, WeightUpdatedOnReplace) { + StringStringCache::Options options; + options.max_weight = 100; + options.weigh_func = [](const std::string& key, const std::string& value) -> int64_t { + return static_cast(value.size()); + }; + StringStringCache cache(options); + + ASSERT_OK(cache.Put("a", std::string(30, 'x'))); + ASSERT_EQ(cache.GetCurrentWeight(), 30); + + // Replace with larger value + ASSERT_OK(cache.Put("a", std::string(70, 'y'))); + ASSERT_EQ(cache.GetCurrentWeight(), 70); + ASSERT_EQ(cache.Size(), 1); +} + +// ==================== Expiration ==================== + +TEST_F(GenericLruCacheTest, ExpirationDisabledByDefault) { + StringIntCache::Options options; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + auto result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 100); +} + +TEST_F(GenericLruCacheTest, ExpirationOnGet) { + StringIntCache::Options options; + options.expire_after_access_ms = 50; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + + // Access before expiration + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + auto result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + + // Wait for expiration + std::this_thread::sleep_for(std::chrono::milliseconds(80)); + result = cache.GetIfPresent("key1"); + ASSERT_FALSE(result.has_value()); + ASSERT_EQ(cache.Size(), 0); +} + +TEST_F(GenericLruCacheTest, ExpirationOnGetWithSupplier) { + StringIntCache::Options options; + options.expire_after_access_ms = 50; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + + std::this_thread::sleep_for(std::chrono::milliseconds(80)); + + int32_t supplier_calls = 0; + auto supplier = [&](const std::string& key) -> Result { + supplier_calls++; + return 200; + }; + + ASSERT_OK_AND_ASSIGN(auto value, cache.Get("key1", supplier)); + ASSERT_EQ(value, 200); + ASSERT_EQ(supplier_calls, 1); +} + +TEST_F(GenericLruCacheTest, AccessResetsExpirationTimer) { + StringIntCache::Options options; + options.expire_after_access_ms = 100; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + + // Access at 40ms to reset the timer + std::this_thread::sleep_for(std::chrono::milliseconds(40)); + auto result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + + // At 80ms from last access (40ms from the GetIfPresent), should still be valid + std::this_thread::sleep_for(std::chrono::milliseconds(40)); + result = cache.GetIfPresent("key1"); + ASSERT_TRUE(result.has_value()); + + // Wait for full expiration from last access + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + result = cache.GetIfPresent("key1"); + ASSERT_FALSE(result.has_value()); +} + +TEST_F(GenericLruCacheTest, ExpiredEntriesEvictedOnPut) { + std::vector removals; + StringIntCache::Options options; + options.expire_after_access_ms = 50; + options.max_weight = 100; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + + std::this_thread::sleep_for(std::chrono::milliseconds(80)); + + // Put triggers EvictIfNeeded which calls EvictExpired + ASSERT_OK(cache.Put("c", 3)); + + // "a" and "b" should have been expired + int32_t expired_count = 0; + for (const auto& record : removals) { + if (record.cause == RemovalCause::EXPIRED) { + expired_count++; + } + } + ASSERT_EQ(expired_count, 2); + ASSERT_EQ(cache.Size(), 1); +} + +// ==================== Removal Callback ==================== + +TEST_F(GenericLruCacheTest, NoCallbackConfigured) { + StringIntCache::Options options; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("key1", 100)); + cache.Invalidate("key1"); + ASSERT_EQ(cache.Size(), 0); +} + +TEST_F(GenericLruCacheTest, AllRemovalCauses) { + std::vector removals; + StringIntCache::Options options; + options.max_weight = 2; + options.expire_after_access_ms = 50; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + removals.push_back({key, std::to_string(value), static_cast(cause)}); + }; + StringIntCache cache(options); + + // REPLACED: put same key with different value + ASSERT_OK(cache.Put("r", 1)); + ASSERT_OK(cache.Put("r", 2)); + ASSERT_EQ(removals.back().cause, RemovalCause::REPLACED); + + // EXPLICIT: invalidate + cache.Invalidate("r"); + ASSERT_EQ(removals.back().cause, RemovalCause::EXPLICIT); + + // SIZE: evict due to weight + ASSERT_OK(cache.Put("s1", 10)); + ASSERT_OK(cache.Put("s2", 20)); + ASSERT_OK(cache.Put("s3", 30)); + ASSERT_EQ(removals.back().cause, RemovalCause::SIZE); + + // EXPIRED: wait and access + cache.InvalidateAll(); + removals.clear(); + ASSERT_OK(cache.Put("e", 99)); + std::this_thread::sleep_for(std::chrono::milliseconds(80)); + ASSERT_FALSE(cache.GetIfPresent("e").has_value()); + ASSERT_EQ(removals.back().cause, RemovalCause::EXPIRED); +} + +// ==================== valuesequal with shared_ptr ==================== + +TEST_F(GenericLruCacheTest, SharedPtrSamePointerNoReplace) { + using Cause = StringSharedPtrCache::RemovalCause; + std::vector causes; + StringSharedPtrCache::Options options; + options.removal_callback = [&](const std::string& key, const std::shared_ptr& value, + auto cause) { causes.push_back(static_cast(cause)); }; + StringSharedPtrCache cache(options); + + auto ptr = std::make_shared(42); + ASSERT_OK(cache.Put("key1", ptr)); + + // Put same pointer — ValuesEqual returns true, should promote without REPLACED + ASSERT_OK(cache.Put("key1", ptr)); + ASSERT_TRUE(causes.empty()); + ASSERT_EQ(cache.Size(), 1); +} + +TEST_F(GenericLruCacheTest, SharedPtrDifferentPointerSameValueNoReplace) { + using Cause = StringSharedPtrCache::RemovalCause; + std::vector causes; + StringSharedPtrCache::Options options; + options.removal_callback = [&](const std::string& key, const std::shared_ptr& value, + auto cause) { causes.push_back(static_cast(cause)); }; + StringSharedPtrCache cache(options); + + auto ptr1 = std::make_shared(42); + auto ptr2 = std::make_shared(42); + ASSERT_NE(ptr1.get(), ptr2.get()); + + ASSERT_OK(cache.Put("key1", ptr1)); + // Different pointer but same dereferenced value — ValuesEqual returns true + ASSERT_OK(cache.Put("key1", ptr2)); + ASSERT_TRUE(causes.empty()); +} + +TEST_F(GenericLruCacheTest, SharedPtrDifferentValueReplaces) { + using Cause = StringSharedPtrCache::RemovalCause; + std::vector causes; + StringSharedPtrCache::Options options; + options.removal_callback = [&](const std::string& key, const std::shared_ptr& value, + auto cause) { causes.push_back(static_cast(cause)); }; + StringSharedPtrCache cache(options); + + ASSERT_OK(cache.Put("key1", std::make_shared(1))); + ASSERT_OK(cache.Put("key1", std::make_shared(2))); + + ASSERT_EQ(causes.size(), 1); + ASSERT_EQ(causes[0], Cause::REPLACED); +} + +TEST_F(GenericLruCacheTest, SharedPtrNullptrComparison) { + using Cause = StringSharedPtrCache::RemovalCause; + std::vector causes; + StringSharedPtrCache::Options options; + options.removal_callback = [&](const std::string& key, const std::shared_ptr& value, + auto cause) { causes.push_back(static_cast(cause)); }; + StringSharedPtrCache cache(options); + + // Put nullptr + ASSERT_OK(cache.Put("key1", nullptr)); + + // Put nullptr again — same value, should not replace + ASSERT_OK(cache.Put("key1", nullptr)); + ASSERT_TRUE(causes.empty()); + + // Put non-null — different from nullptr, should replace + ASSERT_OK(cache.Put("key1", std::make_shared(1))); + ASSERT_EQ(causes.size(), 1); + ASSERT_EQ(causes[0], Cause::REPLACED); + + // Put nullptr again — different from non-null, should replace + causes.clear(); + ASSERT_OK(cache.Put("key1", nullptr)); + ASSERT_EQ(causes.size(), 1); + ASSERT_EQ(causes[0], Cause::REPLACED); +} + +// ==================== Custom Hash and KeyEqual ==================== + +TEST_F(GenericLruCacheTest, CustomHashAndKeyEqual) { + struct CaseInsensitiveHash { + size_t operator()(const std::string& str) const { + std::string lower = str; + for (auto& ch : lower) { + ch = static_cast(std::tolower(ch)); + } + return std::hash{}(lower); + } + }; + struct CaseInsensitiveEqual { + bool operator()(const std::string& lhs, const std::string& rhs) const { + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); i++) { + if (std::tolower(lhs[i]) != std::tolower(rhs[i])) return false; + } + return true; + } + }; + + using CICache = GenericLruCache; + CICache::Options options; + CICache cache(options); + + ASSERT_OK(cache.Put("Hello", 1)); + ASSERT_EQ(cache.Size(), 1); + + // "hello" should match "Hello" with case-insensitive comparison + auto result = cache.GetIfPresent("hello"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 1); + + // Put with different case should replace + ASSERT_OK(cache.Put("HELLO", 2)); + ASSERT_EQ(cache.Size(), 1); + + result = cache.GetIfPresent("Hello"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 2); +} + +// ==================== Thread Safety ==================== + +TEST_F(GenericLruCacheTest, ConcurrentPutAndGet) { + IntIntCache::Options options; + options.max_weight = 10000; + IntIntCache cache(options); + + constexpr int32_t num_threads = 8; + constexpr int32_t ops_per_thread = 200; + + std::vector threads; + std::atomic errors{0}; + + for (int32_t t = 0; t < num_threads; t++) { + threads.emplace_back([&, t]() { + for (int32_t i = 0; i < ops_per_thread; i++) { + int32_t key = t * ops_per_thread + i; + auto status = cache.Put(key, key * 10); + if (!status.ok()) { + errors++; + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + ASSERT_EQ(errors.load(), 0); + ASSERT_EQ(static_cast(cache.Size()), num_threads * ops_per_thread); + + // Concurrent reads + threads.clear(); + for (int32_t t = 0; t < num_threads; t++) { + threads.emplace_back([&, t]() { + for (int32_t i = 0; i < ops_per_thread; i++) { + int32_t key = t * ops_per_thread + i; + auto result = cache.GetIfPresent(key); + if (!result.has_value() || result.value() != key * 10) { + errors++; + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + ASSERT_EQ(errors.load(), 0); +} + +TEST_F(GenericLruCacheTest, ConcurrentGetWithSupplier) { + IntIntCache::Options options; + options.max_weight = 10000; + IntIntCache cache(options); + + constexpr int32_t num_threads = 8; + constexpr int32_t ops_per_thread = 100; + + std::atomic supplier_calls{0}; + std::vector threads; + + for (int32_t t = 0; t < num_threads; t++) { + threads.emplace_back([&, t]() { + for (int32_t i = 0; i < ops_per_thread; i++) { + int32_t key = t * ops_per_thread + i; + auto supplier = [&, key](const int&) -> Result { + supplier_calls++; + return key * 10; + }; + auto result = cache.Get(key, supplier); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(result.value(), key * 10); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + ASSERT_EQ(static_cast(cache.Size()), num_threads * ops_per_thread); +} + +TEST_F(GenericLruCacheTest, ConcurrentInvalidate) { + IntIntCache::Options options; + IntIntCache cache(options); + + for (int32_t i = 0; i < 100; i++) { + ASSERT_OK(cache.Put(i, i)); + } + + std::vector threads; + for (int32_t t = 0; t < 4; t++) { + threads.emplace_back([&, t]() { + for (int32_t i = t * 25; i < (t + 1) * 25; i++) { + cache.Invalidate(i); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + ASSERT_EQ(cache.Size(), 0); +} + +// ==================== Edge Cases ==================== + +TEST_F(GenericLruCacheTest, PutAndGetSingleEntry) { + StringIntCache::Options options; + options.max_weight = 1; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("only", 42)); + ASSERT_EQ(cache.Size(), 1); + + auto result = cache.GetIfPresent("only"); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), 42); + + // Adding another entry should evict the first + ASSERT_OK(cache.Put("new", 99)); + ASSERT_EQ(cache.Size(), 1); + ASSERT_FALSE(cache.GetIfPresent("only").has_value()); + ASSERT_TRUE(cache.GetIfPresent("new").has_value()); +} + +TEST_F(GenericLruCacheTest, ReplaceUpdatesWeight) { + StringStringCache::Options options; + options.max_weight = 100; + options.weigh_func = [](const std::string& key, const std::string& value) -> int64_t { + return static_cast(value.size()); + }; + StringStringCache cache(options); + + ASSERT_OK(cache.Put("a", std::string(50, 'x'))); + ASSERT_EQ(cache.GetCurrentWeight(), 50); + + // Replace with smaller value + ASSERT_OK(cache.Put("a", std::string(20, 'y'))); + ASSERT_EQ(cache.GetCurrentWeight(), 20); + + // Replace with larger value + ASSERT_OK(cache.Put("a", std::string(80, 'z'))); + ASSERT_EQ(cache.GetCurrentWeight(), 80); +} + +TEST_F(GenericLruCacheTest, EvictionOrderIsLru) { + std::vector evicted_keys; + StringIntCache::Options options; + options.max_weight = 3; + options.removal_callback = [&](const std::string& key, const int& value, auto cause) { + if (static_cast(cause) == RemovalCause::SIZE) { + evicted_keys.push_back(key); + } + }; + StringIntCache cache(options); + + ASSERT_OK(cache.Put("a", 1)); + ASSERT_OK(cache.Put("b", 2)); + ASSERT_OK(cache.Put("c", 3)); + + // Access "a" and "b" to make "c" the LRU + cache.GetIfPresent("a"); + cache.GetIfPresent("b"); + + // Insert "d": should evict "c" (LRU) + ASSERT_OK(cache.Put("d", 4)); + ASSERT_EQ(evicted_keys.size(), 1); + ASSERT_EQ(evicted_keys[0], "c"); +} + +} // namespace paimon::test diff --git a/src/paimon/common/utils/murmurhash_utils.h b/src/paimon/common/utils/murmurhash_utils.h new file mode 100644 index 0000000..cebf41b --- /dev/null +++ b/src/paimon/common/utils/murmurhash_utils.h @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * xxHash - Extremely Fast Hash algorithm + * Header File + * Copyright (C) 2012-2023 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +// MMH_rotl32 utility is adapted from xxHash +// https://github.com/Cyan4973/xxHash/blob/dev/xxhash.h + +#pragma once +#include +#include +#include +#include + +#include "paimon/common/memory/memory_segment.h" +#include "paimon/memory/bytes.h" + +namespace paimon { + +#ifdef __has_builtin +#define MMH_HAS_BUILTIN(x) __has_builtin(x) +#else +#define MMH_HAS_BUILTIN(x) 0 +#endif +/*! + * @internal + * @def MMH_rotl32(x,r) + * @brief 32-bit rotate left. + * + * @param x The 32-bit integer to be rotated. + * @param r The number of bits to rotate. + * @pre + * @p r > 0 && @p r < 32 + * @note + * @p x and @p r may be evaluated multiple times. + * @return The rotated result. + */ +#if !defined(NO_CLANG_BUILTIN) && MMH_HAS_BUILTIN(__builtin_rotateleft32) && \ + MMH_HAS_BUILTIN(__builtin_rotateleft64) +#define MMH_rotl32 __builtin_rotateleft32 +/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */ +#elif defined(_MSC_VER) +#define MMH_rotl32(x, r) _rotl(x, r) +#else +#define MMH_rotl32(x, r) (((x) << (r)) | ((x) >> (32 - (r)))) +#endif + +class MurmurHashUtils { + public: + static constexpr int32_t DEFAULT_SEED = 42; + + MurmurHashUtils() = delete; + + /// Hash unsafe bytes, length must be aligned to 4 bytes. + /// + /// @param base base unsafe object + /// @param offset offset for unsafe object + /// @param length_in_bytes length in bytes + /// @return hash code + static int32_t HashUnsafeBytesByWords(const void* base, int64_t offset, + int32_t length_in_bytes) { + return HashUnsafeBytesByWords(base, offset, length_in_bytes, DEFAULT_SEED); + } + + /// Hash bytes. + static int32_t HashBytesPositive(const std::shared_ptr& bytes) { + return HashBytes(bytes) & 0x7fffffff; + } + + /// Hash bytes. + static int32_t HashBytes(const std::shared_ptr& bytes) { + return HashUnsafeBytes(reinterpret_cast(bytes->data()), 0, bytes->size(), + DEFAULT_SEED); + } + + static int32_t HashUnsafeBytes(const void* base, int64_t offset, int32_t length_in_bytes, + int32_t seed) { + assert(length_in_bytes >= 0); + int32_t length_aligned = length_in_bytes - length_in_bytes % 4; + int32_t h1 = HashUnsafeBytesByInt(base, offset, length_aligned, seed); + for (int32_t i = length_aligned; i < length_in_bytes; i++) { + int32_t half_word = GetByte(base, offset + i); + int32_t k1 = MixK1(half_word); + h1 = MixH1(h1, k1); + } + return Fmix(h1, length_in_bytes); + } + + /// Hash unsafe bytes. + /// + /// @param base base unsafe object + /// @param offset offset for unsafe object + /// @param length_in_bytes length in bytes + /// @return hash code + static int32_t HashUnsafeBytes(const void* base, int64_t offset, int32_t length_in_bytes) { + return HashUnsafeBytes(base, offset, length_in_bytes, DEFAULT_SEED); + } + + /// Hash bytes in MemorySegment, length must be aligned to 4 bytes. + /// + /// @param segment segment. + /// @param offset offset for MemorySegment + /// @param length_in_bytes length in MemorySegment + /// @return hash code + static int32_t HashBytesByWords(const MemorySegment& segment, int32_t offset, + int32_t length_in_bytes) { + return HashBytesByWords(segment, offset, length_in_bytes, DEFAULT_SEED); + } + + /// Hash bytes in MemorySegment. + /// + /// @param segment segment. + /// @param offset offset for MemorySegment + /// @param length_in_bytes length in MemorySegment + /// @return hash code + static int32_t HashBytes(const MemorySegment& segment, int32_t offset, + int32_t length_in_bytes) { + return HashBytes(segment, offset, length_in_bytes, DEFAULT_SEED); + } + + private: + static int32_t HashUnsafeBytesByWords(const void* base, int64_t offset, int32_t length_in_bytes, + int32_t seed) { + int32_t h1 = HashUnsafeBytesByInt(base, offset, length_in_bytes, seed); + return Fmix(h1, length_in_bytes); + } + + static int32_t HashBytesByWords(const MemorySegment& segment, int32_t offset, + int32_t length_in_bytes, int32_t seed) { + int32_t h1 = HashBytesByInt(segment, offset, length_in_bytes, seed); + return Fmix(h1, length_in_bytes); + } + + static int32_t HashBytes(const MemorySegment& segment, int32_t offset, int32_t length_in_bytes, + int32_t seed) { + int32_t length_aligned = length_in_bytes - length_in_bytes % 4; + int32_t h1 = HashBytesByInt(segment, offset, length_aligned, seed); + for (int32_t i = length_aligned; i < length_in_bytes; i++) { + int32_t k1 = MixK1(segment.Get(offset + i)); + h1 = MixH1(h1, k1); + } + return Fmix(h1, length_in_bytes); + } + + static int32_t HashUnsafeBytesByInt(const void* base, int64_t offset, int32_t length_in_bytes, + int32_t seed) { + assert(length_in_bytes % 4 == 0); + int32_t h1 = seed; + for (int32_t i = 0; i < length_in_bytes; i += 4) { + int32_t half_word = GetInt(base, offset + i); + int32_t k1 = MixK1(half_word); + h1 = MixH1(h1, k1); + } + return h1; + } + + static int32_t HashBytesByInt(const MemorySegment& segment, int32_t offset, + int32_t length_in_bytes, int32_t seed) { + assert(length_in_bytes % 4 == 0); + int32_t h1 = seed; + for (int32_t i = 0; i < length_in_bytes; i += 4) { + auto half_word = segment.GetValue(offset + i); + int32_t k1 = MixK1(half_word); + h1 = MixH1(h1, k1); + } + return h1; + } + + static int32_t MixK1(uint32_t k1) { + k1 *= C1; + k1 = MMH_rotl32(k1, 15); + k1 *= C2; + return k1; + } + + static int32_t MixH1(uint32_t h1, uint32_t k1) { + h1 ^= k1; + h1 = MMH_rotl32(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + static int32_t Fmix(uint32_t h1, uint32_t length) { + h1 ^= length; + return Fmix(h1); + } + + static int32_t GetInt(const void* base, int64_t offset) { + int32_t value; + std::memcpy(&value, static_cast(base) + offset, sizeof(int32_t)); + return value; + } + + static char GetByte(const void* base, int64_t offset) { + char value; + std::memcpy(&value, static_cast(base) + offset, sizeof(char)); + return value; + } + + public: + static int32_t Fmix(uint32_t h) { + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + h ^= h >> 16; + return h; + } + + private: + static constexpr int32_t C1 = 0xcc9e2d51; + static constexpr int32_t C2 = 0x1b873593; +}; + +} // namespace paimon diff --git a/src/paimon/common/utils/murmurhash_utils_test.cpp b/src/paimon/common/utils/murmurhash_utils_test.cpp new file mode 100644 index 0000000..b4a475f --- /dev/null +++ b/src/paimon/common/utils/murmurhash_utils_test.cpp @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "paimon/common/utils/murmurhash_utils.h" + +#include + +#include "gtest/gtest.h" +#include "paimon/common/memory/memory_segment_utils.h" +#include "paimon/memory/memory_pool.h" + +namespace paimon::test { +TEST(MurmurHashUtilsTest, TestCompatibleWithJava) { + { + std::vector bytes_value = {3, 10, 20, 30, 40, 50, 67, 89, 111, 51, + 33, 67, 70, 25, 48, 10, 54, 100, 43, 21}; + int32_t num_bytes = bytes_value.size(); + uint32_t expect = 0xb39f33e6; + auto pool = GetDefaultPool(); + std::shared_ptr bytes = Bytes::AllocateBytes(num_bytes, pool.get()); + memcpy(bytes->data(), bytes_value.data(), num_bytes); + MemorySegment segment = MemorySegment::Wrap(bytes); + ASSERT_EQ(expect, + MurmurHashUtils::HashUnsafeBytesByWords(bytes_value.data(), 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashUnsafeBytes(bytes_value.data(), 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashBytesByWords(segment, 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashBytes(segment, 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashBytes(bytes)); + } + { + std::vector bytes_value = {3, 10, 20, 30, 40, 50, 67, 89, + 111, 51, 33, 67, 70, 25, 48, 10}; + int32_t num_bytes = bytes_value.size(); + uint32_t expect = 0x46bdab5b; + auto pool = GetDefaultPool(); + std::shared_ptr bytes = Bytes::AllocateBytes(num_bytes, pool.get()); + memcpy(bytes->data(), bytes_value.data(), num_bytes); + MemorySegment segment = MemorySegment::Wrap(bytes); + ASSERT_EQ(expect, + MurmurHashUtils::HashUnsafeBytesByWords(bytes_value.data(), 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashUnsafeBytes(bytes_value.data(), 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashBytesByWords(segment, 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashBytes(segment, 0, num_bytes)); + ASSERT_EQ(expect, MurmurHashUtils::HashBytes(bytes)); + } + { + // test multi segments with MemorySegmentUtil + std::vector bytes_value1 = {3, 10, 20, 30, 40, 50, 67, 89, 111, 51, 33, 67}; + std::vector bytes_value2 = {70, 25, 48, 10}; + int32_t num_bytes = bytes_value1.size() + bytes_value2.size(); + uint32_t expect = 0x46bdab5b; + auto pool = GetDefaultPool(); + std::shared_ptr bytes1 = Bytes::AllocateBytes(bytes_value1.size(), pool.get()); + memcpy(bytes1->data(), bytes_value1.data(), bytes_value1.size()); + MemorySegment segment1 = MemorySegment::Wrap(bytes1); + std::shared_ptr bytes2 = Bytes::AllocateBytes(bytes_value2.size(), pool.get()); + memcpy(bytes2->data(), bytes_value2.data(), bytes_value2.size()); + MemorySegment segment2 = MemorySegment::Wrap(bytes2); + ASSERT_EQ(expect, + MemorySegmentUtils::HashByWords({segment1, segment2}, 0, num_bytes, pool.get())); + ASSERT_EQ(expect, MemorySegmentUtils::Hash({segment1, segment2}, 0, num_bytes, pool.get())); + } + { + // test not aligned bytes (not aligned with 4 bytes) + std::vector bytes_value = {3, 10, 120, 130, 240, 50, 167, 189, 211, + 51, 233, 167, 170, 25, 148, 10, 226, 19}; + int32_t num_bytes = bytes_value.size(); + uint32_t expect_bytes = 0x94a92b77; + auto pool = GetDefaultPool(); + std::shared_ptr bytes = Bytes::AllocateBytes(num_bytes, pool.get()); + memcpy(bytes->data(), bytes_value.data(), num_bytes); + MemorySegment segment = MemorySegment::Wrap(bytes); + + ASSERT_EQ(expect_bytes, MurmurHashUtils::HashUnsafeBytes(bytes_value.data(), 0, num_bytes)); + ASSERT_EQ(expect_bytes, MurmurHashUtils::HashBytes(segment, 0, num_bytes)); + ASSERT_EQ(expect_bytes, MurmurHashUtils::HashBytes(bytes)); + } + { + // test multi segments with MemorySegmentUtil + std::vector bytes_value1 = {3, 10, 120, 130, 240, 50, 167, 189, 211, 51, 233, 167}; + std::vector bytes_value2 = {170, 25, 148, 10, 226, 19}; + int32_t num_bytes = bytes_value1.size() + bytes_value2.size(); + uint32_t expect_bytes = 0x94a92b77; + + auto pool = GetDefaultPool(); + std::shared_ptr bytes1 = Bytes::AllocateBytes(bytes_value1.size(), pool.get()); + memcpy(bytes1->data(), bytes_value1.data(), bytes_value1.size()); + MemorySegment segment1 = MemorySegment::Wrap(bytes1); + std::shared_ptr bytes2 = Bytes::AllocateBytes(bytes_value2.size(), pool.get()); + memcpy(bytes2->data(), bytes_value2.data(), bytes_value2.size()); + MemorySegment segment2 = MemorySegment::Wrap(bytes2); + + ASSERT_EQ(expect_bytes, + MemorySegmentUtils::Hash({segment1, segment2}, 0, num_bytes, pool.get())); + } +} +} // namespace paimon::test diff --git a/src/paimon/common/utils/preconditions.h b/src/paimon/common/utils/preconditions.h new file mode 100644 index 0000000..f1dfe39 --- /dev/null +++ b/src/paimon/common/utils/preconditions.h @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include + +#include "fmt/base.h" +#include "fmt/core.h" +#include "fmt/format.h" +#include "fmt/ranges.h" +#include "paimon/status.h" + +namespace paimon { +/// A collection of static utility methods to validate input. +/// +/// This class is based on Google Guava's Preconditions class, and partly +/// takes code from that class. We add this code to the Paimon code base in order +/// to reduce external dependencies. +class Preconditions { + public: + Preconditions() = delete; + + public: + /// Checks the given boolean condition, and return Status::Invalid() if + /// the condition is not met (evaluates to `false`). + /// + /// @param condition The condition to check + /// @param format_str The message template for the status + /// if the check fails. The template substitutes its `%%s` + /// placeholders with the error message arguments. + /// @param args The arguments for the error message, to be + /// inserted into the message template for the `%%s` placeholders. + /// @return Status::Invalid(), if the condition is violated. + template + static Status CheckState(bool condition, const std::string& format_str, Args&&... args) { + if (!condition) { + return Status::Invalid( + fmt::format(fmt::runtime(format_str), std::forward(args)...)); + } + return Status::OK(); + } + + /// Ensures that the given object reference is not null. Upon violation, + /// Status::Invalid() with no message is returned. + /// + /// @param reference The object reference + /// @return Status::Invalid(), if the passed reference was null. + template + static Status CheckNotNull(const T& reference) { + if (!reference) { + return Status::Invalid("check null failed"); + } + return Status::OK(); + } + /// Ensures that the given object reference is not null. Upon violation, + /// Status::Invalid() with the given message is returned. + /// + /// The error message is constructed from a template and an arguments + /// array, after a similar fashion as fmt::format(String, + /// Object...)}, but supporting only `%%s` as a placeholder. + /// + /// @param format_str The message template for the + /// Status::Invalid() that is thrown if the check fails. The template + /// substitutes its `%%s` placeholders with the error message arguments. + /// @param args The arguments for the error message, to be + /// inserted into the message template for the `%%s` placeholders. + /// @return Status::Invalid() returned, if the passed reference was null. + template + static Status CheckNotNull(const T& reference, const std::string& format_str, Args&&... args) { + if (!reference) { + return Status::Invalid( + fmt::format(fmt::runtime(format_str), std::forward(args)...)); + } + return Status::OK(); + } + + static Status CheckArgument(bool condition) { + if (!condition) { + return Status::Invalid("invalid argument"); + } + return Status::OK(); + } +}; +} // namespace paimon diff --git a/src/paimon/common/utils/preconditions_test.cpp b/src/paimon/common/utils/preconditions_test.cpp new file mode 100644 index 0000000..a5c9c09 --- /dev/null +++ b/src/paimon/common/utils/preconditions_test.cpp @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "paimon/common/utils/preconditions.h" + +#include "gtest/gtest.h" +#include "paimon/testing/utils/testharness.h" + +namespace paimon::test { + +// Test case: Test CheckState with a valid condition +TEST(PreconditionsTest, CheckStateValidCondition) { + ASSERT_OK(Preconditions::CheckState(true, "Condition failed: {}", "Some error details")); +} + +// Test case: Test CheckState with an invalid condition +TEST(PreconditionsTest, CheckStateInvalidCondition) { + ASSERT_NOK_WITH_MSG( + Preconditions::CheckState(false, "Condition failed: {}", "Some error details"), + "Condition failed: Some error details"); +} + +// Test case: Test CheckNotNull with a non-null reference +TEST(PreconditionsTest, CheckNotNullNonNullReference) { + int32_t x = 10; + ASSERT_OK(Preconditions::CheckNotNull(x)); +} + +// Test case: Test CheckNotNull with a non-null reference with msg +TEST(PreconditionsTest, CheckNotNullNonNullReference2) { + int32_t x = 10; + ASSERT_OK(Preconditions::CheckNotNull(x, "Condition failed")); +} + +// Test case: Test CheckNotNull with a null reference (null pointer) +TEST(PreconditionsTest, CheckNotNullNullReference) { + int* ptr = nullptr; + ASSERT_NOK_WITH_MSG(Preconditions::CheckNotNull(ptr), "check null failed"); +} + +// Test case: Test CheckNotNull with a null reference and a custom message +TEST(PreconditionsTest, CheckNotNullNullReferenceWithMessage) { + int* ptr = nullptr; + ASSERT_NOK_WITH_MSG(Preconditions::CheckNotNull(ptr, "Custom error: pointer is null"), + "Custom error: pointer is null"); +} + +// Test case: Test CheckArgument with a valid condition +TEST(PreconditionsTest, CheckArgumentValidCondition) { + ASSERT_OK(Preconditions::CheckArgument(true)); +} + +// Test case: Test CheckArgument with an invalid condition +TEST(PreconditionsTest, CheckArgumentInvalidCondition) { + ASSERT_NOK_WITH_MSG(Preconditions::CheckArgument(false), "invalid argument"); +} + +} // namespace paimon::test