diff --git a/src/methods/ss_opm/ss_opm/config.vsh.yaml b/src/methods/ss_opm/ss_opm/config.vsh.yaml new file mode 100644 index 0000000..beec3fa --- /dev/null +++ b/src/methods/ss_opm/ss_opm/config.vsh.yaml @@ -0,0 +1,26 @@ +__merge__: ../../../api/comp_method.yaml +name: ss_opm +label: SS-OPM +summary: 1st place solution of the Kaggle Open Problems Multimodal Single-Cell Integration challenge. +description: | + Encoder-decoder MLP method using SVD-based dimensionality reduction for both inputs and + targets, followed by batch-median correction. The encoder maps (optionally augmented) + cell embeddings to a latent space; multiple decoder blocks predict target expression in + the SVD-compressed space. The method was the winning solution of the NeurIPS 2021 + Open Problems Multimodal Single-Cell Integration Kaggle competition. +references: + doi: + - 10.1101/2022.04.11.487796 +links: + repository: https://github.com/shu65/open-problems-multimodal +info: + preferred_normalization: log_cp10k +resources: + - path: main.nf + type: nextflow_script + entrypoint: run_wf +dependencies: + - name: methods/ss_opm_train + - name: methods/ss_opm_predict +runners: + - type: nextflow diff --git a/src/methods/ss_opm/ss_opm/main.nf b/src/methods/ss_opm/ss_opm/main.nf new file mode 100644 index 0000000..9f30a8b --- /dev/null +++ b/src/methods/ss_opm/ss_opm/main.nf @@ -0,0 +1,18 @@ +workflow run_wf { + take: input_ch + main: + output_ch = input_ch + | ss_opm_train.run( + fromState: ["input_train_mod1", "input_train_mod2", "input_test_mod1"], + toState: ["input_model": "output"] + ) + | ss_opm_predict.run( + fromState: ["input_test_mod1", "input_model"], + toState: ["output": "output"] + ) + | map { tup -> + [tup[0], [output: tup[1].output]] + } + + emit: output_ch +} diff --git a/src/methods/ss_opm/ss_opm_predict/config.vsh.yaml b/src/methods/ss_opm/ss_opm_predict/config.vsh.yaml new file mode 100644 index 0000000..bd80d73 --- /dev/null +++ b/src/methods/ss_opm/ss_opm_predict/config.vsh.yaml @@ -0,0 +1,20 @@ +__merge__: ../../../api/comp_method_predict.yaml +name: ss_opm_predict +resources: + - type: python_script + path: script.py +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + setup: + - type: docker + run: pip install --no-cache-dir --no-deps git+https://github.com/shu65/open-problems-multimodal.git + - type: python + packages: + - pyarrow + - fastparquet +runners: + - type: executable + - type: nextflow + directives: + label: [highmem, hightime, midcpu, highsharedmem, gpu] diff --git a/src/methods/ss_opm/ss_opm_predict/script.py b/src/methods/ss_opm/ss_opm_predict/script.py new file mode 100644 index 0000000..448fb31 --- /dev/null +++ b/src/methods/ss_opm/ss_opm_predict/script.py @@ -0,0 +1,137 @@ +import sys +import os +import gc +import pickle +import numpy as np +import pandas as pd +import scipy.sparse +import anndata as ad +from ss_opm.model.encoder_decoder.encoder_decoder import EncoderDecoder + +import torch +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print(f'Using device: {device}', flush=True) + +## VIASH START +par = { + 'input_test_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/test_mod1.h5ad', + 'input_model': 'output/models/ss_opm', + 'output': 'output/prediction.h5ad', +} +meta = { + 'name': 'ss_opm_predict', + 'resources_dir': 'src/methods/ss_opm/ss_opm_predict', +} +## VIASH END + +def build_metadata(adata, task_type): + """Build a metadata DataFrame compatible with ss_opm from an H5AD AnnData. + + Mirrors the function in the train script; only used for the input preprocessing + path (targets are None at predict time, so group IDs are not critical). + """ + obs = pd.DataFrame(index=adata.obs_names) + + obs['batch'] = adata.obs['batch'].values + obs['day'] = adata.obs['batch'].str.extract(r'd(\d+)', expand=False).astype(float).fillna(0).values + + X = adata.layers['normalized'] + if scipy.sparse.issparse(X): + X_dense = X.toarray() + else: + X_dense = np.asarray(X, dtype=float) + + obs['nonzero_ratio'] = (X_dense != 0).mean(axis=1) + obs['nonzero_q25'] = np.percentile(X_dense, 25, axis=1) + obs['nonzero_q50'] = np.percentile(X_dense, 50, axis=1) + obs['nonzero_q75'] = np.percentile(X_dense, 75, axis=1) + obs['mean'] = X_dense.mean(axis=1) + obs['std'] = X_dense.std(axis=1) + + # Group: 0 for all test cells (group is not used during input-only preprocessing) + obs['group'] = 0 + obs['cell_type'] = 'hidden' + obs['donor'] = 0 + obs['technology'] = 'unknown' + + if task_type == 'cite': + for ct in ['HSC', 'EryP', 'NeuP', 'MasP', 'MkP', 'BP', 'MoP']: + obs[f'cell_ratio_{ct}'] = 1.0 / 7 + obs['cell_count'] = float(adata.n_obs) + for i in range(8): + obs[f'batch_sv{i}'] = 0.0 + + return obs + + +def to_sparse_csr(X): + if scipy.sparse.issparse(X): + return X.tocsr() + return scipy.sparse.csr_matrix(X) + + +# ---- Load task info ---- +with open(os.path.join(par['input_model'], 'task_info.pickle'), 'rb') as f: + task_info = pickle.load(f) +task_type = task_info['task_type'] +mod2 = task_info['mod2'] +dataset_id = task_info['dataset_id'] +print(f'Task type: {task_type}, mod2: {mod2}', flush=True) + +# ---- Load test data ---- +print('Loading test data...', flush=True) +input_test_mod1 = ad.read_h5ad(par['input_test_mod1']) +test_inputs = to_sparse_csr(input_test_mod1.layers['normalized']) +test_metadata = build_metadata(input_test_mod1, task_type) + +# ---- Load model and preprocessing artifacts ---- +print('Loading model...', flush=True) +with open(os.path.join(par['input_model'], 'pre_post_process.pickle'), 'rb') as f: + pre_post_process = pickle.load(f) + +model = EncoderDecoder(params=None) +# PyTorch >=2.6 defaults weights_only=True, which blocks custom classes. +# Patch torch.load to use weights_only=False for trusted local model files. +import torch as _torch +_orig_torch_load = _torch.load +_torch.load = lambda *a, **kw: _orig_torch_load(*a, **{**kw, 'weights_only': False}) +model.load(os.path.join(par['input_model'], 'model')) +_torch.load = _orig_torch_load +model.params['device'] = device + +mod2_var = pd.read_parquet(os.path.join(par['input_model'], 'mod2_var.parquet')) + +# ---- Preprocess test inputs ---- +print('Preprocessing test data...', flush=True) +preprocessed_test_inputs, _ = pre_post_process.preprocess( + inputs_values=test_inputs, + targets_values=None, + metadata=test_metadata, +) + +# ---- Predict ---- +print('Predicting...', flush=True) +y_pred = model.predict( + x=test_inputs, + preprocessed_x=preprocessed_test_inputs, + metadata=test_metadata, +) +gc.collect() + +# ---- Write output ---- +print('Writing output...', flush=True) +# Prediction must be a sparse matrix to be compatible with all metrics. +if not scipy.sparse.issparse(y_pred): + y_pred = scipy.sparse.csr_matrix(y_pred) + +output = ad.AnnData( + layers={"normalized": y_pred}, + obs=input_test_mod1.obs, + var=mod2_var, + uns={ + "dataset_id": dataset_id, + "method_id": meta["name"], + }, +) +output.write_h5ad(par['output'], compression="gzip") +print('Done!', flush=True) diff --git a/src/methods/ss_opm/ss_opm_train/config.vsh.yaml b/src/methods/ss_opm/ss_opm_train/config.vsh.yaml new file mode 100644 index 0000000..66287d9 --- /dev/null +++ b/src/methods/ss_opm/ss_opm_train/config.vsh.yaml @@ -0,0 +1,20 @@ +__merge__: ../../../api/comp_method_train.yaml +name: ss_opm_train +resources: + - type: python_script + path: script.py +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + setup: + - type: docker + run: pip install --no-cache-dir --no-deps git+https://github.com/shu65/open-problems-multimodal.git + - type: python + packages: + - pyarrow + - fastparquet +runners: + - type: executable + - type: nextflow + directives: + label: [highmem, hightime, midcpu, highsharedmem, gpu] diff --git a/src/methods/ss_opm/ss_opm_train/script.py b/src/methods/ss_opm/ss_opm_train/script.py new file mode 100644 index 0000000..2995145 --- /dev/null +++ b/src/methods/ss_opm/ss_opm_train/script.py @@ -0,0 +1,268 @@ +import sys +import os +import gc +import pickle +import tempfile +import numpy as np +import pandas as pd +import scipy.sparse +import anndata as ad + +from ss_opm.pre_post_processing.pre_post_processing import PrePostProcessing +from ss_opm.model.encoder_decoder.encoder_decoder import EncoderDecoder +from ss_opm.utility.set_seed import set_seed + +import torch +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print(f'Using device: {device}', flush=True) + +## VIASH START +par = { + 'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod1.h5ad', + 'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod2.h5ad', + 'input_test_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/test_mod1.h5ad', + 'output': 'output/models/ss_opm', +} +meta = { + 'name': 'ss_opm_train', + 'resources_dir': 'src/methods/ss_opm/ss_opm_train', +} +## VIASH END + +# Monkey-patch the source utility modules so ALL callers (pre_post_processing, +# IterativeSVDImputator, etc.) use safe versions that handle all-zero/all-NaN rows. +import ss_opm.utility.nonzero_median_normalize as _mnm_module +import ss_opm.utility.row_normalize as _rn_module +import ss_opm.pre_post_processing.pre_post_processing as _pp_module + +def _safe_median_normalize(values, ignore_zero=True, log=False): + """Median-normalize rows, substituting 1 (identity) when the median is 0 or NaN.""" + arr = np.asarray(values.toarray() if hasattr(values, 'toarray') else values, dtype=float).copy() + for_median = arr.copy() + if ignore_zero: + for_median[for_median == 0] = np.nan + med = np.nanquantile(for_median, q=0.5, axis=1) + # Use 1 as fallback so rows with zero/undefined median are left unchanged + med = np.where((med == 0) | ~np.isfinite(med), 1.0, med) + if log: + return arr - med[:, None] + else: + return arr / med[:, None] + +def _safe_row_normalize(v): + """Row-standardize; rows with std=0 are mean-subtracted only (result is zeros).""" + mu = np.mean(v, axis=1) + sigma = np.std(v, axis=1) + sigma = np.where(sigma == 0, 1.0, sigma) + return (v - mu[:, None]) / sigma[:, None] + +# Patch the source modules so every 'from X import Y' binding stays in sync +_mnm_module.median_normalize = _safe_median_normalize +_rn_module.row_normalize = _safe_row_normalize +# Also patch the names already bound inside pre_post_processing's namespace +_pp_module.median_normalize = _safe_median_normalize +_pp_module.row_normalize = _safe_row_normalize + +# The SVD decomposer components are stored as float64 tensors inside +# MultiEncoderDecoderModule, but the neural-network outputs are float32. +# Patch _train_step_forward to convert the whole sub-model to float32 +# immediately before any forward pass, so all tensors share the same dtype. +import ss_opm.model.encoder_decoder.encoder_decoder as _ed_module + +_orig_train_step_fwd = _ed_module.EncoderDecoder._train_step_forward + +def _patched_train_step_fwd(self, batch, training_length_ratio): + if hasattr(self, 'model') and self.model is not None: + self.model.float() + return _orig_train_step_fwd(self, batch, training_length_ratio) + +_ed_module.EncoderDecoder._train_step_forward = _patched_train_step_fwd + +SEED = 42 +set_seed(SEED) + + +def build_metadata(adata, task_type): + """Build a metadata DataFrame compatible with ss_opm from an H5AD AnnData. + + The original ss_opm model expects metadata columns derived from the Kaggle + competition dataset (technology, donor, day, cell_type, plus per-cell stats). + Here we derive what we can from the H5AD obs and compute per-cell statistics + directly from the normalized expression layer. + """ + obs = pd.DataFrame(index=adata.obs_names) + + obs['batch'] = adata.obs['batch'].values + + # Extract day from batch name (format: s{site}d{day}, e.g. 's1d1' -> 1) + obs['day'] = adata.obs['batch'].str.extract(r'd(\d+)', expand=False).astype(float).fillna(0).values + + # Per-cell statistics from the normalized expression layer + X = adata.layers['normalized'] + if scipy.sparse.issparse(X): + X_dense = X.toarray() + else: + X_dense = np.asarray(X, dtype=float) + + obs['nonzero_ratio'] = (X_dense != 0).mean(axis=1) + obs['nonzero_q25'] = np.percentile(X_dense, 25, axis=1) + obs['nonzero_q50'] = np.percentile(X_dense, 50, axis=1) + obs['nonzero_q75'] = np.percentile(X_dense, 75, axis=1) + obs['mean'] = X_dense.mean(axis=1) + obs['std'] = X_dense.std(axis=1) + + # Group ID: one group per unique batch (proxy for the original donor+day+technology groups) + unique_batches = adata.obs['batch'].unique().tolist() + batch_to_group = {b: i for i, b in enumerate(unique_batches)} + obs['group'] = adata.obs['batch'].map(batch_to_group).astype(int).values + + # Cell type: all 'hidden' (no cell type labels available in this format) + obs['cell_type'] = 'hidden' + + # Donor: constant 0 (no donor info; gender_id defaults to 0 = "female") + obs['donor'] = 0 + + # Technology: constant (not used in the batch group assignment above) + obs['technology'] = 'unknown' + + if task_type == 'cite': + # Uniform cell-type ratios (no cell type labels available) + for ct in ['HSC', 'EryP', 'NeuP', 'MasP', 'MkP', 'BP', 'MoP']: + obs[f'cell_ratio_{ct}'] = 1.0 / 7 + # Cell count per batch + batch_counts = adata.obs['batch'].value_counts() + obs['cell_count'] = adata.obs['batch'].map(batch_counts).astype(float).values + # Batch singular vectors: zero-filled (not computable without the full Kaggle dataset) + for i in range(8): + obs[f'batch_sv{i}'] = 0.0 + + return obs + + +def to_sparse_csr(X): + if scipy.sparse.issparse(X): + return X.tocsr() + return scipy.sparse.csr_matrix(X) + + +# ---- Load data ---- +print('Loading data...', flush=True) +input_train_mod1 = ad.read_h5ad(par['input_train_mod1']) +input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) + +mod1 = input_train_mod1.uns['modality'] +mod2 = input_train_mod2.uns['modality'] +dataset_id = input_train_mod1.uns['dataset_id'] +print(f'Modalities: {mod1} -> {mod2}', flush=True) + +# Determine task type: 'cite' when ADT is involved, 'multi' for ATAC/GEX +task_type = 'cite' if 'ADT' in (mod1, mod2) else 'multi' +print(f'Task type: {task_type}', flush=True) + +train_inputs = to_sparse_csr(input_train_mod1.layers['normalized']) +train_targets = to_sparse_csr(input_train_mod2.layers['normalized']) +n_vars_mod1 = train_inputs.shape[1] +n_vars_mod2 = train_targets.shape[1] + +train_metadata = build_metadata(input_train_mod1, task_type) + +# Store mod2 var for the predict step +mod2_var = input_train_mod2.var.copy() + +del input_train_mod1, input_train_mod2 +gc.collect() + +# ---- Load test inputs for SVD fitting (optional but improves preprocessing) ---- +test_inputs = None +test_metadata = None +if par.get('input_test_mod1'): + print('Loading test data for SVD fitting...', flush=True) + input_test_mod1 = ad.read_h5ad(par['input_test_mod1']) + test_inputs = to_sparse_csr(input_test_mod1.layers['normalized']) + test_metadata = build_metadata(input_test_mod1, task_type) + del input_test_mod1 + gc.collect() + +# ---- Create data_dir with CITE-specific mask files ---- +# The original PrePostProcessing loads pre-computed feature-target correlation +# masks from data_dir. We replace them with all-True masks so all input features +# are retained as supplementary raw features alongside the SVD components. +data_dir = tempfile.mkdtemp() +if task_type == 'cite': + mask_pair = np.ones((n_vars_mod1, n_vars_mod2), dtype=bool) + np.savez(os.path.join(data_dir, 'cite_inputs_targets_pair3g.npz'), mask=mask_pair) + mask2 = np.zeros((n_vars_mod1,), dtype=bool) + np.savez(os.path.join(data_dir, 'cite_inputs_mask2.npz'), mask=mask2) + +# ---- Get parameters ---- +pre_post_process_params = PrePostProcessing.get_params( + task_type=task_type, + data_dir=data_dir, + device=device, + seed=SEED, +) +model_params = EncoderDecoder.get_params( + task_type=task_type, + device=device, +) + +# ---- Fit preprocessing ---- +print('Fitting preprocessing...', flush=True) +pre_post_process = PrePostProcessing(pre_post_process_params) + +# Use test inputs alongside train inputs for fitting SVD (improves coverage) +_test_inputs_for_svd = test_inputs if test_inputs is not None else train_inputs +_test_metadata_for_svd = test_metadata if test_metadata is not None else train_metadata + +pre_post_process.fit_preprocess( + inputs_values=train_inputs, + targets_values=train_targets, + metadata=train_metadata, + test_inputs_values=_test_inputs_for_svd, + test_metadata=_test_metadata_for_svd, +) + +# ---- Preprocess training data ---- +print('Preprocessing training data...', flush=True) +preprocessed_inputs, preprocessed_targets = pre_post_process.preprocess( + inputs_values=train_inputs, + targets_values=train_targets, + metadata=train_metadata, +) + +# ---- Train model ---- +# Cast preprocessed arrays to float32 to match what PyTorch expects. +if isinstance(preprocessed_inputs, np.ndarray): + preprocessed_inputs = preprocessed_inputs.astype(np.float32) +if isinstance(preprocessed_targets, np.ndarray): + preprocessed_targets = preprocessed_targets.astype(np.float32) + +print('Training model...', flush=True) +model = EncoderDecoder(model_params) +model.fit( + x=train_inputs, + preprocessed_x=preprocessed_inputs, + y=train_targets, + preprocessed_y=preprocessed_targets, + metadata=train_metadata, + pre_post_process=pre_post_process, +) +gc.collect() + +# ---- Save model and preprocessing artifacts ---- +print('Saving model...', flush=True) +os.makedirs(par['output'], exist_ok=True) + +model_dir = os.path.join(par['output'], 'model') +os.makedirs(model_dir, exist_ok=True) +model.save(model_dir) + +with open(os.path.join(par['output'], 'pre_post_process.pickle'), 'wb') as f: + pickle.dump(pre_post_process, f) + +mod2_var.to_parquet(os.path.join(par['output'], 'mod2_var.parquet')) + +with open(os.path.join(par['output'], 'task_info.pickle'), 'wb') as f: + pickle.dump({'task_type': task_type, 'mod2': mod2, 'dataset_id': dataset_id}, f) + +print('Done!', flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index ea4396c..c87f561 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -74,6 +74,7 @@ dependencies: - name: methods/guanlab_dengkw_pm - name: methods/novel - name: methods/simple_mlp + - name: methods/ss_opm - name: metrics/correlation - name: metrics/mse runners: diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 6a7989d..2e6562e 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -18,7 +18,8 @@ methods = [ lm, guanlab_dengkw_pm, novel, - simple_mlp + simple_mlp, + ss_opm ] // construct list of metrics diff --git a/src/workflows/run_benchmark/run_test.sh b/src/workflows/run_benchmark/run_test.sh index ed97015..7dbd6b7 100755 --- a/src/workflows/run_benchmark/run_test.sh +++ b/src/workflows/run_benchmark/run_test.sh @@ -15,17 +15,17 @@ if [ ! -d "$OUTPUT_DIR" ]; then mkdir -p "$OUTPUT_DIR" fi # run benchmark -export NXF_VER=23.04.2 +export NXF_VER=25.10.5 nextflow run . \ - -main-script target/nextflow/predict_modality/workflows/run_benchmark/main.nf \ + -main-script target/nextflow/workflows/run_benchmark/main.nf \ -profile docker \ -resume \ -entry auto \ -with-trace \ -c common/nextflow_helpers/labels_ci.config \ --input_states "$DATASETS_DIR/**/state.yaml" \ - --rename_keys 'input_train_mod1:output_train_mod1,input_train_mod2:output_train_mod2,input_test_mod1:output_test_mod1,input_test_mod2:output_test_mod2' \ + --rename_keys 'input_train_mod1:output_train_mod1;input_train_mod2:output_train_mod2;input_test_mod1:output_test_mod1;input_test_mod2:output_test_mod2' \ --settings '{"output_scores": "scores.yaml", "output_dataset_info": "dataset_info.yaml", "output_method_configs": "method_configs.yaml", "output_metric_configs": "metric_configs.yaml", "output_task_info": "task_info.yaml"}' \ --publish_dir "$OUTPUT_DIR" \ --output_state 'state.yaml' \ No newline at end of file