From 3856052ae36a40b9f747a9b59682434d9cb90702 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 1 Jun 2026 09:49:10 -0400 Subject: [PATCH 1/2] Add max_tries argument to stan::services::util::initialize --- src/stan/services/util/initialize.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/stan/services/util/initialize.hpp b/src/stan/services/util/initialize.hpp index a6467ae394d..c52c331b449 100644 --- a/src/stan/services/util/initialize.hpp +++ b/src/stan/services/util/initialize.hpp @@ -36,8 +36,8 @@ namespace util { * * When at least some of the initialization is random, it will * randomly initialize until it finds a set of unconstrained - * parameters that are valid or it hits MAX_INIT_TRIES = - * 100 (hard-coded). + * parameters that are valid or it hits max_tries, + * which defaults to 100. * * Valid initialization is defined as a finite, non-NaN value for the * evaluation of the log probability density function and all its @@ -58,6 +58,8 @@ namespace util { * be printed to the logger * @param[in,out] logger logger for messages * @param[in,out] init_writer init writer (on the unconstrained scale) + * @param[in] max_tries The maximum number of times a random initialization + * will be re-tried to achieve a finite log density and gradient. Default 100. * @throws exception passed through from the model if the model has a * fatal error (not a std::domain_error) * @throws std::domain_error if the model can not be initialized and @@ -70,7 +72,8 @@ template initialize(Model& model, const InitContext& init, RNG& rng, double init_radius, bool print_timing, stan::callbacks::logger& logger, - stan::callbacks::writer& init_writer) { + stan::callbacks::writer& init_writer, + int max_tries = 100) { std::vector unconstrained; std::vector disc_vector; @@ -86,7 +89,7 @@ std::vector initialize(Model& model, const InitContext& init, RNG& rng, bool is_initialized_with_zero = init_radius == 0.0; int MAX_INIT_TRIES - = is_fully_initialized || is_initialized_with_zero ? 1 : 100; + = is_fully_initialized || is_initialized_with_zero ? 1 : max_tries; int num_init_tries = 0; for (; num_init_tries < MAX_INIT_TRIES; num_init_tries++) { std::stringstream msg; From bd81713f4edb2756d75a42af1e9c4a52a8d4cf62 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 1 Jun 2026 10:07:42 -0400 Subject: [PATCH 2/2] Add test --- .../unit/services/util/initialize_test.cpp | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/src/test/unit/services/util/initialize_test.cpp b/src/test/unit/services/util/initialize_test.cpp index e268d9c20bc..39b1def24f6 100644 --- a/src/test/unit/services/util/initialize_test.cpp +++ b/src/test/unit/services/util/initialize_test.cpp @@ -5,11 +5,14 @@ #include #include #include +#include +#include #include #include #include #include #include +#include class ServicesUtilInitialize : public testing::Test { public: @@ -577,3 +580,100 @@ TEST_F(ServicesUtilInitialize, model_throws_in_write_array__full_init) { EXPECT_EQ(2, logger.call_count_error()); EXPECT_EQ(100, logger.find_warn("throwing within write_array")); } + +namespace test { +// Mock model that returns non-finite lp +class mock_inf_model : public stan::model::prob_grad { + public: + mock_inf_model() + : stan::model::prob_grad(1), + templated_log_prob_calls(0), + transform_inits_calls(0), + write_array_calls(0), + log_prob_return_value(-std::numeric_limits::infinity()) {} + + void reset() { + templated_log_prob_calls = 0; + transform_inits_calls = 0; + write_array_calls = 0; + log_prob_return_value = 0.0; + } + + template + T__ log_prob(std::vector& params_r__, std::vector& params_i__, + std::ostream* pstream__ = 0) const { + ++templated_log_prob_calls; + return log_prob_return_value; + } + + void transform_inits(const stan::io::var_context& context__, + std::vector& params_i__, + std::vector& params_r__, + std::ostream* pstream__) const { + ++transform_inits_calls; + for (size_t n = 0; n < params_r__.size(); ++n) { + params_r__[n] = n; + } + } + + void get_dims(std::vector >& dimss__, + bool include_tparams = true, bool include_gqs = true) const { + dimss__.resize(0); + std::vector scalar_dim; + dimss__.push_back(scalar_dim); + } + + void constrained_param_names(std::vector& param_names__, + bool include_tparams__ = true, + bool include_gqs__ = true) const { + param_names__.push_back("theta"); + } + + void get_param_names(std::vector& names, + bool include_tparams = true, + bool include_gqs = true) const { + constrained_param_names(names); + } + + void unconstrained_param_names(std::vector& param_names__, + bool include_tparams__ = true, + bool include_gqs__ = true) const { + param_names__.clear(); + for (size_t n = 0; n < num_params_r__; ++n) { + std::stringstream param_name; + param_name << "param_" << n; + param_names__.push_back(param_name.str()); + } + } + template + void write_array(RNG& base_rng__, std::vector& params_r__, + std::vector& params_i__, std::vector& vars__, + bool include_tparams__ = true, bool include_gqs__ = true, + std::ostream* pstream__ = 0) const { + ++write_array_calls; + vars__.resize(0); + for (size_t i = 0; i < params_r__.size(); ++i) + vars__.push_back(params_r__[i]); + } + + mutable int templated_log_prob_calls; + mutable int transform_inits_calls; + mutable int write_array_calls; + double log_prob_return_value; +}; +} // namespace test + +TEST_F(ServicesUtilInitialize, model_errors_retries) { + test::mock_inf_model error_model; + + double init_radius = 1.2; + bool print_timing = false; + EXPECT_THROW_MSG(stan::services::util::initialize( + error_model, empty_context, rng, init_radius, + print_timing, logger, init, 23), + std::domain_error, "Initialization failed."); + EXPECT_EQ(23, error_model.templated_log_prob_calls); + EXPECT_EQ(72, logger.call_count()); + EXPECT_EQ(2, logger.call_count_error()); + EXPECT_EQ(1, logger.find_error("after 23 attempts")); +}