Source code for revise.analysis.services

from __future__ import annotations

from typing import Dict

import scanpy as sc

from revise.utils import ensure_gseapy_stub

# Some environments only run reconstruction/metrics and do not install gseapy.
# Stub it before importing analysis.bio to keep API importable.
ensure_gseapy_stub()

from revise.analysis.bio import conclusions_write
from revise.analysis.bio import get_degs
from revise.analysis.bio import plot_volcano
from revise.analysis.metrics import compute_clustering_metrics
from revise.svc import SVC


[docs] class ScSVCAnalysisService: """Downstream sc-SVC analysis that consumes unified SVC only."""
[docs] def __init__(self, svc: SVC, cluster_col: str = "SVC_cluster") -> None: ensure_gseapy_stub() if svc.expr is None or svc.spatial is None: raise ValueError("ScSVCAnalysisService requires both expr and spatial AnnData") if cluster_col not in svc.expr.obs: raise ValueError(f"cluster_col '{cluster_col}' not found in svc.expr.obs") self.svc = svc self.cluster_col = cluster_col self.sc_SVC_adata_spatial = svc.spatial self.sc_SVC_adata_expr = svc.expr # Keep legacy notebook contract: precomputed DEG table on init. self.sc_SVC_degs = get_degs(self.sc_SVC_adata_expr, groupby=self.cluster_col, method="t-test", fc_threshold=None)
[docs] def get_cm_df(self, sub_cell_type_col: str): grouped = self.sc_SVC_adata_expr.obs.groupby([self.cluster_col, sub_cell_type_col]).size() return grouped.unstack(fill_value=0)
[docs] def get_svc_degs(self, target_cluster=None, fc_threshold=None): if target_cluster is None: select_adata = self.sc_SVC_adata_expr else: if not isinstance(target_cluster, (list, tuple, set)): raise ValueError("target_cluster must be list/tuple/set when provided") if len(target_cluster) == 0: raise ValueError("target_cluster should not be empty") select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs[self.cluster_col].isin(target_cluster)] return get_degs(select_adata, groupby=self.cluster_col, method="t-test", fc_threshold=fc_threshold)
[docs] def get_dot_plot(self, cluster_nums, marker_dict, normalize: bool = False): if cluster_nums is None: series = self.sc_SVC_adata_expr.obs[self.cluster_col] if hasattr(series, "cat"): cluster_nums = series.cat.categories.tolist() else: cluster_nums = series.astype(str).unique().tolist() if not isinstance(cluster_nums, (list, tuple, set)): raise ValueError("cluster_nums should be list/tuple/set") if len(cluster_nums) == 0: raise ValueError("cluster_nums should not be empty") select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs[self.cluster_col].isin(cluster_nums)].copy() if normalize: sc.pp.normalize_total(select_adata, target_sum=1e4) sc.pp.log1p(select_adata) sc.pl.dotplot(select_adata, marker_dict, groupby=self.cluster_col, dendrogram=False)
[docs] def get_pathway_conclusion( self, cluster_nums, fc_threshold, pathway_num, gene_num, geneset_file, normalize: bool = False, ): if cluster_nums is None: series = self.sc_SVC_adata_expr.obs[self.cluster_col] if hasattr(series, "cat"): cluster_nums = series.cat.categories.tolist() else: cluster_nums = series.astype(str).unique().tolist() if not geneset_file: raise ValueError("geneset_file cannot be empty") select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs[self.cluster_col].isin(cluster_nums)] if normalize: select_adata = select_adata.copy() sc.pp.normalize_total(select_adata, target_sum=1e4) sc.pp.log1p(select_adata) deg_df = get_degs(select_adata, groupby=self.cluster_col, method="t-test", fc_threshold=fc_threshold) return conclusions_write( deg_df, geneset_file, gene_num=gene_num, pathway_num=pathway_num, print_flag=True, conclusion_file_name=None, )
[docs] def get_volcano_plot( self, cluster_nums, target_group, replace_cols=None, fc_threshold=None, log_fold_changes=10, logfc_threshold=1, padj_threshold=1e-6, top_k=10, ): indices = self.sc_SVC_adata_expr.obs[self.cluster_col].isin(cluster_nums) select_adata = self.sc_SVC_adata_expr[indices, :].copy() if replace_cols is not None: select_adata.obs[self.cluster_col].replace(replace_cols, inplace=True) select_deg_df = get_degs(select_adata, groupby=self.cluster_col, method="t-test", fc_threshold=fc_threshold) select_deg_df = select_deg_df[select_deg_df["group"] == target_group].copy() select_deg_df = select_deg_df[select_deg_df["logfoldchanges"].abs() <= log_fold_changes].copy() plot_volcano( select_deg_df, logfc_threshold=logfc_threshold, padj_threshold=padj_threshold, top_k=top_k, save_file_name=None, ) return select_deg_df
[docs] class SpSVCAnalysisService: """Basic downstream metrics for sp-SVC outputs."""
[docs] def __init__(self, svc: SVC) -> None: if svc.spatial is None: raise ValueError("SpSVCAnalysisService requires svc.spatial") self.svc = svc
[docs] def clustering_metrics(self, pred_col: str, ref_col: str) -> Dict[str, float]: ari, nmi = compute_clustering_metrics(self.svc.spatial, pred_col, ref_col) return {"ari": float(ari), "nmi": float(nmi)}