Source code for revise.backend.runners.sp_svc_benchmark

import numpy as np
import ot
import scanpy as sc
import scipy
from tqdm import tqdm

from revise.backend.runners.benchmark_svc import BenchmarkSVC
from revise.backend.kernels import SegEvaluateKernel as SegEvaluate
from revise.backend.ops.topology import get_adjacency_graph


[docs] class SpSVC(BenchmarkSVC): """ sp-SVC class for benchmark CFs: segmentation/bin2cell. This class reconstructs single-cell resolution expression profiles from spatial transcriptomics data, with special handling for segmentation errors (diminishing, expanding, unchanged cells). """
[docs] def __init__(self, st_adata, sc_ref_adata, config, real_st_adata, logger): super().__init__(st_adata, sc_ref_adata, config, real_st_adata, logger) self._adata_validate() self._adata_processing() self.seg_evaluate = SegEvaluate(self.config, self.logger) self.svc = {}
[docs] def local_refinement(self): """Reconstruct expression profiles with segmentation-aware smoothing. 1. Evaluate segmentation errors and flag cells that need correction. 2. Split each cell type into ``replace`` and ``candidate`` groups. 3. Use optimal transport between the two groups to obtain smoothed expressions for the ``replace`` cells. 4. Merge corrected and unchanged cells to form ``self.svc["sp_svc"]``. """ if "seg_error" in self.st_adata.obs.columns: self.st_adata = self.seg_evaluate.run(self.st_adata, self.logger) else: self.logger.warning("No 'seg_error' not in st_adata.obs, evaluation skip.") cell_type_adata_list = [] for cell_type in tqdm(self.st_adata.obs[self.config.cell_type_col].unique().tolist(), desc="Reconstruting"): svc_adata_cell_type = self.st_adata[self.st_adata.obs[self.config.cell_type_col] == cell_type] svc_replace_adata = svc_adata_cell_type[~svc_adata_cell_type.obs["no_effect"]] svc_candidate_adata = svc_adata_cell_type[svc_adata_cell_type.obs["no_effect"]] if svc_replace_adata.shape[0] < 50: self.logger.info(f"cell type: {cell_type} has too few spots, skip OT smoothing") svc_replace_adata.layers["ot_smooth"] = svc_replace_adata.X.copy() cell_type_adata_list.append(svc_replace_adata) else: # Build adjacency on ordered data to align replace and candidate partitions svc_ordered = sc.concat([svc_replace_adata, svc_candidate_adata]) adjacent_matrix_all = get_adjacency_graph( svc_ordered, data_type="sc", neighbors_method=self.config.rec_graph_method, alpha=self.config.rec_graph_alpha, gene_neighbor_num=self.config.rec_graph_exp_neighbor_num, spatial_neighbor_num=self.config.rec_graph_spatial_neighbor_num, ) n_recon = svc_replace_adata.shape[0] n_cand = svc_candidate_adata.shape[0] cross_adj = adjacent_matrix_all[:n_recon, n_recon:n_recon + n_cand].tocsr() svc_replace_adata.obsm["cross_connectivities"] = cross_adj cost_matrix = np.zeros((n_recon, self.config.rec_graph_n_neighbors), dtype=svc_replace_adata.X.dtype) neighbor_idx_matrix = np.zeros((n_recon, self.config.rec_graph_n_neighbors), dtype=np.int32) nu_slots = np.zeros(self.config.rec_graph_n_neighbors, dtype=svc_replace_adata.X.dtype) cand_X_csr = svc_candidate_adata.X.tocsr() recon_X_csr = svc_replace_adata.X.tocsr() for i in tqdm(range(n_recon), desc="TopK expression"): row = cross_adj.getrow(i).toarray().ravel() if np.count_nonzero(row) == 0: continue take = min(self.config.rec_graph_n_neighbors, row.size) idx = np.argpartition(-row, kth=take - 1)[:take] idx = idx[np.argsort(-row[idx])] cost_matrix[i, :take] = row[idx].copy() neighbor_idx_matrix[i, :take] = idx.astype(np.int32) slot_expr = cand_X_csr[idx].toarray().mean(axis=1).ravel() nu_slots[:take] += slot_expr mu = np.ravel(recon_X_csr.mean(axis=1)) nu = nu_slots cm_max = float(cost_matrix.max()) if cm_max <= 0: cm_max = 1.0 T_transform = ot.unbalanced.sinkhorn_unbalanced( nu, mu, cost_matrix.T / cm_max, reg=self.config.rec_pot_reg, reg_m=self.config.rec_pot_reg_m, reg_type=self.config.rec_pot_reg_type, verbose=True, numItermax=5000 ) alpha = float(self.config.rec_alpha) smoothed = scipy.sparse.lil_matrix(recon_X_csr.shape, dtype=recon_X_csr.dtype) for i in range(n_recon): idx = neighbor_idx_matrix[i] valid_mask = cost_matrix[i] > 0 if not np.any(valid_mask): smoothed[i] = recon_X_csr.getrow(i) continue idx = idx[valid_mask] w = T_transform[valid_mask, i] w_sum = w.sum() if w_sum > 0: w = w / w_sum neigh_expr = cand_X_csr[idx] weighted = (neigh_expr.T @ w) weighted = np.asarray(weighted).ravel() base = recon_X_csr.getrow(i).toarray().ravel() new_vec = (1.0 - alpha) * base + alpha * weighted smoothed[i] = scipy.sparse.csr_matrix(new_vec) svc_replace_adata.layers["ot_smooth"] = smoothed.tocsr().copy() cell_type_adata_list.append(svc_replace_adata) svc_recon_adata = sc.concat(cell_type_adata_list) svc_recon_adata.X = svc_recon_adata.layers["ot_smooth"].copy() svc_no_effect = self.st_adata[self.st_adata.obs["no_effect"]] self.svc["sp_svc"] = sc.concat([svc_recon_adata, svc_no_effect])