from __future__ import annotations
import copy
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
from revise.backend import build_default_plugin_registry
from revise.backend import ModeEvaluationPolicy
from revise.backend import ModeValidationPolicy
from revise.backend import build_default_registry
from revise.config import infer_default_profile
from revise.config import load_raw_config
from revise.config import merge_unified_config
from revise.recon.context import PipelineContext
from revise.recon.pipeline import UnifiedReconstructionPipeline
from revise.svc import SVC
from revise.utils import (
build_task_dir,
build_run_dir,
build_run_logger,
collect_package_versions,
fingerprint_paths,
hash_jsonable,
set_global_seed,
write_json,
)
[docs]
class REVISEPipeline:
"""Unified orchestration API for all REVISE tasks and modes."""
[docs]
def __init__(self, config_path: Optional[str] = None):
if config_path is None:
config_path = str(Path(__file__).with_name("revise.yaml"))
self.config_path = str(self._resolve_config_path(config_path))
self.raw_config = load_raw_config(self.config_path)
self.registry = None
self.plugin_registry = None
@staticmethod
def _resolve_config_path(config_path: str | Path) -> Path:
path = Path(config_path)
if path.exists():
return path
# Backward-compatible default used by README examples and root wrapper
# scripts. In installed PyPI wheels, revise.yaml lives beside this file,
# not under the caller's current working directory.
if path.as_posix() == "revise/revise.yaml":
packaged = Path(__file__).with_name("revise.yaml")
if packaged.exists():
return packaged
return path
[docs]
def run(
self,
*,
profile: Optional[str] = None,
runtime_overrides: Optional[Dict[str, Any]] = None,
io_overrides: Optional[Dict[str, Any]] = None,
set_overrides: Optional[Iterable[str]] = None,
dry_run: bool = False,
):
# 1) Resolve final runtime config from single YAML entry:
# defaults -> profile -> CLI runtime/io overrides -> --set overrides.
runtime_overrides = dict(runtime_overrides or {})
io_overrides = dict(io_overrides or {})
set_overrides = list(set_overrides or [])
if profile is None:
profile = infer_default_profile(self.raw_config, runtime_overrides)
merged_config = merge_unified_config(
raw_config=self.raw_config,
profile=profile,
runtime_overrides=runtime_overrides,
io_overrides=io_overrides,
set_overrides=set_overrides,
)
runtime = self._resolve_runtime_plugins(merged_config)
route_key = f"{runtime['platform']}:{runtime['confounding']}"
output_root = merged_config["io"]["output_root"]
sample_name = merged_config["io"]["sample_name"]
log_dir = build_task_dir(
output_root=output_root,
sample_name=sample_name,
route_key=route_key,
io_cfg=merged_config["io"],
)
run_dir = build_run_dir(
output_root=output_root,
sample_name=sample_name,
route_key=route_key,
io_cfg=merged_config["io"],
)
logger_name = f"REVISEUnified::{sample_name}::{route_key}"
if log_dir == run_dir:
logger_name = f"{logger_name}::{Path(run_dir).name}"
logger = build_run_logger(
run_name=logger_name,
run_dir=log_dir,
)
logger.info("[framework] start unified run route=%s strategy=%s", route_key, runtime["strategy"])
set_global_seed(seed=runtime.get("seed"), deterministic=bool(runtime.get("deterministic", True)))
ctx = PipelineContext(
merged_config=merged_config,
raw_config=self.raw_config,
config_path=self.config_path,
profile=profile,
runtime=runtime,
route_key=route_key,
run_dir=run_dir,
logger=logger,
dry_run=bool(dry_run),
)
self._write_initial_metadata(ctx)
if ctx.dry_run:
# Dry-run validates route+inputs without importing heavy strategy
# implementations. This keeps structural checks fast and robust.
ctx.stage_trace.append("validate_inputs")
ModeValidationPolicy().validate(ctx)
ctx.svc = SVC(
expr=None,
spatial=None,
svc_kind=str(runtime.get("svc_kind", "sc")),
provenance={
"dry_run": True,
"route": ctx.route,
"stage_trace": list(ctx.stage_trace),
},
artifacts={},
)
self._write_final_metadata(ctx)
logger.info("[framework] dry-run validated route=%s", route_key)
return ctx.svc
if self.registry is None:
# Strategy registry is initialized lazily to avoid importing heavy
# scientific modules during dry-run validation.
self.registry = build_default_registry()
strategy = self.registry.get(runtime["strategy"])
pipeline = UnifiedReconstructionPipeline(
strategy=strategy,
validation_policy=ModeValidationPolicy(),
evaluation_policy=ModeEvaluationPolicy(),
)
svc = pipeline.run(ctx)
self._write_final_metadata(ctx)
logger.info("[framework] finished unified run route=%s", route_key)
return svc
def _write_initial_metadata(self, ctx: PipelineContext) -> None:
write_json(Path(ctx.run_dir) / "merged_config.json", self._export_merged_config(ctx))
def _write_final_metadata(self, ctx: PipelineContext) -> None:
legacy_conf = ctx.legacy_config
st_file = getattr(legacy_conf, "st_file_path", None)
sc_file = getattr(legacy_conf, "sc_ref_file_path", None)
gt_file = getattr(legacy_conf, "gt_svc_file_path", None)
exported_config = self._export_merged_config(ctx)
provenance = {
"config_path": ctx.config_path,
"profile": ctx.profile,
"route": ctx.route,
"route_key": ctx.route_key,
"run_dir": str(ctx.run_dir),
"config_hash": hash_jsonable(exported_config),
"data_fingerprint": fingerprint_paths([st_file, sc_file, gt_file]),
"packages": collect_package_versions(
[
"revise-svc",
"scanpy",
"anndata",
"numpy",
"pandas",
"scipy",
"POT",
"leidenalg",
]
),
"stage_trace": list(ctx.stage_trace),
"quality_metric_keys": sorted(ctx.quality_metrics.keys()),
"svc_summary": ctx.svc.summary() if ctx.svc else {},
}
if ctx.svc is not None:
ctx.svc.provenance.update(provenance)
write_json(Path(ctx.run_dir) / "provenance.json", provenance)
def _export_merged_config(self, ctx: PipelineContext) -> Dict[str, Any]:
exported = copy.deepcopy(ctx.merged_config)
runtime = exported.get("runtime")
if isinstance(runtime, dict):
runtime.pop("legacy_mode", None)
return exported
[docs]
def _resolve_runtime_plugins(self, merged_config: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve platform/CF/OT plugins before strategy instantiation."""
if self.plugin_registry is None:
self.plugin_registry = build_default_plugin_registry()
runtime = merged_config["runtime"]
payload: Dict[str, Any] = {
"runtime": runtime,
"merged_config": merged_config,
}
platform_adapter_id = runtime.get("platform_adapter") or runtime.get("platform") or "default"
cf_strategy_id = runtime.get("cf_strategy") or runtime.get("confounding") or "default"
ot_solver_id = runtime.get("ot_solver") or "default"
payload = self.plugin_registry.get_platform_adapter(platform_adapter_id).adapt(payload)
payload = self.plugin_registry.get_cf_strategy(cf_strategy_id).apply(payload)
payload = self.plugin_registry.get_ot_solver(ot_solver_id).solve(payload)
resolved_runtime = payload.get("runtime", runtime)
merged_config["runtime"] = resolved_runtime
return resolved_runtime