Source code for revise.backend.runners.sc_svc_application

import scanpy as sc

from revise.backend.runners.application_svc import ApplicationSVC
from revise.backend.kernels import GraphClusterKernel as GraphCluster
from revise.analysis.bio import get_degs
from revise.analysis.bio import conclusions_write
from revise.analysis.bio import plot_volcano


"""
The sc-SVC application is organized into two classes:
1. ScSVC: reconstructs sc-SVC. `local_refinement` returns a cell-type-specific sc-SVC,
   including `sc_SVC_adata_spatial` and `sc_SVC_adata_expr` for that cell type (h5ad format).
2. ScSVCAnalysis: performs downstream bioinformatics analyses based on
   `sc_SVC_adata_spatial` and `sc_SVC_adata_expr`.
"""


[docs] class ScSVC(ApplicationSVC): """ sc-SVC class for application usage. This class handles single-cell resolution spatial transcriptomics data, filtering cells and genes based on transcript counts and preparing data for downstream annotation and reconstruction. """
[docs] def __init__(self, st_adata, sc_ref_adata, config, logger): super().__init__(st_adata, sc_ref_adata, config, None, logger) self._adata_validate() self.sc_ref_adata_raw = self.sc_ref_adata.copy() self.graph_cluster = GraphCluster(self.config, self.logger) self.cluster_col = "SVC_cluster"
[docs] def local_refinement(self, select_ct, sub_cell_type_col, resolutions, select_res=None): ct_adata_sp = self.st_adata[self.st_adata.obs['Level1'] == select_ct] ct_adata_sc = self.sc_ref_adata[self.sc_ref_adata.obs['Level1'] == select_ct] annotate_kwargs = dict(self.config.__dict__) annotate_kwargs["cell_type_col"] = sub_cell_type_col ct_adata_sp = self.annotate_method.run(ct_adata_sp, ct_adata_sc, **annotate_kwargs) sc_SVC_adata, merge_df, best_res = self.graph_cluster.run(ct_adata_sp, resolutions, sub_cell_type_col) if select_res is None: self.logger.info(f"User does not input select_res, use best_res {best_res} based on spatial alignment score") select_res = best_res else: self.logger.info(f"Use resolution {select_res} from user input") sc_SVC_adata.obs[self.cluster_col] = sc_SVC_adata.obs[f'leiden_{select_res}'].astype('category') sp_cluster_num = merge_df.loc[merge_df['resolution'] == best_res, 'cluster_num'].values[0] self.logger.info(f"resolution {select_res} got cluster number {sp_cluster_num}") annotate_kwargs = dict(self.config.__dict__) annotate_kwargs["cell_type_col"] = self.cluster_col ct_adata_sc = self.annotate_method.run(ct_adata_sc, sc_SVC_adata, **annotate_kwargs) return sc_SVC_adata, ct_adata_sc
class ScSVCAnalysis: def __init__(self, sc_SVC_adata_spatial, sc_SVC_adata_expr, cluster_col): self.sc_SVC_adata_spatial = sc_SVC_adata_spatial self.sc_SVC_adata_expr = sc_SVC_adata_expr self.cluster_col = cluster_col self.sc_SVC_degs = get_degs(self.sc_SVC_adata_expr, groupby=self.cluster_col, method='t-test', fc_threshold=None) # self.sc_SVC_degs.to_csv(f"{self.config.result_root_path}/degs_all.csv") def get_cm_df(self, sub_cell_type_col): grouped = self.sc_SVC_adata_expr.obs.groupby([self.cluster_col, sub_cell_type_col]).size() cm_df = grouped.unstack(fill_value=0) return cm_df def get_svc_degs(self, target_cluster=None, fc_threshold=None): if not target_cluster: return get_degs(self.sc_SVC_adata_expr, groupby=self.cluster_col, method='t-test', fc_threshold=fc_threshold) else: assert isinstance(target_cluster, (list, tuple, set)), f"target_cluster {target_cluster} should be a list, tuple, or set" assert len(target_cluster) > 0, "target_cluster should not be empty" select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs[self.cluster_col].isin(target_cluster)] return get_degs(select_adata, groupby=self.cluster_col, method='t-test', fc_threshold=fc_threshold) def get_dot_plot(self, cluster_nums, marker_dict, normalize=False): if cluster_nums is None: print("Using all clusters") cluster_nums = self.sc_SVC_adata_expr.obs['SVC_cluster'].cat.categories.tolist() assert isinstance(cluster_nums, (list, tuple, set)), f"target_cluster {cluster_nums} should be a list, tuple, or set" assert len(cluster_nums) > 0, "target_cluster should not be empty" select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs[self.cluster_col].isin(cluster_nums)] select_adata = select_adata.copy() if normalize: sc.pp.normalize_total(select_adata, target_sum=1e4) sc.pp.log1p(select_adata) sc.pl.dotplot(select_adata, marker_dict, groupby=self.cluster_col, dendrogram=False) def get_violin_plot(self, cluster_nums, normalize=False): if cluster_nums is None: print("Using all clusters") cluster_nums = self.sc_SVC_adata_expr.obs['SVC_cluster'].cat.categories.tolist() assert isinstance(cluster_nums, (list, tuple, set)), f"target_cluster {cluster_nums} should be a list, tuple, or set" assert len(cluster_nums) > 0, "target_cluster should not be empty" select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs[self.cluster_col].isin(cluster_nums)] if normalize: sc.pp.normalize_total(select_adata, target_sum=1e4) sc.pp.log1p(select_adata) sc.pl.violin(select_adata, "LPL", groupby=self.cluster_col) def get_pathway_conclusion( self, cluster_nums, fc_threshold, pathway_num, gene_num, geneset_file=["MSigDB_Hallmark_2020", "KEGG_2021_Human", "GO_Biological_Process_2025"], normalize=False ): if cluster_nums is None: print("Using all clusters") cluster_nums = self.sc_SVC_adata_expr.obs['SVC_cluster'].cat.categories.tolist() if len(geneset_file) == 0: raise NotImplementedError(f"Got empty geneset_file!") select_adata = self.sc_SVC_adata_expr[self.sc_SVC_adata_expr.obs['SVC_cluster'].isin(cluster_nums)] if normalize: sc.pp.normalize_total(select_adata, target_sum=1e4) sc.pp.log1p(select_adata) deg_df = self.get_svc_degs(cluster_nums, fc_threshold) all_pathway = conclusions_write( deg_df, geneset_file, gene_num=gene_num, pathway_num=pathway_num, print_flag=True, conclusion_file_name=None ) return all_pathway def get_volcano_plot(self, cluster_nums, target_group, replace_cols=None, fc_threshold=None, log_fold_changes=10, logfc_threshold=1, padj_threshold=1e-6, top_k=10): indices = self.sc_SVC_adata_expr.obs['SVC_cluster'].isin(cluster_nums) select_adata = self.sc_SVC_adata_expr[indices, :] if replace_cols is not None: select_adata.obs['SVC_cluster'].replace(replace_cols, inplace=True) select_deg_df = get_degs(select_adata, groupby='SVC_cluster', method='t-test', fc_threshold=fc_threshold) select_deg_df = select_deg_df[select_deg_df['group'] == target_group] select_deg_df.reset_index(drop = True, inplace = True) select_deg_df = select_deg_df[select_deg_df['logfoldchanges'].abs() <= log_fold_changes] plot_volcano(select_deg_df, logfc_threshold=logfc_threshold, padj_threshold=padj_threshold, top_k=top_k, save_file_name=None) return select_deg_df