import os
import matplotlib.pyplot as plt
import numpy as np
import ot
import scanpy as sc
from scipy import sparse
from tqdm import tqdm
from revise.backend.runners.application_svc import ApplicationSVC
from revise.backend.kernels import GraphAggregateKernel as GraphAggregate
from revise.analysis.metrics import compute_clustering_metrics
from revise.backend.ops.shaver import trim_sp_adata
from revise.backend.ops.topology import get_adjacency_graph
def _legacy_dense_topk(adjacent_matrix, row_index, n_neighbors):
row = adjacent_matrix[row_index].toarray().ravel()
if np.count_nonzero(row) == 0:
return None, None
take = min(n_neighbors, row.size)
idx = np.argpartition(-row, kth=take - 1)[:n_neighbors]
idx = idx[np.argsort(-row[idx])]
return idx.astype(np.int32, copy=False), row[idx].copy()
def _sparse_exact_topk(adjacent_csr, row_index, n_neighbors):
start = adjacent_csr.indptr[row_index]
end = adjacent_csr.indptr[row_index + 1]
row_indices = adjacent_csr.indices[start:end]
row_data = adjacent_csr.data[start:end]
if row_data.size == 0:
return None, None, False
nonzero_mask = row_data != 0
if not np.all(nonzero_mask):
row_indices = row_indices[nonzero_mask]
row_data = row_data[nonzero_mask]
if row_data.size == 0:
return None, None, False
take = min(n_neighbors, adjacent_csr.shape[1])
if row_data.size >= take and np.all(row_data > 0):
local_idx = np.argpartition(-row_data, kth=take - 1)[:take]
selected_data = row_data[local_idx]
# Tied selected weights can change neighbor order versus the dense
# np.argpartition path, so keep the original implementation for them.
boundary_value = selected_data.min()
boundary_is_unique = (
np.count_nonzero(row_data == boundary_value)
== np.count_nonzero(selected_data == boundary_value)
)
if boundary_is_unique and np.unique(selected_data).size == selected_data.size:
order = np.argsort(-selected_data)
local_idx = local_idx[order]
return (
row_indices[local_idx].astype(np.int32, copy=False),
row_data[local_idx].copy(),
False,
)
idx, values = _legacy_dense_topk(adjacent_csr, row_index, n_neighbors)
return idx, values, True
def _compute_topk_expression(adjacent_matrix, expression_matrix, n_neighbors, dtype, progress=True):
adjacent_csr = adjacent_matrix.tocsr()
n_obs = adjacent_csr.shape[0]
cost_matrix = np.zeros((n_obs, n_neighbors), dtype=dtype)
neighbor_margin_expr = np.zeros(n_neighbors, dtype=dtype)
neighbor_idx_matrix = np.zeros((n_obs, n_neighbors), dtype=np.int32)
if sparse.issparse(expression_matrix):
expression_csr = expression_matrix.tocsr()
row_expr_mean = np.empty(n_obs, dtype=dtype)
for i in range(n_obs):
row_expr_mean[i] = np.mean(expression_csr[i].toarray(), axis=1).ravel()[0]
else:
row_expr_mean = np.mean(np.asarray(expression_matrix), axis=1)
row_expr_mean = row_expr_mean.astype(dtype, copy=False)
dense_fallback_rows = 0
iterator = range(n_obs)
if progress:
iterator = tqdm(iterator, desc="TopK expression")
for i in iterator:
idx, values, used_dense_fallback = _sparse_exact_topk(adjacent_csr, i, n_neighbors)
if idx is None:
continue
if used_dense_fallback:
dense_fallback_rows += 1
neighbor_margin_expr += row_expr_mean[idx]
cost_matrix[i] = values.astype(dtype, copy=False)
neighbor_idx_matrix[i] = idx
return cost_matrix, neighbor_margin_expr, neighbor_idx_matrix, dense_fallback_rows
[docs]
class SpSVC(ApplicationSVC):
"""
sp-SVC class for application usage.
This class reconstructs single-cell resolution expression profiles
from spatial transcriptomics data using optimal transport-based
graph aggregation for each cell type.
"""
[docs]
def __init__(self, st_adata, sc_ref_adata, config, logger):
super().__init__(st_adata, sc_ref_adata, config, None, logger)
self._adata_validate()
self.overlap_genes = list(self.st_adata.var_names.intersection(self.sc_ref_adata.var_names))
self.st_adata = self.st_adata[:, self.overlap_genes]
self.sc_ref_adata = self.sc_ref_adata[:, self.overlap_genes]
self.svc = {}
self.graph_aggregate = GraphAggregate(config, logger)
[docs]
def local_refinement(self):
"""
Reconstruct single-cell resolution expression profiles.
This method performs the following steps:
1. Trims spatial data by removing low-expression genes
2. For each cell type, constructs an adjacency graph
3. Uses optimal transport to find neighbor relationships
4. Aggregates neighbor expressions using graph-based smoothing
5. Optionally generates UMAP plots for visualization
The reconstructed data is stored in self.svc["sp_svc"].
"""
if self.config.plot_flag:
self.logger.info("Plotting Raw ...")
self._umap_plot(self.st_adata, prefix="Raw")
svc_recon_adata = self.st_adata.copy()
self.logger.info(f"before trim: {svc_recon_adata.X.data.shape}")
svc_recon_adata, celltype_genes = trim_sp_adata(svc_recon_adata, self.sc_ref_adata, "Level1")
self.logger.info(f"after trim: {svc_recon_adata.X.data.shape}")
svc_recon_adata.obsm = self.st_adata.obsm.copy()
cell_type_adata_list = []
for cell_type in tqdm(svc_recon_adata.obs[self.config.cell_type_col].unique().tolist(), desc="Reconstructing"):
svc_recon_adata_cell_type = svc_recon_adata[svc_recon_adata.obs[self.config.cell_type_col] == cell_type]
raw_st_adata_cell_type = svc_recon_adata_cell_type.copy()
self.logger.info(f"begin OT smoothing for cell type: {cell_type}, adata shape: {svc_recon_adata_cell_type.shape}")
if svc_recon_adata_cell_type.shape[0] < 50:
self.logger.info(f"cell type: {cell_type}, has too few spots, skip OT smoothing")
cell_type_adata_list.append(svc_recon_adata_cell_type)
else:
adjacent_matrix = get_adjacency_graph(
svc_recon_adata_cell_type,
data_type="sc",
neighbors_method=self.config.rec_graph_method,
alpha=self.config.rec_graph_alpha,
gene_neighbor_num=self.config.rec_graph_exp_neighbor_num,
spatial_neighbor_num=self.config.rec_graph_spatial_neighbor_num,
)
svc_recon_adata_cell_type.obsp["joint_connectivities"] = adjacent_matrix
cost_matrix, neighbor_margin_expr, neighbor_idx_matrix, dense_fallback_rows = _compute_topk_expression(
adjacent_matrix=adjacent_matrix,
expression_matrix=svc_recon_adata_cell_type.X,
n_neighbors=self.config.rec_graph_n_neighbors,
dtype=svc_recon_adata_cell_type.X.dtype,
progress=True,
)
if dense_fallback_rows:
self.logger.info(
f"TopK expression used dense fallback for {dense_fallback_rows}/"
f"{adjacent_matrix.shape[0]} rows to preserve exact legacy ordering."
)
mu = np.ravel(svc_recon_adata_cell_type.X.sum(axis=1))
nu = neighbor_margin_expr
# np.nan_to_num(cost_matrix, cost_matrix.max())
cost_matrix = 1 - cost_matrix
cost_matrix = 1 / cost_matrix
T_transform = ot.unbalanced.sinkhorn_unbalanced(
nu,
mu,
cost_matrix.T / cost_matrix.max(),
reg=self.config.rec_pot_reg,
reg_m=self.config.rec_pot_reg_m,
reg_type=self.config.rec_pot_reg_type,
verbose=True,
numItermax=5000
)
# Ensure expressions are unchanged before aggregation
if sparse.issparse(svc_recon_adata_cell_type.X) and sparse.issparse(raw_st_adata_cell_type.X):
assert (svc_recon_adata_cell_type.X != raw_st_adata_cell_type.X).nnz == 0
else:
assert np.array_equal(
np.asarray(svc_recon_adata_cell_type.X),
np.asarray(raw_st_adata_cell_type.X)
)
svc_recon_adata_cell_type = self.graph_aggregate.run(
adata=svc_recon_adata_cell_type,
neighbor_idx_matrix=neighbor_idx_matrix,
coupling_matrix=T_transform
)
cell_type_adata_list.append(svc_recon_adata_cell_type)
self.svc["sp_svc"] = sc.concat(cell_type_adata_list)
self.svc["sp_svc"].X = sparse.csr_matrix(self.svc["sp_svc"].X)
if self.config.plot_flag:
self.logger.info("Plotting spSVC...")
self._umap_plot(self.svc["sp_svc"], prefix="sp_SVC")
[docs]
def _umap_plot(self, adata, prefix):
"""
Generate UMAP visualization plots.
Args:
adata: AnnData object to plot
prefix: Prefix string for output file names
This method performs preprocessing (filtering, normalization, PCA),
computes clustering at multiple resolutions, and generates UMAP
and spatial scatter plots saved to the result directory.
"""
adata = adata.copy()
sc.pp.filter_cells(adata, min_genes=self.config.plot_min_genes)
sc.pp.filter_genes(adata, min_cells=self.config.plot_min_cells)
if self.config.plot_sample_size > 0:
self.logger.info(f"Downsampling to {self.config.plot_sample_size} cells for plotting ...")
np.random.seed(self.config.plot_sample_size)
indices = np.random.choice(adata.shape[0], self.config.plot_sample_size, replace=False)
adata = adata[indices, :].copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
if adata.shape[1] > 2000:
self.logger.info(f"Highly variable genes filtering for {adata.shape[1]} genes.")
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True)
adata = adata[:, adata.var.highly_variable].copy()
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
for res in self.config.plot_cluster_resolution:
sc.tl.leiden(adata, resolution=res, key_added=f"leiden_res_{res}")
n_clusters = len(adata.obs[f"leiden_res_{res}"].cat.categories)
self.logger.info(f"Number of clusters for leiden resolution {res}: {n_clusters}")
ari, nmi = compute_clustering_metrics(adata, f"leiden_res_{res}", self.config.cell_type_col)
self.logger.info(f"ari: {ari}, nmi: {nmi}")
sc.tl.umap(adata)
umap_resolution = [f"leiden_res_{res}" for res in self.config.plot_cluster_resolution]
umap_resolution = [self.config.cell_type_col] + umap_resolution
sc.pl.umap(adata, color=umap_resolution, show=False)
plt.savefig(os.path.join(self.config.result_dir, f"{prefix}_umap.png"))
plt.close()
for res in self.config.plot_cluster_resolution:
sc.pl.scatter(adata, x='x', y='y', color=f'leiden_res_{res}', show=False)
plt.savefig(os.path.join(self.config.result_dir, f"{prefix}_resolution_{res}_scatter.png"))
plt.close()