Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_ccc.py 2.22 KiB
import numpy as np
from scipy.signal import fftconvolve


def calculate_correlation(data_volume, template_volume, mask):
    """
    Calculate the correlation of a data volume with a template volume using local normalization.
    """
    mask = mask.astype(bool)
    num_elements = np.sum(mask)

    # Template normalization within the mask
    template_masked = template_volume * mask
    template_mean = np.mean(template_volume[mask])
    template_std = np.std(template_volume[mask])
    normalized_template = (template_masked - template_mean) / template_std

    # Prepare for efficient convolution by padding the template and mask
    padded_template = np.zeros_like(data_volume)
    padded_mask = np.zeros_like(data_volume)

    # Assuming template and mask are smaller than the data_volume
    template_slices = tuple(
        slice(s // 2 - ts // 2, s // 2 + ts // 2 + 1)
        for s, ts in zip(data_volume.shape, template_volume.shape)
    )
    padded_template[template_slices] = normalized_template
    padded_mask[template_slices] = mask

    # Calculate mean and variance across the volume using convolution
    mean_volume = fftconvolve(data_volume, padded_mask, mode="same") / num_elements
    mean_square_volume = (
        fftconvolve(data_volume**2, padded_mask, mode="same") / num_elements
    )

    # Local standard deviation
    volume_variance = mean_square_volume - mean_volume**2
    np.maximum(volume_variance, 0, out=volume_variance)  # Ensure non-negative variance
    volume_std_dev = np.sqrt(volume_variance)

    # Normalize the data volume locally
    adjusted_volume = (data_volume - mean_volume) / volume_std_dev
    adjusted_volume[volume_std_dev == 0] = 0  # Avoid division by zero

    # Compute correlation using convolution
    correlation_result = (
        fftconvolve(adjusted_volume, padded_template[::-1, ::-1, ::-1], mode="same")
        / num_elements
    )

    return correlation_result


# Example usage
chunk = np.random.rand(100, 100, 100)  # Large data volume
template = np.random.rand(20, 20, 20)  # Smaller template
mask = np.ones(template.shape)  # Uniform mask for simplicity

# Compute cross-correlation
cc_result = calculate_correlation(chunk, template, mask)
print("Cross-correlation computed successfully.")