From 3ed052083c5e2d5872d2db5b098d12c5248e1b77 Mon Sep 17 00:00:00 2001 From: Huabin Zhou <huabin.zhou@utsouthwestern.edu> Date: Sun, 18 Aug 2024 14:47:21 -0500 Subject: [PATCH] roll back to skip the rotamer and maskmer generation prior the template matching, which accelaratate the program since MPI wasn't involved in these two separated processes . --- src/catm.py | 11 ++++++++- src/config.py | 2 +- src/global_template_matching.py | 14 +++++++----- src/global_template_matching_chunk.py | 32 +++++++++++++-------------- src/template_matching.py | 15 +++++-------- 5 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/catm.py b/src/catm.py index bd78251..0386331 100644 --- a/src/catm.py +++ b/src/catm.py @@ -15,9 +15,10 @@ import os import sys from config_loader import get_config from template_matching import TemplateMatcher -from global_template_matching_chunk import TemplateMatcherGeneral from clash_resolver import ClashResolver + + # Get the directory where the script is located script_dir = os.path.abspath(os.path.dirname(__file__)) @@ -26,6 +27,14 @@ sys.path.append(script_dir) # import config config = get_config() +if hasattr(config, "testTM"): + if config.testTM: + if config.chunk_size is None: + from global_template_matching import TemplateMatcherGeneral + print('No chunk size specified, assuming the size of tomograms is small') + else: + from global_template_matching_chunk import TemplateMatcherGeneral + print("Running template matching with specified chunk size") __version__ = "0.1.0" __author__ = "Huabin Zhou" diff --git a/src/config.py b/src/config.py index a51f3a8..7839387 100644 --- a/src/config.py +++ b/src/config.py @@ -90,4 +90,4 @@ adjust_ccc = 0.11111111 # max distance between two partles to be considered as the same particle # option's for running only the general template matching testTM = False -adjust_ccc_relion = 0.11111111 +adjust_ccc_relion = 0 diff --git a/src/global_template_matching.py b/src/global_template_matching.py index b8b0d22..40d3eb0 100644 --- a/src/global_template_matching.py +++ b/src/global_template_matching.py @@ -1,7 +1,7 @@ import numpy as np from scipy.spatial.transform import Rotation as R from multiprocessing import Pool, cpu_count -from utils import rotate_high_res +from utils import rotate from utils import prepare_ctf_volumes, apply_ctf, apply_wedge from scipy.ndimage import binary_dilation import pandas as pd @@ -15,7 +15,7 @@ Works well on small tomograms, no optimzation for large tomograms yet """ -class TemplateMatcher_general: +class TemplateMatcherGeneral: def __init__(self, inputs): self.templates = inputs["templates"] self.contour_level = inputs["contour_level"] @@ -30,8 +30,8 @@ class TemplateMatcher_general: self.sort_score = inputs["sort_score"] self.mpi_nn = inputs["mpi_nn"] self.global_angles = self.generate_radom_angles() - self.rotamer = self.prepare_rotamer() - self.maskamer = self.prepare_maskamer() + #self.rotamer = self.prepare_rotamer() + #self.maskamer = self.prepare_maskamer() def match_worker(self, index): # Implementation of the template matching logic @@ -135,7 +135,8 @@ class TemplateMatcher_general: # cur_rot = [phi, theta, psi] cur_rot = angles[idx] for temp in range(num_templates): - template_rot = self.rotamer[temp][idx] + template = self.templates[temp] + template_rot = rotate(template, cur_rot) # for given template at given angles,fetch from the rotamer if self.ctf is not None: template_rot = apply_ctf(template_rot, ctf_vols[temp]) @@ -147,7 +148,8 @@ class TemplateMatcher_general: mask = binary_dilation(template_rot, structuring_element) else: - mask = self.maskamer[temp][idx] + mask = self.masks[temp] + mask = rotate(mask, cur_rot) mask = apply_wedge(mask, self.missing_wedge) mask[mask < self.contour_level[temp]] = 0 mask[mask >= self.contour_level[temp]] = 1 diff --git a/src/global_template_matching_chunk.py b/src/global_template_matching_chunk.py index e0b83ca..b542fa4 100644 --- a/src/global_template_matching_chunk.py +++ b/src/global_template_matching_chunk.py @@ -1,7 +1,7 @@ import numpy as np from scipy.spatial.transform import Rotation as R from multiprocessing import Pool, cpu_count -from utils import rotate_high_res, prepare_ctf_volumes, apply_ctf, apply_wedge +from utils import rotate_high_res, prepare_ctf_volumes, apply_ctf, apply_wedge,rotate from scipy.ndimage import binary_dilation import pandas as pd from config_loader import get_config @@ -29,8 +29,8 @@ class TemplateMatcherGeneral: self.sort_score = inputs["sort_score"] self.mpi_nn = inputs["mpi_nn"] self.global_angles = self.generate_random_angles() - self.rotamer = self.prepare_rotamer() - self.maskamer = self.prepare_maskamer() + #self.rotamer = self.prepare_rotamer() + #self.maskamer = self.prepare_maskamer() # Memory map the tomogram to a temporary file self.tomogram_file_path = self.create_temp_file() @@ -144,24 +144,23 @@ class TemplateMatcherGeneral: cur_ccc = np.zeros(self.dims) for temp in range(num_templates): - template_rot = self.rotamer[temp][idx] + template = self.templates[temp] + template_rot =rotate(template,cur_rot) template_rot = ( apply_ctf(template_rot, ctf_vols[temp]) if self.ctf is not None else apply_wedge(template_rot, self.missing_wedge) ) - mask = ( - self.masks[temp] - if self.masks is not None - else binary_dilation(template_rot, np.ones((3, 3, 3))) - ) - mask = ( - apply_wedge(mask, self.missing_wedge) - if self.masks is not None - else mask - ) - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 + template_rot[template_rot < self.contour_level[temp]] = 0 + if self.masks is None: + structuring_element = np.ones((3, 3, 3)) + mask = binary_dilation(template_rot, structuring_element) + else: + mask = self.masks[temp] + mask = rotate(mask, cur_rot) + mask = apply_wedge(mask, self.missing_wedge) + mask[mask < 0.1] = 0 # the mask should be binary + mask[mask >= 0.1] = 1 cur_ccc = self.calculate_correlation(subtomo, template_rot, mask) print(cur_ccc.max()) @@ -236,6 +235,7 @@ class TemplateMatcherGeneral: normalized_template = (template_volume - template_mean) / template_std data_shape = data_volume.shape + padding = 10 chunks = [range(0, s, cs) for s, cs in zip(data_shape, chunk_size)] correlation_result = np.zeros(data_shape, dtype=np.float32) diff --git a/src/template_matching.py b/src/template_matching.py index bf4f953..108c6af 100644 --- a/src/template_matching.py +++ b/src/template_matching.py @@ -9,14 +9,7 @@ from scipy.ndimage import binary_dilation from lxml import etree from config_loader import get_config -config = get_config() # cannot be self.config, pickle needed - -""" -Just to be clear, there are a few things you need to be aware of: -1. The angles are in degrees and in the order of phi, theta, psi, which is intrinsic ZXZ convention. -2. The template and template2 are the same size. -""" - +config = get_config() # cannot be self.config, pickle needed for MPI class TemplateMatcher: def __init__(self, inputs): @@ -40,7 +33,7 @@ class TemplateMatcher: self.mpi_nn = inputs["mpi_nn"] if self.local_search_angles is False: self.global_angles = self.generate_radom_angles() - self.rotamer = self.prepare_rotamer() + #self.rotamer = self.prepare_rotamer() else: print("Perform Local Refinement!") @@ -201,7 +194,9 @@ class TemplateMatcher: template = self.templates[temp] template_rot = rotate(template, (phi, theta, psi)) else: - template_rot = self.rotamer[temp][i] + template = self.templates[temp] + template_rot = rotate(template, (phi, theta, psi)) + #template_rot = self.rotamer[temp][i] # reserve for furture development # for given template at given angles,fetch from the rotamer if self.ctf is not None: template_rot = apply_ctf(template_rot, ctf_vols[temp]) -- GitLab