from __future__ import annotations
from pathlib import Path
from typing import Any
from typing import Dict
from revise.backend.contracts import EvaluationPolicy
from revise.backend.contracts import InputValidationPolicy
from revise.backend.contracts import LocalRefinementStrategy
from revise.svc import SVC
[docs]
class UnifiedReconstructionPipeline:
STAGE_ORDER = [
"validate_inputs",
"global_anchoring",
"prepare_local_units",
"build_graph",
"build_ot_problem",
"solve_ot",
"update_expression",
"finalize_svc",
"evaluate_if_needed",
]
[docs]
def __init__(
self,
strategy: LocalRefinementStrategy,
validation_policy: InputValidationPolicy,
evaluation_policy: EvaluationPolicy,
) -> None:
self.strategy = strategy
self.validation_policy = validation_policy
self.evaluation_policy = evaluation_policy
[docs]
def run(self, ctx):
# Template Method: this order is fixed across all modes/routes.
# Strategy implementations can customize internals per stage but must
# not change lifecycle topology.
self.validate_inputs(ctx)
if ctx.dry_run:
ctx.svc = SVC(
expr=None,
spatial=None,
svc_kind=str(ctx.runtime.get("svc_kind", "sc")),
provenance={"stage_trace": list(ctx.stage_trace), "dry_run": True},
artifacts={},
)
return ctx.svc
self.global_anchoring(ctx)
self.prepare_local_units(ctx)
self.build_graph(ctx)
self.build_ot_problem(ctx)
self.solve_ot(ctx)
self.update_expression(ctx)
self.finalize_svc(ctx)
self.evaluate_if_needed(ctx)
return ctx.svc
def _record_stage(self, ctx, stage: str) -> None:
ctx.stage_trace.append(stage)
ctx.logger.info("[pipeline] stage=%s", stage)
[docs]
def global_anchoring(self, ctx) -> None:
self._record_stage(ctx, "global_anchoring")
self.strategy.global_anchoring(ctx)
[docs]
def prepare_local_units(self, ctx) -> None:
self._record_stage(ctx, "prepare_local_units")
self.strategy.prepare_local_units(ctx)
[docs]
def build_graph(self, ctx) -> None:
self._record_stage(ctx, "build_graph")
self.strategy.build_graph(ctx)
[docs]
def build_ot_problem(self, ctx) -> None:
self._record_stage(ctx, "build_ot_problem")
self.strategy.build_ot_problem(ctx)
[docs]
def solve_ot(self, ctx) -> None:
self._record_stage(ctx, "solve_ot")
self.strategy.solve_ot(ctx)
[docs]
def update_expression(self, ctx) -> None:
self._record_stage(ctx, "update_expression")
self.strategy.update_expression(ctx)
[docs]
def finalize_svc(self, ctx) -> None:
self._record_stage(ctx, "finalize_svc")
ctx.svc = self.strategy.finalize_svc(ctx)
self._persist_outputs(ctx)
[docs]
def evaluate_if_needed(self, ctx) -> None:
self._record_stage(ctx, "evaluate_if_needed")
if not self.evaluation_policy.should_evaluate(ctx):
ctx.logger.info("[pipeline] evaluation skipped by policy")
return
from revise.analysis.metrics import compute_metric
outputs = dict(ctx.svc.artifacts.get("legacy_outputs", {})) if ctx.svc else {}
if not outputs:
ctx.logger.warning("[pipeline] no outputs available for evaluation")
return
if ctx.real_st_adata is None:
ctx.logger.warning("[pipeline] real_st_adata is missing; benchmark evaluation skipped")
return
benchmark_mode = str(ctx.runtime.get("mode")) == "benchmark"
for key, adata in outputs.items():
common_index = adata.obs.index.intersection(ctx.real_st_adata.obs.index)
if common_index.empty:
ctx.logger.warning("[pipeline] no shared cells for %s; skip metric", key)
continue
pred = adata[common_index, :].copy()
gt = ctx.real_st_adata[common_index, :].copy()
metrics_df = compute_metric(pred, gt, ctx.logger, adata_process=False, gene_list=None, normalize=True)
if benchmark_mode:
out_file = Path(ctx.run_dir) / "metrics_normalized.csv"
else:
metrics_dir = Path(ctx.run_dir) / "metrics"
metrics_dir.mkdir(parents=True, exist_ok=True)
out_file = metrics_dir / f"{key}_metrics_normalized.csv"
metrics_df.to_csv(out_file)
ctx.quality_metrics[key] = metrics_df
# Backward-compatible sink in legacy mode for non-benchmark runs.
if ctx.legacy_mode and not benchmark_mode:
metrics_df.to_csv(Path(ctx.run_dir) / "metrics_normalized.csv")
if ctx.svc is not None:
ctx.svc.quality_metrics = dict(ctx.quality_metrics)
def _persist_outputs(self, ctx) -> None:
if ctx.svc is None:
return
if not bool(ctx.merged_config.get("io", {}).get("save_outputs", True)):
return
outputs = dict(ctx.svc.artifacts.get("legacy_outputs", {}))
if not outputs:
return
benchmark_mode = str(ctx.runtime.get("mode")) == "benchmark"
if not benchmark_mode:
# Canonical artifact sink for unified runs.
artifacts_dir = Path(ctx.run_dir) / "artifacts"
artifacts_dir.mkdir(parents=True, exist_ok=True)
for key, adata in outputs.items():
path = artifacts_dir / f"{key}.h5ad"
adata.write_h5ad(path)
elif not ctx.legacy_mode:
for key, adata in outputs.items():
path = Path(ctx.run_dir) / f"{key}.h5ad"
adata.write_h5ad(path)
if ctx.legacy_mode:
self._emit_legacy_files(ctx, outputs)
def _emit_legacy_files(self, ctx, outputs: Dict[str, Any]) -> None:
compatibility_map = {
"sp_svc": "sp_SVC.h5ad",
"sc_svc_dec": "sc_SVC.h5ad",
"sc_svc_expr": "sc_SVC_expr.h5ad",
"sc_svc_spatial": "sc_SVC_spatial.h5ad",
"sc_svc_impute_in_panel": "sc_SVC_impute_in_panel.h5ad",
"sc_svc_impute_all_panel": "sc_SVC_impute_all_panel.h5ad",
}
for key, filename in compatibility_map.items():
if key in outputs:
outputs[key].write_h5ad(Path(ctx.run_dir) / filename)