Source code for revise.backend.runners.sp_svc_application

import os

import matplotlib.pyplot as plt
import numpy as np
import ot
import scanpy as sc
from scipy import sparse
from tqdm import tqdm

from revise.backend.runners.application_svc import ApplicationSVC
from revise.backend.kernels import GraphAggregateKernel as GraphAggregate
from revise.analysis.metrics import compute_clustering_metrics
from revise.backend.ops.shaver import trim_sp_adata
from revise.backend.ops.topology import get_adjacency_graph


def _legacy_dense_topk(adjacent_matrix, row_index, n_neighbors):
    row = adjacent_matrix[row_index].toarray().ravel()
    if np.count_nonzero(row) == 0:
        return None, None
    take = min(n_neighbors, row.size)
    idx = np.argpartition(-row, kth=take - 1)[:n_neighbors]
    idx = idx[np.argsort(-row[idx])]
    return idx.astype(np.int32, copy=False), row[idx].copy()


def _sparse_exact_topk(adjacent_csr, row_index, n_neighbors):
    start = adjacent_csr.indptr[row_index]
    end = adjacent_csr.indptr[row_index + 1]
    row_indices = adjacent_csr.indices[start:end]
    row_data = adjacent_csr.data[start:end]

    if row_data.size == 0:
        return None, None, False

    nonzero_mask = row_data != 0
    if not np.all(nonzero_mask):
        row_indices = row_indices[nonzero_mask]
        row_data = row_data[nonzero_mask]
        if row_data.size == 0:
            return None, None, False

    take = min(n_neighbors, adjacent_csr.shape[1])
    if row_data.size >= take and np.all(row_data > 0):
        local_idx = np.argpartition(-row_data, kth=take - 1)[:take]
        selected_data = row_data[local_idx]
        # Tied selected weights can change neighbor order versus the dense
        # np.argpartition path, so keep the original implementation for them.
        boundary_value = selected_data.min()
        boundary_is_unique = (
            np.count_nonzero(row_data == boundary_value)
            == np.count_nonzero(selected_data == boundary_value)
        )
        if boundary_is_unique and np.unique(selected_data).size == selected_data.size:
            order = np.argsort(-selected_data)
            local_idx = local_idx[order]
            return (
                row_indices[local_idx].astype(np.int32, copy=False),
                row_data[local_idx].copy(),
                False,
            )

    idx, values = _legacy_dense_topk(adjacent_csr, row_index, n_neighbors)
    return idx, values, True


def _compute_topk_expression(adjacent_matrix, expression_matrix, n_neighbors, dtype, progress=True):
    adjacent_csr = adjacent_matrix.tocsr()
    n_obs = adjacent_csr.shape[0]

    cost_matrix = np.zeros((n_obs, n_neighbors), dtype=dtype)
    neighbor_margin_expr = np.zeros(n_neighbors, dtype=dtype)
    neighbor_idx_matrix = np.zeros((n_obs, n_neighbors), dtype=np.int32)

    if sparse.issparse(expression_matrix):
        expression_csr = expression_matrix.tocsr()
        row_expr_mean = np.empty(n_obs, dtype=dtype)
        for i in range(n_obs):
            row_expr_mean[i] = np.mean(expression_csr[i].toarray(), axis=1).ravel()[0]
    else:
        row_expr_mean = np.mean(np.asarray(expression_matrix), axis=1)
    row_expr_mean = row_expr_mean.astype(dtype, copy=False)

    dense_fallback_rows = 0
    iterator = range(n_obs)
    if progress:
        iterator = tqdm(iterator, desc="TopK expression")
    for i in iterator:
        idx, values, used_dense_fallback = _sparse_exact_topk(adjacent_csr, i, n_neighbors)
        if idx is None:
            continue
        if used_dense_fallback:
            dense_fallback_rows += 1

        neighbor_margin_expr += row_expr_mean[idx]
        cost_matrix[i] = values.astype(dtype, copy=False)
        neighbor_idx_matrix[i] = idx

    return cost_matrix, neighbor_margin_expr, neighbor_idx_matrix, dense_fallback_rows


[docs] class SpSVC(ApplicationSVC): """ sp-SVC class for application usage. This class reconstructs single-cell resolution expression profiles from spatial transcriptomics data using optimal transport-based graph aggregation for each cell type. """
[docs] def __init__(self, st_adata, sc_ref_adata, config, logger): super().__init__(st_adata, sc_ref_adata, config, None, logger) self._adata_validate() self.overlap_genes = list(self.st_adata.var_names.intersection(self.sc_ref_adata.var_names)) self.st_adata = self.st_adata[:, self.overlap_genes] self.sc_ref_adata = self.sc_ref_adata[:, self.overlap_genes] self.svc = {} self.graph_aggregate = GraphAggregate(config, logger)
[docs] def local_refinement(self): """ Reconstruct single-cell resolution expression profiles. This method performs the following steps: 1. Trims spatial data by removing low-expression genes 2. For each cell type, constructs an adjacency graph 3. Uses optimal transport to find neighbor relationships 4. Aggregates neighbor expressions using graph-based smoothing 5. Optionally generates UMAP plots for visualization The reconstructed data is stored in self.svc["sp_svc"]. """ if self.config.plot_flag: self.logger.info("Plotting Raw ...") self._umap_plot(self.st_adata, prefix="Raw") svc_recon_adata = self.st_adata.copy() self.logger.info(f"before trim: {svc_recon_adata.X.data.shape}") svc_recon_adata, celltype_genes = trim_sp_adata(svc_recon_adata, self.sc_ref_adata, "Level1") self.logger.info(f"after trim: {svc_recon_adata.X.data.shape}") svc_recon_adata.obsm = self.st_adata.obsm.copy() cell_type_adata_list = [] for cell_type in tqdm(svc_recon_adata.obs[self.config.cell_type_col].unique().tolist(), desc="Reconstructing"): svc_recon_adata_cell_type = svc_recon_adata[svc_recon_adata.obs[self.config.cell_type_col] == cell_type] raw_st_adata_cell_type = svc_recon_adata_cell_type.copy() self.logger.info(f"begin OT smoothing for cell type: {cell_type}, adata shape: {svc_recon_adata_cell_type.shape}") if svc_recon_adata_cell_type.shape[0] < 50: self.logger.info(f"cell type: {cell_type}, has too few spots, skip OT smoothing") cell_type_adata_list.append(svc_recon_adata_cell_type) else: adjacent_matrix = get_adjacency_graph( svc_recon_adata_cell_type, data_type="sc", neighbors_method=self.config.rec_graph_method, alpha=self.config.rec_graph_alpha, gene_neighbor_num=self.config.rec_graph_exp_neighbor_num, spatial_neighbor_num=self.config.rec_graph_spatial_neighbor_num, ) svc_recon_adata_cell_type.obsp["joint_connectivities"] = adjacent_matrix cost_matrix, neighbor_margin_expr, neighbor_idx_matrix, dense_fallback_rows = _compute_topk_expression( adjacent_matrix=adjacent_matrix, expression_matrix=svc_recon_adata_cell_type.X, n_neighbors=self.config.rec_graph_n_neighbors, dtype=svc_recon_adata_cell_type.X.dtype, progress=True, ) if dense_fallback_rows: self.logger.info( f"TopK expression used dense fallback for {dense_fallback_rows}/" f"{adjacent_matrix.shape[0]} rows to preserve exact legacy ordering." ) mu = np.ravel(svc_recon_adata_cell_type.X.sum(axis=1)) nu = neighbor_margin_expr # np.nan_to_num(cost_matrix, cost_matrix.max()) cost_matrix = 1 - cost_matrix cost_matrix = 1 / cost_matrix T_transform = ot.unbalanced.sinkhorn_unbalanced( nu, mu, cost_matrix.T / cost_matrix.max(), reg=self.config.rec_pot_reg, reg_m=self.config.rec_pot_reg_m, reg_type=self.config.rec_pot_reg_type, verbose=True, numItermax=5000 ) # Ensure expressions are unchanged before aggregation if sparse.issparse(svc_recon_adata_cell_type.X) and sparse.issparse(raw_st_adata_cell_type.X): assert (svc_recon_adata_cell_type.X != raw_st_adata_cell_type.X).nnz == 0 else: assert np.array_equal( np.asarray(svc_recon_adata_cell_type.X), np.asarray(raw_st_adata_cell_type.X) ) svc_recon_adata_cell_type = self.graph_aggregate.run( adata=svc_recon_adata_cell_type, neighbor_idx_matrix=neighbor_idx_matrix, coupling_matrix=T_transform ) cell_type_adata_list.append(svc_recon_adata_cell_type) self.svc["sp_svc"] = sc.concat(cell_type_adata_list) self.svc["sp_svc"].X = sparse.csr_matrix(self.svc["sp_svc"].X) if self.config.plot_flag: self.logger.info("Plotting spSVC...") self._umap_plot(self.svc["sp_svc"], prefix="sp_SVC")
[docs] def _umap_plot(self, adata, prefix): """ Generate UMAP visualization plots. Args: adata: AnnData object to plot prefix: Prefix string for output file names This method performs preprocessing (filtering, normalization, PCA), computes clustering at multiple resolutions, and generates UMAP and spatial scatter plots saved to the result directory. """ adata = adata.copy() sc.pp.filter_cells(adata, min_genes=self.config.plot_min_genes) sc.pp.filter_genes(adata, min_cells=self.config.plot_min_cells) if self.config.plot_sample_size > 0: self.logger.info(f"Downsampling to {self.config.plot_sample_size} cells for plotting ...") np.random.seed(self.config.plot_sample_size) indices = np.random.choice(adata.shape[0], self.config.plot_sample_size, replace=False) adata = adata[indices, :].copy() sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) if adata.shape[1] > 2000: self.logger.info(f"Highly variable genes filtering for {adata.shape[1]} genes.") sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True) adata = adata[:, adata.var.highly_variable].copy() sc.tl.pca(adata, svd_solver='arpack') sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40) for res in self.config.plot_cluster_resolution: sc.tl.leiden(adata, resolution=res, key_added=f"leiden_res_{res}") n_clusters = len(adata.obs[f"leiden_res_{res}"].cat.categories) self.logger.info(f"Number of clusters for leiden resolution {res}: {n_clusters}") ari, nmi = compute_clustering_metrics(adata, f"leiden_res_{res}", self.config.cell_type_col) self.logger.info(f"ari: {ari}, nmi: {nmi}") sc.tl.umap(adata) umap_resolution = [f"leiden_res_{res}" for res in self.config.plot_cluster_resolution] umap_resolution = [self.config.cell_type_col] + umap_resolution sc.pl.umap(adata, color=umap_resolution, show=False) plt.savefig(os.path.join(self.config.result_dir, f"{prefix}_umap.png")) plt.close() for res in self.config.plot_cluster_resolution: sc.pl.scatter(adata, x='x', y='y', color=f'leiden_res_{res}', show=False) plt.savefig(os.path.join(self.config.result_dir, f"{prefix}_resolution_{res}_scatter.png")) plt.close()