Source code for revise.recon.pipeline

from __future__ import annotations

from pathlib import Path
from typing import Any
from typing import Dict
from revise.backend.contracts import EvaluationPolicy
from revise.backend.contracts import InputValidationPolicy
from revise.backend.contracts import LocalRefinementStrategy
from revise.svc import SVC


[docs] class UnifiedReconstructionPipeline: STAGE_ORDER = [ "validate_inputs", "global_anchoring", "prepare_local_units", "build_graph", "build_ot_problem", "solve_ot", "update_expression", "finalize_svc", "evaluate_if_needed", ]
[docs] def __init__( self, strategy: LocalRefinementStrategy, validation_policy: InputValidationPolicy, evaluation_policy: EvaluationPolicy, ) -> None: self.strategy = strategy self.validation_policy = validation_policy self.evaluation_policy = evaluation_policy
[docs] def run(self, ctx): # Template Method: this order is fixed across all modes/routes. # Strategy implementations can customize internals per stage but must # not change lifecycle topology. self.validate_inputs(ctx) if ctx.dry_run: ctx.svc = SVC( expr=None, spatial=None, svc_kind=str(ctx.runtime.get("svc_kind", "sc")), provenance={"stage_trace": list(ctx.stage_trace), "dry_run": True}, artifacts={}, ) return ctx.svc self.global_anchoring(ctx) self.prepare_local_units(ctx) self.build_graph(ctx) self.build_ot_problem(ctx) self.solve_ot(ctx) self.update_expression(ctx) self.finalize_svc(ctx) self.evaluate_if_needed(ctx) return ctx.svc
def _record_stage(self, ctx, stage: str) -> None: ctx.stage_trace.append(stage) ctx.logger.info("[pipeline] stage=%s", stage)
[docs] def validate_inputs(self, ctx) -> None: self._record_stage(ctx, "validate_inputs") self.validation_policy.validate(ctx) self.strategy.prepare_context(ctx)
[docs] def global_anchoring(self, ctx) -> None: self._record_stage(ctx, "global_anchoring") self.strategy.global_anchoring(ctx)
[docs] def prepare_local_units(self, ctx) -> None: self._record_stage(ctx, "prepare_local_units") self.strategy.prepare_local_units(ctx)
[docs] def build_graph(self, ctx) -> None: self._record_stage(ctx, "build_graph") self.strategy.build_graph(ctx)
[docs] def build_ot_problem(self, ctx) -> None: self._record_stage(ctx, "build_ot_problem") self.strategy.build_ot_problem(ctx)
[docs] def solve_ot(self, ctx) -> None: self._record_stage(ctx, "solve_ot") self.strategy.solve_ot(ctx)
[docs] def update_expression(self, ctx) -> None: self._record_stage(ctx, "update_expression") self.strategy.update_expression(ctx)
[docs] def finalize_svc(self, ctx) -> None: self._record_stage(ctx, "finalize_svc") ctx.svc = self.strategy.finalize_svc(ctx) self._persist_outputs(ctx)
[docs] def evaluate_if_needed(self, ctx) -> None: self._record_stage(ctx, "evaluate_if_needed") if not self.evaluation_policy.should_evaluate(ctx): ctx.logger.info("[pipeline] evaluation skipped by policy") return from revise.analysis.metrics import compute_metric outputs = dict(ctx.svc.artifacts.get("legacy_outputs", {})) if ctx.svc else {} if not outputs: ctx.logger.warning("[pipeline] no outputs available for evaluation") return if ctx.real_st_adata is None: ctx.logger.warning("[pipeline] real_st_adata is missing; benchmark evaluation skipped") return benchmark_mode = str(ctx.runtime.get("mode")) == "benchmark" for key, adata in outputs.items(): common_index = adata.obs.index.intersection(ctx.real_st_adata.obs.index) if common_index.empty: ctx.logger.warning("[pipeline] no shared cells for %s; skip metric", key) continue pred = adata[common_index, :].copy() gt = ctx.real_st_adata[common_index, :].copy() metrics_df = compute_metric(pred, gt, ctx.logger, adata_process=False, gene_list=None, normalize=True) if benchmark_mode: out_file = Path(ctx.run_dir) / "metrics_normalized.csv" else: metrics_dir = Path(ctx.run_dir) / "metrics" metrics_dir.mkdir(parents=True, exist_ok=True) out_file = metrics_dir / f"{key}_metrics_normalized.csv" metrics_df.to_csv(out_file) ctx.quality_metrics[key] = metrics_df # Backward-compatible sink in legacy mode for non-benchmark runs. if ctx.legacy_mode and not benchmark_mode: metrics_df.to_csv(Path(ctx.run_dir) / "metrics_normalized.csv") if ctx.svc is not None: ctx.svc.quality_metrics = dict(ctx.quality_metrics)
def _persist_outputs(self, ctx) -> None: if ctx.svc is None: return if not bool(ctx.merged_config.get("io", {}).get("save_outputs", True)): return outputs = dict(ctx.svc.artifacts.get("legacy_outputs", {})) if not outputs: return benchmark_mode = str(ctx.runtime.get("mode")) == "benchmark" if not benchmark_mode: # Canonical artifact sink for unified runs. artifacts_dir = Path(ctx.run_dir) / "artifacts" artifacts_dir.mkdir(parents=True, exist_ok=True) for key, adata in outputs.items(): path = artifacts_dir / f"{key}.h5ad" adata.write_h5ad(path) elif not ctx.legacy_mode: for key, adata in outputs.items(): path = Path(ctx.run_dir) / f"{key}.h5ad" adata.write_h5ad(path) if ctx.legacy_mode: self._emit_legacy_files(ctx, outputs) def _emit_legacy_files(self, ctx, outputs: Dict[str, Any]) -> None: compatibility_map = { "sp_svc": "sp_SVC.h5ad", "sc_svc_dec": "sc_SVC.h5ad", "sc_svc_expr": "sc_SVC_expr.h5ad", "sc_svc_spatial": "sc_SVC_spatial.h5ad", "sc_svc_impute_in_panel": "sc_SVC_impute_in_panel.h5ad", "sc_svc_impute_all_panel": "sc_SVC_impute_all_panel.h5ad", } for key, filename in compatibility_map.items(): if key in outputs: outputs[key].write_h5ad(Path(ctx.run_dir) / filename)