Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/methods/ss_opm/ss_opm/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__merge__: ../../../api/comp_method.yaml
name: ss_opm
label: SS-OPM
summary: 1st place solution of the Kaggle Open Problems Multimodal Single-Cell Integration challenge.
description: |
Encoder-decoder MLP method using SVD-based dimensionality reduction for both inputs and
targets, followed by batch-median correction. The encoder maps (optionally augmented)
cell embeddings to a latent space; multiple decoder blocks predict target expression in
the SVD-compressed space. The method was the winning solution of the NeurIPS 2021
Open Problems Multimodal Single-Cell Integration Kaggle competition.
references:
doi:
- 10.1101/2022.04.11.487796
links:
repository: https://github.com/shu65/open-problems-multimodal
info:
preferred_normalization: log_cp10k
resources:
- path: main.nf
type: nextflow_script
entrypoint: run_wf
dependencies:
- name: methods/ss_opm_train
- name: methods/ss_opm_predict
runners:
- type: nextflow
18 changes: 18 additions & 0 deletions src/methods/ss_opm/ss_opm/main.nf
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
workflow run_wf {
take: input_ch
main:
output_ch = input_ch
| ss_opm_train.run(
fromState: ["input_train_mod1", "input_train_mod2", "input_test_mod1"],
toState: ["input_model": "output"]
)
| ss_opm_predict.run(
fromState: ["input_test_mod1", "input_model"],
toState: ["output": "output"]
)
| map { tup ->
[tup[0], [output: tup[1].output]]
}

emit: output_ch
}
20 changes: 20 additions & 0 deletions src/methods/ss_opm/ss_opm_predict/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
__merge__: ../../../api/comp_method_predict.yaml
name: ss_opm_predict
resources:
- type: python_script
path: script.py
engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1
setup:
- type: docker
run: pip install --no-cache-dir --no-deps git+https://github.com/shu65/open-problems-multimodal.git
- type: python
packages:
- pyarrow
- fastparquet
runners:
- type: executable
- type: nextflow
directives:
label: [highmem, hightime, midcpu, highsharedmem, gpu]
137 changes: 137 additions & 0 deletions src/methods/ss_opm/ss_opm_predict/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import sys
import os
import gc
import pickle
import numpy as np
import pandas as pd
import scipy.sparse
import anndata as ad
from ss_opm.model.encoder_decoder.encoder_decoder import EncoderDecoder

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}', flush=True)

## VIASH START
par = {
'input_test_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/test_mod1.h5ad',
'input_model': 'output/models/ss_opm',
'output': 'output/prediction.h5ad',
}
meta = {
'name': 'ss_opm_predict',
'resources_dir': 'src/methods/ss_opm/ss_opm_predict',
}
## VIASH END

def build_metadata(adata, task_type):
"""Build a metadata DataFrame compatible with ss_opm from an H5AD AnnData.

Mirrors the function in the train script; only used for the input preprocessing
path (targets are None at predict time, so group IDs are not critical).
"""
obs = pd.DataFrame(index=adata.obs_names)

obs['batch'] = adata.obs['batch'].values
obs['day'] = adata.obs['batch'].str.extract(r'd(\d+)', expand=False).astype(float).fillna(0).values

X = adata.layers['normalized']
if scipy.sparse.issparse(X):
X_dense = X.toarray()
else:
X_dense = np.asarray(X, dtype=float)

obs['nonzero_ratio'] = (X_dense != 0).mean(axis=1)
obs['nonzero_q25'] = np.percentile(X_dense, 25, axis=1)
obs['nonzero_q50'] = np.percentile(X_dense, 50, axis=1)
obs['nonzero_q75'] = np.percentile(X_dense, 75, axis=1)
obs['mean'] = X_dense.mean(axis=1)
obs['std'] = X_dense.std(axis=1)

# Group: 0 for all test cells (group is not used during input-only preprocessing)
obs['group'] = 0
obs['cell_type'] = 'hidden'
obs['donor'] = 0
obs['technology'] = 'unknown'

if task_type == 'cite':
for ct in ['HSC', 'EryP', 'NeuP', 'MasP', 'MkP', 'BP', 'MoP']:
obs[f'cell_ratio_{ct}'] = 1.0 / 7
obs['cell_count'] = float(adata.n_obs)
for i in range(8):
obs[f'batch_sv{i}'] = 0.0

return obs


def to_sparse_csr(X):
if scipy.sparse.issparse(X):
return X.tocsr()
return scipy.sparse.csr_matrix(X)


# ---- Load task info ----
with open(os.path.join(par['input_model'], 'task_info.pickle'), 'rb') as f:
task_info = pickle.load(f)
task_type = task_info['task_type']
mod2 = task_info['mod2']
dataset_id = task_info['dataset_id']
print(f'Task type: {task_type}, mod2: {mod2}', flush=True)

# ---- Load test data ----
print('Loading test data...', flush=True)
input_test_mod1 = ad.read_h5ad(par['input_test_mod1'])
test_inputs = to_sparse_csr(input_test_mod1.layers['normalized'])
test_metadata = build_metadata(input_test_mod1, task_type)

# ---- Load model and preprocessing artifacts ----
print('Loading model...', flush=True)
with open(os.path.join(par['input_model'], 'pre_post_process.pickle'), 'rb') as f:
pre_post_process = pickle.load(f)

model = EncoderDecoder(params=None)
# PyTorch >=2.6 defaults weights_only=True, which blocks custom classes.
# Patch torch.load to use weights_only=False for trusted local model files.
import torch as _torch
_orig_torch_load = _torch.load
_torch.load = lambda *a, **kw: _orig_torch_load(*a, **{**kw, 'weights_only': False})
model.load(os.path.join(par['input_model'], 'model'))
_torch.load = _orig_torch_load
model.params['device'] = device

mod2_var = pd.read_parquet(os.path.join(par['input_model'], 'mod2_var.parquet'))

# ---- Preprocess test inputs ----
print('Preprocessing test data...', flush=True)
preprocessed_test_inputs, _ = pre_post_process.preprocess(
inputs_values=test_inputs,
targets_values=None,
metadata=test_metadata,
)

# ---- Predict ----
print('Predicting...', flush=True)
y_pred = model.predict(
x=test_inputs,
preprocessed_x=preprocessed_test_inputs,
metadata=test_metadata,
)
gc.collect()

# ---- Write output ----
print('Writing output...', flush=True)
# Prediction must be a sparse matrix to be compatible with all metrics.
if not scipy.sparse.issparse(y_pred):
y_pred = scipy.sparse.csr_matrix(y_pred)

output = ad.AnnData(
layers={"normalized": y_pred},
obs=input_test_mod1.obs,
var=mod2_var,
uns={
"dataset_id": dataset_id,
"method_id": meta["name"],
},
)
output.write_h5ad(par['output'], compression="gzip")
print('Done!', flush=True)
20 changes: 20 additions & 0 deletions src/methods/ss_opm/ss_opm_train/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
__merge__: ../../../api/comp_method_train.yaml
name: ss_opm_train
resources:
- type: python_script
path: script.py
engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1
setup:
- type: docker
run: pip install --no-cache-dir --no-deps git+https://github.com/shu65/open-problems-multimodal.git
- type: python
packages:
- pyarrow
- fastparquet
runners:
- type: executable
- type: nextflow
directives:
label: [highmem, hightime, midcpu, highsharedmem, gpu]
Loading
Loading