diff --git a/src/catm.py b/src/catm.py index bd782512f229ba80790f321b9df5a7dfcc8cc038..0386331cc5c911be20562807eb43c6cc47d4aa60 100644 --- a/src/catm.py +++ b/src/catm.py @@ -15,9 +15,10 @@ import os import sys from config_loader import get_config from template_matching import TemplateMatcher -from global_template_matching_chunk import TemplateMatcherGeneral from clash_resolver import ClashResolver + + # Get the directory where the script is located script_dir = os.path.abspath(os.path.dirname(__file__)) @@ -26,6 +27,14 @@ sys.path.append(script_dir) # import config config = get_config() +if hasattr(config, "testTM"): + if config.testTM: + if config.chunk_size is None: + from global_template_matching import TemplateMatcherGeneral + print('No chunk size specified, assuming the size of tomograms is small') + else: + from global_template_matching_chunk import TemplateMatcherGeneral + print("Running template matching with specified chunk size") __version__ = "0.1.0" __author__ = "Huabin Zhou" diff --git a/src/config.py b/src/config.py index a51f3a8d4473f79812cab2475547e8e593c83c36..7839387a8aefedb76cc24c427ecc5552ed12d25b 100644 --- a/src/config.py +++ b/src/config.py @@ -90,4 +90,4 @@ adjust_ccc = 0.11111111 # max distance between two partles to be considered as the same particle # option's for running only the general template matching testTM = False -adjust_ccc_relion = 0.11111111 +adjust_ccc_relion = 0 diff --git a/src/global_template_matching.py b/src/global_template_matching.py index b8b0d22f00a479076227c106ec3d4d0e7805eded..40d3eb06336398862db645c5d11094e81754e5a7 100644 --- a/src/global_template_matching.py +++ b/src/global_template_matching.py @@ -1,7 +1,7 @@ import numpy as np from scipy.spatial.transform import Rotation as R from multiprocessing import Pool, cpu_count -from utils import rotate_high_res +from utils import rotate from utils import prepare_ctf_volumes, apply_ctf, apply_wedge from scipy.ndimage import binary_dilation import pandas as pd @@ -15,7 +15,7 @@ Works well on small tomograms, no optimzation for large tomograms yet """ -class TemplateMatcher_general: +class TemplateMatcherGeneral: def __init__(self, inputs): self.templates = inputs["templates"] self.contour_level = inputs["contour_level"] @@ -30,8 +30,8 @@ class TemplateMatcher_general: self.sort_score = inputs["sort_score"] self.mpi_nn = inputs["mpi_nn"] self.global_angles = self.generate_radom_angles() - self.rotamer = self.prepare_rotamer() - self.maskamer = self.prepare_maskamer() + #self.rotamer = self.prepare_rotamer() + #self.maskamer = self.prepare_maskamer() def match_worker(self, index): # Implementation of the template matching logic @@ -135,7 +135,8 @@ class TemplateMatcher_general: # cur_rot = [phi, theta, psi] cur_rot = angles[idx] for temp in range(num_templates): - template_rot = self.rotamer[temp][idx] + template = self.templates[temp] + template_rot = rotate(template, cur_rot) # for given template at given angles,fetch from the rotamer if self.ctf is not None: template_rot = apply_ctf(template_rot, ctf_vols[temp]) @@ -147,7 +148,8 @@ class TemplateMatcher_general: mask = binary_dilation(template_rot, structuring_element) else: - mask = self.maskamer[temp][idx] + mask = self.masks[temp] + mask = rotate(mask, cur_rot) mask = apply_wedge(mask, self.missing_wedge) mask[mask < self.contour_level[temp]] = 0 mask[mask >= self.contour_level[temp]] = 1 diff --git a/src/global_template_matching_chunk.py b/src/global_template_matching_chunk.py index e0b83ca5a906535433e7f81fa11ab470cbea1a8d..b542fa45cb8bfe82892c79065ec306d57712a1bc 100644 --- a/src/global_template_matching_chunk.py +++ b/src/global_template_matching_chunk.py @@ -1,7 +1,7 @@ import numpy as np from scipy.spatial.transform import Rotation as R from multiprocessing import Pool, cpu_count -from utils import rotate_high_res, prepare_ctf_volumes, apply_ctf, apply_wedge +from utils import rotate_high_res, prepare_ctf_volumes, apply_ctf, apply_wedge,rotate from scipy.ndimage import binary_dilation import pandas as pd from config_loader import get_config @@ -29,8 +29,8 @@ class TemplateMatcherGeneral: self.sort_score = inputs["sort_score"] self.mpi_nn = inputs["mpi_nn"] self.global_angles = self.generate_random_angles() - self.rotamer = self.prepare_rotamer() - self.maskamer = self.prepare_maskamer() + #self.rotamer = self.prepare_rotamer() + #self.maskamer = self.prepare_maskamer() # Memory map the tomogram to a temporary file self.tomogram_file_path = self.create_temp_file() @@ -144,24 +144,23 @@ class TemplateMatcherGeneral: cur_ccc = np.zeros(self.dims) for temp in range(num_templates): - template_rot = self.rotamer[temp][idx] + template = self.templates[temp] + template_rot =rotate(template,cur_rot) template_rot = ( apply_ctf(template_rot, ctf_vols[temp]) if self.ctf is not None else apply_wedge(template_rot, self.missing_wedge) ) - mask = ( - self.masks[temp] - if self.masks is not None - else binary_dilation(template_rot, np.ones((3, 3, 3))) - ) - mask = ( - apply_wedge(mask, self.missing_wedge) - if self.masks is not None - else mask - ) - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 + template_rot[template_rot < self.contour_level[temp]] = 0 + if self.masks is None: + structuring_element = np.ones((3, 3, 3)) + mask = binary_dilation(template_rot, structuring_element) + else: + mask = self.masks[temp] + mask = rotate(mask, cur_rot) + mask = apply_wedge(mask, self.missing_wedge) + mask[mask < 0.1] = 0 # the mask should be binary + mask[mask >= 0.1] = 1 cur_ccc = self.calculate_correlation(subtomo, template_rot, mask) print(cur_ccc.max()) @@ -236,6 +235,7 @@ class TemplateMatcherGeneral: normalized_template = (template_volume - template_mean) / template_std data_shape = data_volume.shape + padding = 10 chunks = [range(0, s, cs) for s, cs in zip(data_shape, chunk_size)] correlation_result = np.zeros(data_shape, dtype=np.float32) diff --git a/src/template_matching.py b/src/template_matching.py index bf4f953e0debc86f69f8e9b4375837efc632688c..108c6affdbf02f9a58ac33e070e7fcd263a22c3e 100644 --- a/src/template_matching.py +++ b/src/template_matching.py @@ -9,14 +9,7 @@ from scipy.ndimage import binary_dilation from lxml import etree from config_loader import get_config -config = get_config() # cannot be self.config, pickle needed - -""" -Just to be clear, there are a few things you need to be aware of: -1. The angles are in degrees and in the order of phi, theta, psi, which is intrinsic ZXZ convention. -2. The template and template2 are the same size. -""" - +config = get_config() # cannot be self.config, pickle needed for MPI class TemplateMatcher: def __init__(self, inputs): @@ -40,7 +33,7 @@ class TemplateMatcher: self.mpi_nn = inputs["mpi_nn"] if self.local_search_angles is False: self.global_angles = self.generate_radom_angles() - self.rotamer = self.prepare_rotamer() + #self.rotamer = self.prepare_rotamer() else: print("Perform Local Refinement!") @@ -201,7 +194,9 @@ class TemplateMatcher: template = self.templates[temp] template_rot = rotate(template, (phi, theta, psi)) else: - template_rot = self.rotamer[temp][i] + template = self.templates[temp] + template_rot = rotate(template, (phi, theta, psi)) + #template_rot = self.rotamer[temp][i] # reserve for furture development # for given template at given angles,fetch from the rotamer if self.ctf is not None: template_rot = apply_ctf(template_rot, ctf_vols[temp])