Source code for revise.framework

from __future__ import annotations

import copy
from pathlib import Path
from typing import Any, Dict, Iterable, Optional

from revise.backend import build_default_plugin_registry
from revise.backend import ModeEvaluationPolicy
from revise.backend import ModeValidationPolicy
from revise.backend import build_default_registry
from revise.config import infer_default_profile
from revise.config import load_raw_config
from revise.config import merge_unified_config
from revise.recon.context import PipelineContext
from revise.recon.pipeline import UnifiedReconstructionPipeline
from revise.svc import SVC
from revise.utils import (
    build_task_dir,
    build_run_dir,
    build_run_logger,
    collect_package_versions,
    fingerprint_paths,
    hash_jsonable,
    set_global_seed,
    write_json,
)


[docs] class REVISEPipeline: """Unified orchestration API for all REVISE tasks and modes."""
[docs] def __init__(self, config_path: Optional[str] = None): if config_path is None: config_path = str(Path(__file__).with_name("revise.yaml")) self.config_path = str(self._resolve_config_path(config_path)) self.raw_config = load_raw_config(self.config_path) self.registry = None self.plugin_registry = None
@staticmethod def _resolve_config_path(config_path: str | Path) -> Path: path = Path(config_path) if path.exists(): return path # Backward-compatible default used by README examples and root wrapper # scripts. In installed PyPI wheels, revise.yaml lives beside this file, # not under the caller's current working directory. if path.as_posix() == "revise/revise.yaml": packaged = Path(__file__).with_name("revise.yaml") if packaged.exists(): return packaged return path
[docs] def run( self, *, profile: Optional[str] = None, runtime_overrides: Optional[Dict[str, Any]] = None, io_overrides: Optional[Dict[str, Any]] = None, set_overrides: Optional[Iterable[str]] = None, dry_run: bool = False, ): # 1) Resolve final runtime config from single YAML entry: # defaults -> profile -> CLI runtime/io overrides -> --set overrides. runtime_overrides = dict(runtime_overrides or {}) io_overrides = dict(io_overrides or {}) set_overrides = list(set_overrides or []) if profile is None: profile = infer_default_profile(self.raw_config, runtime_overrides) merged_config = merge_unified_config( raw_config=self.raw_config, profile=profile, runtime_overrides=runtime_overrides, io_overrides=io_overrides, set_overrides=set_overrides, ) runtime = self._resolve_runtime_plugins(merged_config) route_key = f"{runtime['platform']}:{runtime['confounding']}" output_root = merged_config["io"]["output_root"] sample_name = merged_config["io"]["sample_name"] log_dir = build_task_dir( output_root=output_root, sample_name=sample_name, route_key=route_key, io_cfg=merged_config["io"], ) run_dir = build_run_dir( output_root=output_root, sample_name=sample_name, route_key=route_key, io_cfg=merged_config["io"], ) logger_name = f"REVISEUnified::{sample_name}::{route_key}" if log_dir == run_dir: logger_name = f"{logger_name}::{Path(run_dir).name}" logger = build_run_logger( run_name=logger_name, run_dir=log_dir, ) logger.info("[framework] start unified run route=%s strategy=%s", route_key, runtime["strategy"]) set_global_seed(seed=runtime.get("seed"), deterministic=bool(runtime.get("deterministic", True))) ctx = PipelineContext( merged_config=merged_config, raw_config=self.raw_config, config_path=self.config_path, profile=profile, runtime=runtime, route_key=route_key, run_dir=run_dir, logger=logger, dry_run=bool(dry_run), ) self._write_initial_metadata(ctx) if ctx.dry_run: # Dry-run validates route+inputs without importing heavy strategy # implementations. This keeps structural checks fast and robust. ctx.stage_trace.append("validate_inputs") ModeValidationPolicy().validate(ctx) ctx.svc = SVC( expr=None, spatial=None, svc_kind=str(runtime.get("svc_kind", "sc")), provenance={ "dry_run": True, "route": ctx.route, "stage_trace": list(ctx.stage_trace), }, artifacts={}, ) self._write_final_metadata(ctx) logger.info("[framework] dry-run validated route=%s", route_key) return ctx.svc if self.registry is None: # Strategy registry is initialized lazily to avoid importing heavy # scientific modules during dry-run validation. self.registry = build_default_registry() strategy = self.registry.get(runtime["strategy"]) pipeline = UnifiedReconstructionPipeline( strategy=strategy, validation_policy=ModeValidationPolicy(), evaluation_policy=ModeEvaluationPolicy(), ) svc = pipeline.run(ctx) self._write_final_metadata(ctx) logger.info("[framework] finished unified run route=%s", route_key) return svc
def _write_initial_metadata(self, ctx: PipelineContext) -> None: write_json(Path(ctx.run_dir) / "merged_config.json", self._export_merged_config(ctx)) def _write_final_metadata(self, ctx: PipelineContext) -> None: legacy_conf = ctx.legacy_config st_file = getattr(legacy_conf, "st_file_path", None) sc_file = getattr(legacy_conf, "sc_ref_file_path", None) gt_file = getattr(legacy_conf, "gt_svc_file_path", None) exported_config = self._export_merged_config(ctx) provenance = { "config_path": ctx.config_path, "profile": ctx.profile, "route": ctx.route, "route_key": ctx.route_key, "run_dir": str(ctx.run_dir), "config_hash": hash_jsonable(exported_config), "data_fingerprint": fingerprint_paths([st_file, sc_file, gt_file]), "packages": collect_package_versions( [ "revise-svc", "scanpy", "anndata", "numpy", "pandas", "scipy", "POT", "leidenalg", ] ), "stage_trace": list(ctx.stage_trace), "quality_metric_keys": sorted(ctx.quality_metrics.keys()), "svc_summary": ctx.svc.summary() if ctx.svc else {}, } if ctx.svc is not None: ctx.svc.provenance.update(provenance) write_json(Path(ctx.run_dir) / "provenance.json", provenance) def _export_merged_config(self, ctx: PipelineContext) -> Dict[str, Any]: exported = copy.deepcopy(ctx.merged_config) runtime = exported.get("runtime") if isinstance(runtime, dict): runtime.pop("legacy_mode", None) return exported
[docs] def _resolve_runtime_plugins(self, merged_config: Dict[str, Any]) -> Dict[str, Any]: """Resolve platform/CF/OT plugins before strategy instantiation.""" if self.plugin_registry is None: self.plugin_registry = build_default_plugin_registry() runtime = merged_config["runtime"] payload: Dict[str, Any] = { "runtime": runtime, "merged_config": merged_config, } platform_adapter_id = runtime.get("platform_adapter") or runtime.get("platform") or "default" cf_strategy_id = runtime.get("cf_strategy") or runtime.get("confounding") or "default" ot_solver_id = runtime.get("ot_solver") or "default" payload = self.plugin_registry.get_platform_adapter(platform_adapter_id).adapt(payload) payload = self.plugin_registry.get_cf_strategy(cf_strategy_id).apply(payload) payload = self.plugin_registry.get_ot_solver(ot_solver_id).solve(payload) resolved_runtime = payload.get("runtime", runtime) merged_config["runtime"] = resolved_runtime return resolved_runtime