diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index ab02cd552e..6bad94ec04 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -980,7 +980,8 @@ def run( common parameters shown below will be added and can be passed through the `override` parameter of this method. - ``"output_dir"``: the path to save mlflow tracking outputs locally, default to "/eval". - - ``"tracking_uri"``: uri to save mlflow tracking outputs, default to "/output_dir/mlruns". + - ``"tracking_uri"``: uri to save mlflow tracking outputs, default to a local SQLite database + at "/mlruns.db" with run artifacts kept under "/mlruns". - ``"experiment_name"``: experiment name for this run, default to "monai_experiment". - ``"run_name"``: the name of current run. - ``"save_execute_config"``: whether to save the executed config files. It can be `False`, `/path/to/artifacts` diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index d37d7f1c05..bb32bd1621 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -118,8 +118,10 @@ "configs": { # if no "output_dir" in the bundle config, default to "/eval" "output_dir": "$@bundle_root + '/eval'", - # use URI to support linux, mac and windows os - "tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlruns'", + # MLflow 3.13+ rejects the filesystem (file store) tracking backend, so default tracking + # to a local SQLite database. The handler keeps run artifacts under "/mlruns" + # (next to the db). A URI is used so the path is valid on linux, mac and windows os. + "tracking_uri": "$monai.utils.path_to_sqlite_uri(@output_dir + '/mlruns.db')", "experiment_name": "monai_experiment", "run_name": None, # may fill it at runtime diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 3078d89f97..3cc6536d7c 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -22,7 +22,16 @@ from torch.utils.data import Dataset from monai.apps.utils import get_logger -from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import +from monai.utils import ( + CommonKeys, + IgniteInfo, + ensure_tuple, + flatten_dict, + min_version, + optional_import, + path_to_sqlite_uri, + path_to_uri, +) Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.") @@ -68,7 +77,14 @@ class MLFlowHandler: tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment variable to have MLflow find a URI from there. in both cases, the URI can either be an HTTP/HTTPS URI for a remote server, a database connection string, or a local path - to log data to a directory. The URI defaults to path `mlruns`. + to log data to a directory. When no ``tracking_uri`` is provided and the + ``MLFLOW_TRACKING_URI`` environment variable is unset, the handler now + defaults to a local SQLite database backend at ``sqlite:////mlruns.db`` with + artifacts stored under ``/mlruns``. The default was changed from the filesystem + (file store) backend because MLflow 3.13+ raises an exception for the file store unless + ``MLFLOW_ALLOW_FILE_STORE=true`` is set; SQLite is the backend MLflow recommends and it + does not raise. Any explicitly provided ``tracking_uri`` (including a local file path or + ``file://`` URI) is passed through unchanged. for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri. iteration_log: whether to log data to MLFlow when iteration completed, default to `True`. ``iteration_log`` can be also a function and it will be interpreted as an event filter @@ -113,6 +129,11 @@ class MLFlowHandler: optimizer_param_names: parameter names in the optimizer that need to be recorded during running the workflow, default to `'lr'`. close_on_complete: whether to close the mlflow run in `complete` phase in workflow, default to False. + artifact_location: the location to store run artifacts in, passed to MLflow when the experiment is + created. When ``None`` and a local SQLite backend is used (from the ``tracking_uri`` argument + or the ``MLFLOW_TRACKING_URI`` environment variable), it defaults to an ``mlruns`` directory + next to the database file; for other backends ``None`` lets MLflow decide based on the + ``tracking_uri``. Has no effect if the experiment already exists. For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html. @@ -141,6 +162,7 @@ def __init__( artifacts: str | Sequence[Path] | None = None, optimizer_param_names: str | Sequence[str] = "lr", close_on_complete: bool = False, + artifact_location: str | None = None, ) -> None: self.iteration_log = iteration_log self.epoch_log = epoch_log @@ -156,6 +178,31 @@ def __init__( self.experiment_param = experiment_param self.artifacts = ensure_tuple(artifacts) self.optimizer_param_names = ensure_tuple(optimizer_param_names) + # When no tracking_uri is provided, default to a local SQLite backend instead of the + # filesystem (file store) backend. MLflow 3.13+ raises for the file store unless + # `MLFLOW_ALLOW_FILE_STORE=true` is set, while SQLite is the recommended backend and does + # not raise. Artifacts cannot live inside a database, so by default they are stored under + # the `./mlruns` directory (where the previous file store default kept them) via the + # experiment `artifact_location`. Any explicitly provided tracking_uri is left unchanged. + self.artifact_location = artifact_location + # Resolve the effective tracking URI from the argument or the `MLFLOW_TRACKING_URI` + # environment variable, so both configure the artifact location the same way. + effective_tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI") + # When neither is set, fall back to the local SQLite default described above. + if not effective_tracking_uri: + tracking_uri = effective_tracking_uri = path_to_sqlite_uri(os.path.join(os.getcwd(), "mlruns.db")) + # For a local SQLite backend, keep run artifacts in an `mlruns` directory next to the + # database file (mirroring the previous file-store layout) unless the caller set + # `artifact_location`. Other backends (e.g. a remote server) are left to MLflow to decide. + # Only `tracking_uri` is passed to the client, so an `MLFLOW_TRACKING_URI` env var is + # still resolved by MLflow itself. + if ( + self.artifact_location is None + and effective_tracking_uri + and effective_tracking_uri.startswith("sqlite:///") + ): + db_path = Path(effective_tracking_uri[len("sqlite:///") :]) + self.artifact_location = path_to_uri(db_path.parent / "mlruns") self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None) self.run_finish_status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED) self.close_on_complete = close_on_complete @@ -245,7 +292,12 @@ def _set_experiment(self): try: experiment = self.client.get_experiment_by_name(self.experiment_name) if not experiment: - experiment_id = self.client.create_experiment(self.experiment_name) + # pass an explicit artifact_location (set for the default SQLite backend, or + # by the caller) so artifacts land in the intended directory; when it is + # None MLflow decides based on the tracking_uri. + experiment_id = self.client.create_experiment( + self.experiment_name, artifact_location=self.artifact_location + ) experiment = self.client.get_experiment(experiment_id) break except MlflowException as e: @@ -336,14 +388,43 @@ def complete(self) -> None: for artifact in artifact_list: self.client.log_artifact(self.cur_run.info.run_id, artifact) + def _dispose_sqlite_store(self) -> None: + """ + Release MLflow's SQLAlchemy engine when a local SQLite tracking backend is used. + + MLflow keeps the SQLite connection open for the lifetime of the client, which on + Windows prevents the database file from being deleted. MLflow exposes no public + client close/dispose API, so this reaches into its internals defensively to release + the engine. It is a no-op for non-SQLite backends. + """ + tracking_uri = getattr(self.client, "tracking_uri", "") + if not isinstance(tracking_uri, str) or not tracking_uri.startswith("sqlite:"): + return + store = getattr(getattr(self.client, "_tracking_client", None), "store", None) + if store is None: + return + dispose = getattr(store, "_dispose_engine", None) + if callable(dispose): + dispose() + else: + engine = getattr(store, "engine", None) + if engine is not None: + engine.dispose() + read_engine = getattr(store, "read_engine", None) + if read_engine is not None: + read_engine.dispose() + def close(self) -> None: """ - Stop current running logger of MLFlow. + Stop current running logger of MLFlow and release local SQLite resources. """ - if self.cur_run: - self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status) - self.cur_run = None + try: + if self.cur_run: + self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status) + self.cur_run = None + finally: + self._dispose_sqlite_store() def epoch_completed(self, engine: Engine) -> None: """ diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 3efc9b5e7f..75501c7f7a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -89,6 +89,7 @@ is_sqrt, issequenceiterable, list_to_dict, + path_to_sqlite_uri, path_to_uri, pprint_edges, progress_bar, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index ed48d4b37d..23d2216d84 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -69,6 +69,7 @@ "save_obj", "label_union", "path_to_uri", + "path_to_sqlite_uri", "pprint_edges", "check_key_duplicates", "CheckKeyDuplicatesYamlLoader", @@ -727,6 +728,21 @@ def path_to_uri(path: PathLike) -> str: return Path(path).absolute().as_uri() +def path_to_sqlite_uri(path: PathLike) -> str: + """ + Convert a database file path to a SQLite connection URI, e.g. for use as an MLflow + ``tracking_uri``. If not an absolute path, it is converted to an absolute path first. + + A forward-slash (POSIX) path is used so the URI is valid on Windows as well as POSIX: + on Windows this yields ``sqlite:///C:/path/db.sqlite`` and on POSIX ``sqlite:////path/db.sqlite``. + + Args: + path: input database file path, can be a string or `Path` object. + + """ + return f"sqlite:///{Path(path).absolute().as_posix()}" + + def pprint_edges(val: Any, n_lines: int = 20) -> str: """ Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines. diff --git a/tests/fl/monai_algo/test_fl_monai_algo.py b/tests/fl/monai_algo/test_fl_monai_algo.py index 2c1a8488cc..a55dcc4560 100644 --- a/tests/fl/monai_algo/test_fl_monai_algo.py +++ b/tests/fl/monai_algo/test_fl_monai_algo.py @@ -26,7 +26,7 @@ from monai.fl.client.monai_algo import MonaiAlgo from monai.fl.utils.constants import ExtraItems from monai.fl.utils.exchange_object import ExchangeObject -from monai.utils import path_to_uri +from monai.utils import path_to_sqlite_uri from tests.test_utils import SkipIfNoModule _root_dir = Path(__file__).resolve().parents[2] @@ -79,7 +79,7 @@ "save_execute_config": f"{_data_dir}/config_executed.json", "trainer": { "_target_": "MLFlowHandler", - "tracking_uri": path_to_uri(_data_dir) + "/mlflow_override", + "tracking_uri": path_to_sqlite_uri(os.path.join(_data_dir, "mlflow_override.db")), "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", "close_on_complete": True, }, @@ -103,7 +103,7 @@ workflow_type="train", logging_file=_logging_file, tracking="mlflow", - tracking_uri=path_to_uri(_data_dir) + "/mlflow_1", + tracking_uri=path_to_sqlite_uri(os.path.join(_data_dir, "mlflow_1.db")), experiment_name="monai_eval1", ), "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), @@ -119,7 +119,7 @@ ], "eval_kwargs": { "tracking": "mlflow", - "tracking_uri": path_to_uri(_data_dir) + "/mlflow_2", + "tracking_uri": path_to_sqlite_uri(os.path.join(_data_dir, "mlflow_2.db")), "experiment_name": "monai_eval2", }, "eval_workflow_name": "training", @@ -179,6 +179,38 @@ ] +def _dispose_sqlite_engines(): + """Dispose MLflow's open SQLAlchemy SQLite engines so the test ``.db`` files can be removed. + + MLflow keeps a SQLite connection open for the lifetime of its client; on Windows that + locks the database file and breaks cleanup. ``MLFlowHandler.close()`` releases it, but a + workflow may finish without closing every handler, so dispose defensively here before + deleting the files. Scoped to the test's ``mlflow*.db`` backends so unrelated (e.g. + in-memory) sqlite engines elsewhere in the process are left untouched. + """ + import gc + + try: + from sqlalchemy.engine import Engine + except ImportError: + return + gc.collect() + for obj in gc.get_objects(): + # gc.get_objects() can include dead weakref proxies, whose isinstance() raises + # ReferenceError, so guard the whole inspection (ReferenceError is an Exception). + try: + if not isinstance(obj, Engine): + continue + url = obj.url + db = url.database if url.get_backend_name() == "sqlite" else None + # the test backends are all files named ``mlflow*.db``; match those only so + # unrelated (e.g. in-memory) sqlite engines in the process are left untouched. + if db and os.path.basename(db).startswith("mlflow"): + obj.dispose() + except Exception: + pass + + @SkipIfNoModule("ignite") @SkipIfNoModule("mlflow") class TestFLMonaiAlgo(unittest.TestCase): @@ -202,8 +234,11 @@ def test_train(self, input_params): # test experiment management if "save_execute_config" in algo.train_workflow.parser: - self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override")) - shutil.rmtree(f"{_data_dir}/mlflow_override") + _dispose_sqlite_engines() # release SQLite handles so the db file can be removed on Windows + self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override.db")) + os.remove(f"{_data_dir}/mlflow_override.db") + if os.path.isdir(f"{_data_dir}/mlruns"): + shutil.rmtree(f"{_data_dir}/mlruns") self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json")) os.remove(f"{_data_dir}/config_executed.json") @@ -225,9 +260,12 @@ def test_evaluate(self, input_params): # test experiment management if "save_execute_config" in algo.eval_workflow.parser: + _dispose_sqlite_engines() # release SQLite handles so the db files can be removed on Windows self.assertGreater(len(list(glob.glob(f"{_data_dir}/mlflow_*"))), 0) for f in list(glob.glob(f"{_data_dir}/mlflow_*")): - shutil.rmtree(f) + shutil.rmtree(f) if os.path.isdir(f) else os.remove(f) + if os.path.isdir(f"{_data_dir}/mlruns"): + shutil.rmtree(f"{_data_dir}/mlruns") self.assertGreater(len(list(glob.glob(f"{_data_dir}/eval/config_*"))), 0) for f in list(glob.glob(f"{_data_dir}/eval/config_*")): os.remove(f) diff --git a/tests/handlers/test_handler_mlflow.py b/tests/handlers/test_handler_mlflow.py index 80630e6f5a..8039c1a72e 100644 --- a/tests/handlers/test_handler_mlflow.py +++ b/tests/handlers/test_handler_mlflow.py @@ -17,7 +17,7 @@ import tempfile import unittest from concurrent.futures import ThreadPoolExecutor -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np from ignite.engine import Engine, Events @@ -26,7 +26,7 @@ from monai.apps import download_and_extract from monai.bundle import ConfigWorkflow, download from monai.handlers import MLFlowHandler -from monai.utils import optional_import, path_to_uri +from monai.utils import optional_import, path_to_sqlite_uri, path_to_uri from tests.test_utils import skip_if_downloading_fails, skip_if_quick _, has_dataset_tracking = optional_import("mlflow", "2.4.0") @@ -55,7 +55,7 @@ def _train_func(engine, batch): handler = MLFlowHandler( iteration_log=False, epoch_log=True, - tracking_uri=path_to_uri(test_path), + tracking_uri=path_to_sqlite_uri(test_path), state_attributes=["test"], close_on_complete=True, ) @@ -71,7 +71,8 @@ def setUp(self): def tearDown(self): for tmpdir in self.tmpdir_list: if tmpdir and os.path.exists(tmpdir): - shutil.rmtree(tmpdir) + # the SQLite default backend creates a db file rather than a directory + shutil.rmtree(tmpdir) if os.path.isdir(tmpdir) else os.remove(tmpdir) def test_multi_run(self): with tempfile.TemporaryDirectory() as tempdir: @@ -95,7 +96,7 @@ def _update_metric(engine): handler = MLFlowHandler( iteration_log=False, epoch_log=True, - tracking_uri=path_to_uri(test_path), + tracking_uri=path_to_sqlite_uri(test_path), state_attributes=["test"], close_on_complete=True, ) @@ -137,7 +138,7 @@ def _update_metric(engine): handler = MLFlowHandler( iteration_log=False, epoch_log=True, - tracking_uri=path_to_uri(test_path), + tracking_uri=path_to_sqlite_uri(test_path), state_attributes=["test"], experiment_param=experiment_param, artifacts=[artifact_path], @@ -173,7 +174,7 @@ def _update_metric(engine): handler = MLFlowHandler( iteration_log=False, epoch_log=epoch_log, - tracking_uri=path_to_uri(test_path), + tracking_uri=path_to_sqlite_uri(test_path), state_attributes=["test"], experiment_param=experiment_param, close_on_complete=True, @@ -212,7 +213,7 @@ def _update_metric(engine): handler = MLFlowHandler( iteration_log=iteration_log, epoch_log=False, - tracking_uri=path_to_uri(test_path), + tracking_uri=path_to_sqlite_uri(test_path), state_attributes=["test"], experiment_param=experiment_param, close_on_complete=True, @@ -242,6 +243,117 @@ def test_multi_thread(self): self.tmpdir_list.append(res) self.assertTrue(len(glob.glob(res)) > 0) + def test_default_tracking_uri_is_sqlite(self): + # when no tracking_uri is provided, the handler should default to a local SQLite backend + # rather than the filesystem (file store) backend, which raises on mlflow 3.13+. + with tempfile.TemporaryDirectory() as tempdir: + cwd = os.getcwd() + os.chdir(tempdir) + handler = None + try: + handler = MLFlowHandler(iteration_log=False, epoch_log=False) + self.assertTrue(handler.client.tracking_uri.startswith("sqlite:///")) + self.assertTrue(handler.client.tracking_uri.endswith("mlruns.db")) + # artifacts should still default to a `./mlruns`-style directory + self.assertIsNotNone(handler.artifact_location) + self.assertTrue(handler.artifact_location.endswith("mlruns")) + finally: + if handler is not None: + handler.close() # release the SQLite handle so Windows can delete the db + os.chdir(cwd) + + def test_explicit_tracking_uri_is_preserved(self): + # an explicitly provided tracking_uri must be passed through unchanged, including file paths. + with tempfile.TemporaryDirectory() as tempdir: + explicit_uri = path_to_uri(os.path.join(tempdir, "mlflow_explicit")) + handler = MLFlowHandler(iteration_log=False, epoch_log=False, tracking_uri=explicit_uri) + self.assertEqual(handler.client.tracking_uri, explicit_uri) + self.assertIsNone(handler.artifact_location) + + def test_remote_tracking_uri_leaves_artifact_location_unset(self): + # a non-local (e.g. remote) tracking_uri must not get a local artifact_location injected, + # so the remote backend keeps deciding where artifacts go. + handler = MLFlowHandler(iteration_log=False, epoch_log=False, tracking_uri="http://localhost:5000") + self.assertEqual(handler.client.tracking_uri, "http://localhost:5000") + self.assertIsNone(handler.artifact_location) + + def test_explicit_sqlite_tracking_uri_colocates_artifacts(self): + # an explicit local SQLite tracking_uri should still co-locate artifacts next to the db. + with tempfile.TemporaryDirectory() as tempdir: + uri = path_to_sqlite_uri(os.path.join(tempdir, "sub", "mlruns.db")) + handler = MLFlowHandler(iteration_log=False, epoch_log=False, tracking_uri=uri) + try: + self.assertEqual(handler.client.tracking_uri, uri) + self.assertIsNotNone(handler.artifact_location) + self.assertTrue(handler.artifact_location.endswith("mlruns")) + finally: + handler.close() # release the SQLite handle so Windows can delete the db + + def test_env_var_sqlite_tracking_uri_colocates_artifacts(self): + # a SQLite `MLFLOW_TRACKING_URI` env var should co-locate artifacts next to the db, the + # same as an explicit `tracking_uri` argument. The env var itself is left for MLflow to + # resolve, so the handler does not pass it to the client. + with tempfile.TemporaryDirectory() as tempdir: + uri = path_to_sqlite_uri(os.path.join(tempdir, "sub", "mlruns.db")) + handler = None + with patch.dict(os.environ, {"MLFLOW_TRACKING_URI": uri}): + try: + handler = MLFlowHandler(iteration_log=False, epoch_log=False) + self.assertTrue(handler.client.tracking_uri.endswith("mlruns.db")) + self.assertIsNotNone(handler.artifact_location) + self.assertTrue(handler.artifact_location.endswith("mlruns")) + # co-located with the db file (the `sub` dir), not a cwd-relative `./mlruns` + self.assertIn("sub", handler.artifact_location) + finally: + if handler is not None: + handler.close() # release the SQLite handle so Windows can delete the db + + def test_explicit_artifact_location_is_used(self): + # an explicitly provided artifact_location should be kept even with the default SQLite + # backend, so callers (e.g. the bundle defaults) can co-locate artifacts with the db. + with tempfile.TemporaryDirectory() as tempdir: + cwd = os.getcwd() + os.chdir(tempdir) + handler = None + try: + art = path_to_uri(os.path.join(tempdir, "artifacts")) + handler = MLFlowHandler(iteration_log=False, epoch_log=False, artifact_location=art) + self.assertEqual(handler.artifact_location, art) + finally: + if handler is not None: + handler.close() # release the SQLite handle so Windows can delete the db + os.chdir(cwd) + + def test_default_sqlite_run_flow(self): + # a basic log/run flow should work with the default SQLite backend (no tracking_uri given). + with tempfile.TemporaryDirectory() as tempdir: + cwd = os.getcwd() + os.chdir(tempdir) + try: + + def _train_func(engine, batch): + return [batch + 1.0] + + engine = Engine(_train_func) + + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + + # close_on_complete=False so cur_run stays available after the run for the metric + # check below; the run is closed explicitly afterwards. + handler = MLFlowHandler(iteration_log=False, epoch_log=True, close_on_complete=False) + handler.attach(engine) + engine.run(range(3), max_epochs=2) + cur_run = handler.client.get_run(handler.cur_run.info.run_id) + self.assertTrue("acc" in cur_run.data.metrics.keys()) + handler.close() + # the default backend should have created a SQLite database file in the cwd + self.assertTrue(os.path.exists(os.path.join(tempdir, "mlruns.db"))) + finally: + os.chdir(cwd) + @skip_if_quick @unittest.skipUnless(has_dataset_tracking, reason="Requires mlflow version >= 2.4.0.") def test_dataset_tracking(self): @@ -271,7 +383,7 @@ def test_dataset_tracking(self): final_id="finalize", ) - tracking_path = os.path.join(bundle_root, "eval") + tracking_path = os.path.join(tempdir, "mlflow_dataset.db") workflow.bundle_root = bundle_root workflow.dataset_dir = data_dir workflow.initialize() @@ -280,7 +392,7 @@ def test_dataset_tracking(self): iteration_log=False, epoch_log=False, dataset_dict={"test": infer_dataset}, - tracking_uri=path_to_uri(tracking_path), + tracking_uri=path_to_sqlite_uri(tracking_path), ) mlflow_handler.attach(workflow.evaluator) workflow.run() diff --git a/tests/integration/test_integration_bundle_run.py b/tests/integration/test_integration_bundle_run.py index 7f366d4745..67f4456259 100644 --- a/tests/integration/test_integration_bundle_run.py +++ b/tests/integration/test_integration_bundle_run.py @@ -29,7 +29,7 @@ from monai.bundle import ConfigParser from monai.bundle.utils import DEFAULT_HANDLERS_ID from monai.transforms import LoadImage -from monai.utils import path_to_uri +from monai.utils import path_to_sqlite_uri from tests.test_utils import command_line_tests TESTS_PATH = Path(__file__).parents[1] @@ -175,7 +175,7 @@ def test_shape(self, config_file, expected_shape): "no_epoch": True, # test override config in the settings file "evaluator": { "_target_": "MLFlowHandler", - "tracking_uri": "$monai.utils.path_to_uri(@output_dir) + '/mlflow_override1'", + "tracking_uri": "$monai.utils.path_to_sqlite_uri(@output_dir + '/mlflow_override1.db')", "iteration_log": "@no_epoch", }, }, @@ -208,16 +208,17 @@ def test_shape(self, config_file, expected_shape): command_line_tests(la + ["--args_file", def_args_file] + ["--tracking", settings_file]) loader = LoadImage(image_only=True) self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape) - self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override1")) + self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override1.db")) - tracking_uri = path_to_uri(tempdir) + "/mlflow_override2" # test override experiment management configs + # test override experiment management configs + tracking_uri = path_to_sqlite_uri(os.path.join(tempdir, "mlflow_override2.db")) # here test the script with `google fire` tool as CLI cmd = "-m fire monai.bundle.scripts run --tracking mlflow --evaluator#amp False" cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir} --device {device}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] command_line_tests(la) self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) - self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override2")) + self.assertTrue(os.path.exists(f"{tempdir}/mlflow_override2.db")) # test the saved execution configs self.assertTrue(len(glob(f"{tempdir}/config_*.json")), 2)