[18]:
import os
output_dir = "output/seg_benchmark"
os.makedirs(output_dir, exist_ok=True)

Segmentation Benchmark

[19]:
source_path = "../REVISE/results"
task = "seg"
patient_id = "P2CRC"

result_path = f"{source_path}/{task}/{patient_id}"
data_path = f"../REVISE/data/{task}/{patient_id}"

REVISE_result_path = f"{source_path}/{task}/{patient_id}"

methods = ['raw', 'sp_SVC']

parts = ["part3", "part1", "part2"]
metrics = ["PCC", "SSIM", "MSE"]
spot_sizes = [1,2,3,4]

Plot seg benchmark

[ ]:
gene_type = "All"
gene_num = 50
[ ]:
import os
import pandas as pd
import scanpy as sc
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

def get_genes(result_path, data_path, part, spot_size, method, gene_type = "HVG", gene_num = 50, test_genes = None):
    ## Get HVG or HEG or All
    spot_path = os.path.join(data_path, f"cut_{part}", f"spot_{spot_size}")
    save_path = os.path.join(result_path, part, f"{spot_size}_{method}", "select_gene")
    os.makedirs(save_path, exist_ok=True)

    gene_file = f"{save_path}/{gene_type}_genes_{gene_num}.txt"
    if os.path.exists(gene_file):
        # print(f"Find {gene_type} genes in {gene_file}")
        with open(gene_file, "r") as f:
            genes = f.read().splitlines()
        return genes

    st_path = f"{spot_path}/xenium_spot.h5ad"
    adata = sc.read(st_path)
    if test_genes is not None:
        overlap_genes = [gene for gene in test_genes if gene in adata.var_names]
        adata = adata[:, overlap_genes]

    if gene_type == "HVG":
        sc.pp.filter_genes(adata, min_cells=1)
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=gene_num)
        genes = adata.var[adata.var['highly_variable']].index.tolist()
    elif gene_type == "HEG":
        adata.var['sum'] = adata.X.toarray().sum(axis=0)
        adata.var.sort_values('sum', ascending=True, inplace=True)
        genes = adata.var.head(gene_num).index.tolist()
    else:
        genes = adata.var_names

    with open(gene_file, "w") as f:
        f.write("\n".join(genes))
    # print(f"Save {gene_type} genes to {gene_file}")

    return genes

def get_merge_df(result_path, data_path, part, metric, spot_sizes, methods):

    merge_df = pd.DataFrame()
    for spot_size in tqdm(spot_sizes, desc="spot_sizes"):
        for method in methods:
            if method == "raw":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/raw_metrics_normalized.csv"
            elif method == "sp_SVC":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/metrics_normalized.csv"
            df = pd.read_csv(metric_file, index_col=0)
            df.set_index('Gene', inplace=True)
            genes = get_genes(result_path, data_path, part, spot_size, method, gene_type = gene_type, gene_num = gene_num, test_genes = df.index)
            df = df.loc[genes]

            df = pd.DataFrame({
                'Method': method,
                'Value': df[metric].values,
                'Spot_size': spot_size,
                'Part': part,
                'Metric': metric,
            })

            merge_df = pd.concat([merge_df, df])
    merge_df.reset_index(drop=True, inplace=True)
    return merge_df
[ ]:
# compute mean
parts = ["part3", "part1", "part2"]
metrics = ["PCC", "SSIM", "MSE"]

save_dir = f"{output_dir}/{gene_type}_mean"
os.makedirs(save_dir, exist_ok=True)
for part in tqdm(parts, desc="parts"):
    for spot_size in tqdm(spot_sizes, desc="spot_sizes"):
        merge_df = pd.DataFrame()
        for method in methods:
            if method == "raw":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/raw_metrics_normalized.csv"
            elif method == "sp_SVC":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/metrics_normalized.csv"
            df = pd.read_csv(metric_file, index_col=0)
            df.set_index('Gene', inplace=True)
            genes = get_genes(result_path, data_path, part, spot_size, method, gene_type = gene_type, gene_num = gene_num, test_genes = df.index)
            df = df.loc[genes]

            df = df[metrics].mean(axis=0)

            merge_df = pd.concat([merge_df, df], axis=1)
        merge_df.reset_index(drop=True, inplace=True)
        merge_df.index = metrics
        merge_df.columns = methods
        merge_df.T.to_csv(f"{save_dir}/{part}_{spot_size}_{gene_type}_{gene_num}.csv")

spot_sizes: 100%|██████████| 4/4 [00:50<00:00, 12.73s/it]
spot_sizes: 100%|██████████| 4/4 [00:30<00:00,  7.54s/it]
spot_sizes: 100%|██████████| 4/4 [00:25<00:00,  6.48s/it]
parts: 100%|██████████| 3/3 [01:46<00:00, 35.67s/it]
[21]:
gene_type = "All"

[ ]:
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.useafm'] = False

def plot_comp_seg(merge_df, metric, part, ax):

    spot_size_order = [1, 2, 3, 4]
    method_order = methods
    custom_palette = {
        methods[0]: '#80a9c8',
        methods[1]: '#e89786',
        # methods[2]: '#8ccfd9',
        # methods[3]: '#b39a94',
    }
    # Plot boxplots
    sns.boxplot(
        data=merge_df,
        x='Spot_size',
        y='Value',
        hue='Method',
        order=spot_size_order,
        hue_order=method_order,
        palette=[custom_palette[m] for m in method_order],
        width=0.4,
        fliersize=2,
        showfliers=False,
        ax=ax
    )
    # ax.set_title(f'{part}', fontsize=11, pad=8)
    # ax.set_xlabel('Spot Size', fontsize=10)
    ax.set_ylabel(metric, fontsize=10)
    ax.xaxis.set_visible(False)

    if metric == "PCC":
        ax.set_ylim(0.2, 1.02)
    elif metric == "SSIM":
        ax.set_ylim(0.0, 1.02)
    elif metric == "MSE":
        ax.set_ylim(1e-5, 0.01)
        ax.set_yscale('log')

    if ax.get_legend() is not None:
        ax.legend_.remove()
    ax.set_aspect('auto')


fig, axes = plt.subplots(3, 3, figsize=(12, 9))

for i, part in enumerate(parts):
    for j, metric in enumerate(metrics):

        merge_df = get_merge_df(result_path, data_path, part, metric, spot_sizes, methods=methods)

        ax = axes[j, i]
        plot_comp_seg(merge_df, metric, part, ax)

plt.tight_layout()
plt.savefig(f"{output_dir}/{task}_{gene_type}_{gene_num}.pdf", dpi=300)
plt.show()
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 119.05it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 278.60it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 278.65it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 275.21it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 278.63it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 277.66it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 276.19it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 274.75it/s]
spot_sizes: 100%|██████████| 4/4 [00:00<00:00, 277.77it/s]
../_images/benchmark_seg_benchmark_8_1.png

bin2cell

[ ]:
spot_sizes = [8]
task = "bin2cell"

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.useafm'] = False
[ ]:
def get_merge_df(result_path, data_path, parts, metric, methods, gene_type, gene_num):
    merge_df = pd.DataFrame()
    spot_size = 8
    for part in parts:
        for method in methods:
            if method == "raw":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/raw_metrics_normalized.csv"
            elif method == "sp_SVC":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/metrics_normalized.csv"

            # Load data
            df_metric = pd.read_csv(metric_file, index_col=0)
            df_metric.set_index('Gene', inplace=True)

            # Pass the correct parameters
            genes = get_genes(result_path, data_path, part, spot_size, method,
                            gene_type=gene_type, gene_num=gene_num, test_genes=df_metric.index)
            df_metric = df_metric.loc[genes]

            # Create a new DataFrame to avoid variable-name conflicts
            temp_df = pd.DataFrame({
                'Method': method,
                'Value': df_metric[metric].values,  # Use the correct variable name
                'Spot_size': spot_size,
                'Part': part,
                'Metric': metric,
            })

            merge_df = pd.concat([merge_df, temp_df])

    merge_df.reset_index(drop=True, inplace=True)
    return merge_df


def plot_comp_seg(merge_df, metric, ax):
    custom_palette = {
        "raw": '#80a9c8',
        "sp_SVC": '#e89786',
    }

    sns.boxplot(
        data=merge_df,
        x='Part',
        y='Value',
        hue='Method',
        palette=custom_palette,
        width=0.4,
        fliersize=2,
        showfliers=False,
        ax=ax
    )

    ax.set_ylabel(metric, fontsize=10)
    ax.xaxis.set_visible(False)

    if metric == "PCC":
        ax.set_ylim(0.2, 1.02)
    elif metric == "SSIM":
        ax.set_ylim(0.0, 1.02)
    elif metric == "MSE":
        ax.set_ylim(1e-5, 0.01)
        ax.set_yscale('log')

    if ax.get_legend() is not None:
        ax.legend_.remove()
    ax.set_aspect('auto')


fig, axes = plt.subplots(1, 3, figsize=(9, 3))

for j, metric in enumerate(metrics):
    # Add missing parameters
    merge_df = get_merge_df(result_path, data_path, parts, metric, methods=methods,
                           gene_type=gene_type, gene_num=gene_num)

    ax = axes[j]
    plot_comp_seg(merge_df, metric, ax)

plt.tight_layout()
plt.savefig(f"{output_dir}/{task}_{gene_type}_{gene_num}.pdf", dpi=300)
plt.show()
../_images/benchmark_seg_benchmark_11_0.png
[25]:
# compute mean
merge_df = pd.DataFrame()
spot_sizes = [8]
for part in tqdm(parts, desc="parts"):
    for spot_size in tqdm(spot_sizes, desc="spot_sizes"):
        merge_df = pd.DataFrame()
        for method in methods:
            if method == "raw":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/raw_metrics_normalized.csv"
            elif method == "sp_SVC":
                metric_file = f"{result_path}/cut_{part}/spot_{spot_size}/metrics_normalized.csv"
            df = pd.read_csv(metric_file, index_col=0)
            df.set_index('Gene', inplace=True)
            genes = get_genes(result_path, data_path, part, spot_size, method, gene_type = gene_type, gene_num = gene_num, test_genes = df.index)
            df = df.loc[genes]

            df = df[metrics].mean(axis=0)

            merge_df = pd.concat([merge_df, df], axis=1)
        merge_df.reset_index(drop=True, inplace=True)
        merge_df.index = metrics
        merge_df.columns = methods
        merge_df.T.to_csv(f"{save_dir}/{part}_{spot_size}_{gene_type}_{gene_num}.csv")

spot_sizes: 100%|██████████| 1/1 [00:00<00:00, 81.41it/s]
spot_sizes: 100%|██████████| 1/1 [00:00<00:00, 85.57it/s]
spot_sizes: 100%|██████████| 1/1 [00:00<00:00, 85.37it/s]
parts: 100%|██████████| 3/3 [00:00<00:00, 64.29it/s]