import os
import os.path
import shutil
import numpy as np
import ot
import pandas as pd
import scanpy as sc
from tqdm import tqdm
from revise.backend.runners.benchmark_svc import BenchmarkSVC
from revise.backend.kernels import GeneImputeKernel as GeneImpute
from revise.backend.kernels import GeneUncertaintyKernel as GeneUncertainty
from revise.backend.ops.distance import bhattacharyya_distance
from revise.backend.ops.meta import get_subcluster
from revise.backend.ops.meta import merge_subcluster
from revise.backend.ops.shaver import get_prune_adata
[docs]
class ScSVCImpute(BenchmarkSVC):
"""
Single-cell SVC imputation for benchmark CFs: gene panel/gene dropout.
This class performs gene imputation by comparing in-panel vs all-panel
HVG selection strategies and using optimal transport for imputation.
"""
[docs]
def __init__(self, st_adata, sc_ref_adata, config, real_st_adata, logger):
super().__init__(st_adata, sc_ref_adata, config, real_st_adata, logger)
self._adata_validate()
self._adata_processing()
self._adata_processing_impute()
self.gene_uncertainty = GeneUncertainty(self.config, self.logger)
self.gene_impute = GeneImpute(self.config, self.logger)
self.svc = {}
[docs]
def _adata_processing_impute(self):
"""
Process data with transcript count filtering.
"""
if "cell_id" in self.st_adata.obs.columns:
self.st_adata.obs_names = self.st_adata.obs["cell_id"]
self.st_adata = self.st_adata[self.st_adata.obs['transcript_counts'] >= self.config.prep_min_counts, :]
sc.pp.filter_genes(self.st_adata, min_cells=self.config.prep_min_cells)
self.sc_ref_adata.obs = self.sc_ref_adata.obs[[self.config.cell_type_col]]
sc.pp.filter_genes(self.sc_ref_adata, min_cells=self.config.prep_min_cells)
self.sc_ref_adata.obs[self.config.cell_type_col].replace({"Mono/Macro": "Mono_Macro"}, inplace=True)
[docs]
def local_refinement(self, *args, **kwargs):
"""
Reconstruct expression profiles using gene imputation.
1. Evaluates gene uncertainty comparing in-panel vs all-panel strategies
2. Generates subclustered single-cell data for both strategies
3. Performs local imputation for each cell type using optimal transport
4. Optionally prunes imputed data
Results are stored in:
- self.svc["sc_svc_impute_all_panel"]: Imputation using all-panel strategy
- self.svc["sc_svc_impute_in_panel"]: Imputation using in-panel strategy
"""
overlap_genes = list(self.st_adata.var_names.intersection(self.sc_ref_adata.var_names))
assert len(overlap_genes) > 0, "overlap genes not found"
gene_compare_file = os.path.join(self.config.result_dir, "compare_in_vs_all_panel_moranI.csv")
self._materialize_cached_gene_compare(gene_compare_file)
if not os.path.exists(gene_compare_file):
compare_df = self.gene_uncertainty.run(self.sc_ref_adata, overlap_genes)
compare_df.to_csv(gene_compare_file)
else:
compare_df = pd.read_csv(gene_compare_file, index_col=0)
compare_df = compare_df[~compare_df['test']]
in_panel_file = os.path.join(self.config.result_dir, "adata_sc_in_panel.h5ad")
all_panel_file = os.path.join(self.config.result_dir, "adata_sc_all_panel.h5ad")
self._materialize_cached_subcluster(in_panel_file, all_panel_file)
if os.path.exists(in_panel_file) and os.path.exists(all_panel_file):
self.logger.info(f"Load {in_panel_file} and {all_panel_file}")
adata_sc_in_panel = sc.read(in_panel_file)
adata_sc_all_panel = sc.read(all_panel_file)
else:
self.logger.info(f"Build {in_panel_file} and {all_panel_file}")
adata_sc_all_panel, adata_sc_in_panel = get_subcluster(
self.sc_ref_adata,
compare_df,
celltype_col=self.config.cell_type_col)
adata_sc_in_panel.write(in_panel_file)
adata_sc_all_panel.write(all_panel_file)
self.svc["sc_svc_impute_all_panel"] = self.local_impute(
adata_sc_all_panel, f"leiden_{self.config.rec_subcluster_resolution}"
)
# impute in panel
self.svc["sc_svc_impute_in_panel"] = self.local_impute(
adata_sc_in_panel, f"leiden_{self.config.rec_subcluster_resolution}"
)
# self.st_adata = self.st_adata[self.svc["sc_svc_impute_in_panel"].obs_name, :]
# metrics_in_panel = compute_metric(
# adata_to_metric, adata_sp_impute_in_panel, self.logger,
# adata_process=False,
# gene_list=gene_list,
# normalize=True
# )
# metrics_in_panel.to_csv(os.path.join(self.config.metric_dir, f"metrics_in_panel.csv"))
[docs]
def _materialize_cached_gene_compare(self, target_file: str) -> None:
"""
Optionally hydrate compare CSV from cache to avoid re-running uncertainty.
This path is only enabled when REVISE_GENE_COMPARE_CACHE is set and
target_file does not already exist.
"""
if os.path.exists(target_file):
return
cache_file = os.environ.get("REVISE_GENE_COMPARE_CACHE")
if not cache_file:
return
if not os.path.exists(cache_file):
self.logger.warning("REVISE_GENE_COMPARE_CACHE does not exist: %s", cache_file)
return
shutil.copyfile(cache_file, target_file)
self.logger.info("Loaded cached compare file for impute benchmark: %s -> %s", cache_file, target_file)
[docs]
def _materialize_cached_subcluster(self, in_panel_file: str, all_panel_file: str) -> None:
"""
Optionally hydrate subcluster AnnData files from the compare cache directory.
"""
if os.path.exists(in_panel_file) and os.path.exists(all_panel_file):
return
cache_file = os.environ.get("REVISE_GENE_COMPARE_CACHE")
if not cache_file:
return
cache_dir = os.path.dirname(cache_file)
src_in_panel = os.path.join(cache_dir, "adata_sc_in_panel.h5ad")
src_all_panel = os.path.join(cache_dir, "adata_sc_all_panel.h5ad")
if not (os.path.exists(src_in_panel) and os.path.exists(src_all_panel)):
return
if not os.path.exists(in_panel_file):
shutil.copyfile(src_in_panel, in_panel_file)
if not os.path.exists(all_panel_file):
shutil.copyfile(src_all_panel, all_panel_file)
self.logger.info(
"Loaded cached subcluster files for impute benchmark: %s , %s -> %s",
src_in_panel,
src_all_panel,
self.config.result_dir,
)
[docs]
def local_impute(
self,
adata_sc,
sc_subcluster
):
"""
Perform local imputation for each cell type using subclustered reference.
Args:
adata_sc: Subclustered single-cell reference AnnData
sc_subcluster: Column name in adata_sc.obs containing subcluster labels
Returns:
AnnData: Imputed spatial data with reconstructed expressions
1. Processes each cell type separately
2. Computes subcluster profiles and distances
3. Uses optimal transport to find spot-subcluster mappings
4. Imputes gene expressions using OT coupling weights
"""
adata_sp = self.st_adata.copy()
adata_sc = adata_sc.copy()
cts = list(adata_sc.obs[self.config.cell_type_col].unique())
adata_sp_cts = []
for select_ct in tqdm(cts, "Imputation by cell type"):
self.logger.info(f"Conducting cell type: {select_ct} ........")
ct_adata_sc = adata_sc[adata_sc.obs[self.config.cell_type_col] == select_ct].copy()
ct_adata_sp = adata_sp[adata_sp.obs[self.config.cell_type_col] == select_ct].copy()
overlap_genes = ct_adata_sc.var_names.intersection(ct_adata_sp.var_names)
ct_adata_sc_overlap = ct_adata_sc[:, overlap_genes].copy()
ct_adata_sp_overlap = ct_adata_sp[:, overlap_genes].copy()
dums = pd.get_dummies(ct_adata_sc_overlap.obs[sc_subcluster],
dtype=ct_adata_sc_overlap.X.dtype)
ncats = dums.sum(axis=0)
dums /= ncats.to_numpy()
profiles = ct_adata_sc_overlap.X.T @ dums.to_numpy()
profiles = pd.DataFrame(profiles, index=ct_adata_sc_overlap.var.index, columns=dums.columns)
ct_adata_sc_overlap.varm[sc_subcluster] = profiles
dist = bhattacharyya_distance(profiles.values.T, ct_adata_sp_overlap.X.toarray())
cell_profile_mapping = pd.get_dummies(ct_adata_sc_overlap.obs[sc_subcluster])
cell_profile_mapping /= cell_profile_mapping.sum(axis=1).to_numpy()[:, None]
type_prior = np.array(ct_adata_sc_overlap.X.sum(axis=1)).flatten() @ cell_profile_mapping
spot_prior = pd.Series(np.array(ct_adata_sp_overlap.X.sum(axis=1)).flatten(),
index=ct_adata_sp_overlap.obs.index)
spot_prior /= spot_prior.sum()
type_prior /= type_prior.sum()
T_matrix = ot.unbalanced.sinkhorn_unbalanced(
spot_prior.values,
type_prior.values,
dist.T / dist.max(),
reg=self.config.rec_impute_pot_reg,
reg_m=self.config.rec_impute_pot_reg_m,
reg_type=self.config.rec_impute_pot_reg_type,
verbose=True,
numItermax=5000
)
T_matrix = pd.DataFrame(T_matrix, index=spot_prior.index, columns=type_prior.index)
ct_adata_sc = merge_subcluster(
ct_adata_sc,
subcluster=sc_subcluster,
mode=self.config.rec_merge_subcluster_method
)
overlap_genes = ct_adata_sp.var_names.intersection(ct_adata_sc.var_names)
adata_sp_impute = self.gene_impute.run(
ct_adata_sp, ct_adata_sc,
genes_to_predict=overlap_genes,
neighbor_weights=T_matrix,
)
adata_sp_cts.append(adata_sp_impute)
adata_sp_impute = sc.concat(adata_sp_cts)
if self.config.rec_impute_prune_flag:
adata_sp_impute = get_prune_adata(adata_sp_impute)
return adata_sp_impute