From 40174ac8140f6d69e7557f51f2327782e1936eed Mon Sep 17 00:00:00 2001 From: Huabin Zhou <huabin.zhou@utsouthwestern.edu> Date: Tue, 13 Aug 2024 17:07:33 -0500 Subject: [PATCH] update README.md and src files --- README.md | 34 +- src/__init.py__ | 0 src/catching.egg-info/PKG-INFO | 100 +++++ src/catching.egg-info/SOURCES.txt | 22 + src/catching.egg-info/dependency_links.txt | 1 + src/catching.egg-info/entry_points.txt | 2 + src/catching.egg-info/requires.txt | 33 ++ src/catching.egg-info/top_level.txt | 14 + src/catm.egg-info/PKG-INFO | 65 +++ src/catm.egg-info/SOURCES.txt | 22 + src/catm.egg-info/dependency_links.txt | 1 + src/catm.egg-info/entry_points.txt | 2 + src/catm.egg-info/requires.txt | 32 ++ src/catm.egg-info/top_level.txt | 14 + src/catm.py | 75 ++++ src/clash_resolver.py | 450 ++++++++++++++++++++ src/config.py | 93 ++++ src/config_loader.py | 58 +++ src/file_handler.py | 335 +++++++++++++++ src/filter_tomograms.py | 70 +++ src/geo_utils.py | 31 ++ src/global_template_matching.py | 283 ++++++++++++ src/global_template_matching_chunk.py | 346 +++++++++++++++ src/misc.py | 20 + src/plot_nucleosome_with_z.py | 29 ++ src/slurm-tm.sh | 25 ++ src/template_matching.py | 473 +++++++++++++++++++++ src/test_ccc.py | 62 +++ src/test_config.py | 85 ++++ src/test_multiprocessing.py | 24 ++ src/utils.py | 362 ++++++++++++++++ 31 files changed, 3154 insertions(+), 9 deletions(-) create mode 100644 src/__init.py__ create mode 100644 src/catching.egg-info/PKG-INFO create mode 100644 src/catching.egg-info/SOURCES.txt create mode 100644 src/catching.egg-info/dependency_links.txt create mode 100644 src/catching.egg-info/entry_points.txt create mode 100644 src/catching.egg-info/requires.txt create mode 100644 src/catching.egg-info/top_level.txt create mode 100644 src/catm.egg-info/PKG-INFO create mode 100644 src/catm.egg-info/SOURCES.txt create mode 100644 src/catm.egg-info/dependency_links.txt create mode 100644 src/catm.egg-info/entry_points.txt create mode 100644 src/catm.egg-info/requires.txt create mode 100644 src/catm.egg-info/top_level.txt create mode 100644 src/catm.py create mode 100644 src/clash_resolver.py create mode 100644 src/config.py create mode 100644 src/config_loader.py create mode 100644 src/file_handler.py create mode 100644 src/filter_tomograms.py create mode 100644 src/geo_utils.py create mode 100644 src/global_template_matching.py create mode 100644 src/global_template_matching_chunk.py create mode 100644 src/misc.py create mode 100644 src/plot_nucleosome_with_z.py create mode 100644 src/slurm-tm.sh create mode 100644 src/template_matching.py create mode 100644 src/test_ccc.py create mode 100644 src/test_config.py create mode 100644 src/test_multiprocessing.py create mode 100644 src/utils.py diff --git a/README.md b/README.md index 2bcb5cc..886eef1 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,16 @@ -# Context Awared Template Matching (CATM) for Cryo-ET subtomogram averaging +# Context Awared Template Matching (CATM) for Cryo-ET data analysis # CATM version 0.1 [](https://zenodo.org/badge/latestdoi/) ## # Introduction -This software is designed for the using +This software is designed for the assign the models in crowded cryo-ET tomograms. The main point is to use the pre-picked +coordinates from the other software as input, and calculate the cross-correlation coefficients (CCCs) with the templates +provided. The software then use these information and combine the geometric restrict between the models to archive high accurate +assignments. It involved with a clash resolver (CRer) module to optimize the pose between the models. ## Installation -python version at least 3.5 is required. If you download the package as a zip file from github, please rename the folder IsoNet-master to IsoNet. +python version at least 3.5 is required, create a separated environment if you want. 1. Creat a new vitual enviroment (optional) @@ -30,14 +33,27 @@ pip install . ``` cd test - -CATM - +CATM #run the main program ``` - ## Usage -### 1. Prepare your data, required files are: -- a stack of subtomograms in mrc format +- There are three different modes the program currently supported, "guilded-TM-CRer", "guilded-localTM-CRer", "traditional TM" +### 1. Prepare your data, required files are: +- a tomogram file, usually produced by back-projection and the low-pass filtered, in mrc format +- a list of candidate particles from deep-learning based pickers, manual picking or others, support csv, coord, star formats +- a 3D ctf model from Warp or relion (optional), if not provide, the missing wedge info show be given +- one or a few templates files, the templates need to be in a cubic volume for now +- corresponding mask files for each templates (optional), need to be the same dimensions as the templates if provided. If not provided, a tight mask will be generated according (recommended) +### 2. Determine the counter level for template matching and clash resolver +- run the counter_level_check.py file for your template, this only need to be done once for your similar dataset +''' +python counter_level_check.py path_to_template.mrc [ctf_model or missing wedge info] +''' +- check the counter level in Chimera/ChimeraX and choose the level with least artifact +### 3. copy the config.py file to your work dir, and adjust the parameters for your run +- There are three different modes the program currently supported, "Context", +### 4. + + diff --git a/src/__init.py__ b/src/__init.py__ new file mode 100644 index 0000000..e69de29 diff --git a/src/catching.egg-info/PKG-INFO b/src/catching.egg-info/PKG-INFO new file mode 100644 index 0000000..ebb8763 --- /dev/null +++ b/src/catching.egg-info/PKG-INFO @@ -0,0 +1,100 @@ +Metadata-Version: 2.1 +Name: catching +Version: 0.1.0 +Summary: A software for template matching and clash resolving in cryo-EM +Author-email: Huabin Zhou <huabin.zhou@utsouthwestern.edu> +License: Copyright (c) 2024 The University of Texas Southwestern Medical Center. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted for academic research use only (subject to the limitations in the disclaimer below) provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + + * Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + + ANY USE OR REDISTRIBUTION OF THIS SOFTWARE FOR COMMERCIAL PURPOSES, WHETHER IN SOURCE OR BINARY FORM, WITH OR WITHOUT MODIFICATION, IS EXPRESSLY PROHIBITED; ANY USE OR REDISTRIBUTION BY A FOR-PROFIT ENTITY SHALL COMPRISE USE OR REDISTRIBUTION FOR COMMERCIAL PURPOSES. + + NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE, AND ANY ACCOMPANYING DOCUMENTATION, IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE OR ANY OF ITS ACCOMPANYING DOCUMENTATION, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Project-URL: Source, https://github.com/mwaskom/seaborn +Project-URL: Docs, http://seaborn.pydata.org +Classifier: Intended Audience :: Science/Research +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: License :: OSI Approved :: BSD License +Classifier: Topic :: Cryo-EM Data Processing +Classifier: Operating System :: OS Independent +Classifier: Framework :: Matplotlib +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +License-File: LICENSE.md +Requires-Dist: numpy +Requires-Dist: pandas +Requires-Dist: mrcfile +Requires-Dist: scikit-learn +Requires-Dist: scipy +Requires-Dist: lxml +Requires-Dist: starfile +Provides-Extra: stats +Requires-Dist: scipy>=1.7; extra == "stats" +Requires-Dist: statsmodels>=0.12; extra == "stats" +Provides-Extra: dev +Requires-Dist: matplotlib; extra == "dev" +Requires-Dist: pytest; extra == "dev" +Requires-Dist: pytest-cov; extra == "dev" +Requires-Dist: pytest-xdist; extra == "dev" +Requires-Dist: flake8; extra == "dev" +Requires-Dist: mypy; extra == "dev" +Requires-Dist: pandas-stubs; extra == "dev" +Requires-Dist: pre-commit; extra == "dev" +Requires-Dist: flit; extra == "dev" +Provides-Extra: docs +Requires-Dist: numpydoc; extra == "docs" +Requires-Dist: nbconvert; extra == "docs" +Requires-Dist: ipykernel; extra == "docs" +Requires-Dist: sphinx<6.0.0; extra == "docs" +Requires-Dist: sphinx-copybutton; extra == "docs" +Requires-Dist: sphinx-issues; extra == "docs" +Requires-Dist: sphinx-design; extra == "docs" +Requires-Dist: pyyaml; extra == "docs" +Requires-Dist: pydata_sphinx_theme==0.10.0rc2; extra == "docs" + +# Context Awared Template maCHING (CATCHing) for subtomogram averaging + +# Caching version 0.1 +[](https://zenodo.org/badge/latestdoi/222662248) +## +# Introduction +This software is designed for the using + +## Installation +python version at least 3.5 is required. If you download the package as a zip file from github, please rename the folder IsoNet-master to IsoNet. + +1. Creat a new vitual enviroment (optional) + +``` +conda create -n catching python +``` + +2. Install the package + +``` +pip install . +``` + +3. Test the installation +``` +catching_test +``` + + + +## Usage +### 1. Prepare your data + diff --git a/src/catching.egg-info/SOURCES.txt b/src/catching.egg-info/SOURCES.txt new file mode 100644 index 0000000..cac7fb4 --- /dev/null +++ b/src/catching.egg-info/SOURCES.txt @@ -0,0 +1,22 @@ +LICENSE.md +README.md +pyproject.toml +src/basic.py +src/catm.py +src/clash_resolver.py +src/file_handler.py +src/old_utils.py +src/run.py +src/template_matching.py +src/test_config.py +src/test_multiprocessing.py +src/utils.py +src/catching.egg-info/PKG-INFO +src/catching.egg-info/SOURCES.txt +src/catching.egg-info/dependency_links.txt +src/catching.egg-info/entry_points.txt +src/catching.egg-info/requires.txt +src/catching.egg-info/top_level.txt +src/misc/angluar_assement.py +src/templates/generate_spherical_mask.py +src/templates/rotate_tomo.py \ No newline at end of file diff --git a/src/catching.egg-info/dependency_links.txt b/src/catching.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/catching.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/catching.egg-info/entry_points.txt b/src/catching.egg-info/entry_points.txt new file mode 100644 index 0000000..00fe888 --- /dev/null +++ b/src/catching.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +catm = catm:main diff --git a/src/catching.egg-info/requires.txt b/src/catching.egg-info/requires.txt new file mode 100644 index 0000000..536fc49 --- /dev/null +++ b/src/catching.egg-info/requires.txt @@ -0,0 +1,33 @@ +numpy +pandas +mrcfile +scikit-learn +scipy +lxml +starfile + +[dev] +matplotlib +pytest +pytest-cov +pytest-xdist +flake8 +mypy +pandas-stubs +pre-commit +flit + +[docs] +numpydoc +nbconvert +ipykernel +sphinx<6.0.0 +sphinx-copybutton +sphinx-issues +sphinx-design +pyyaml +pydata_sphinx_theme==0.10.0rc2 + +[stats] +scipy>=1.7 +statsmodels>=0.12 diff --git a/src/catching.egg-info/top_level.txt b/src/catching.egg-info/top_level.txt new file mode 100644 index 0000000..33e42a6 --- /dev/null +++ b/src/catching.egg-info/top_level.txt @@ -0,0 +1,14 @@ +basic +catm +clash_resolver +csv +file_handler +misc +old_utils +results +run +template_matching +templates +test_config +test_multiprocessing +utils diff --git a/src/catm.egg-info/PKG-INFO b/src/catm.egg-info/PKG-INFO new file mode 100644 index 0000000..7695a24 --- /dev/null +++ b/src/catm.egg-info/PKG-INFO @@ -0,0 +1,65 @@ +Metadata-Version: 2.1 +Name: catm +Version: 0.0.0 +Summary: A software for template matching and clash resolving in cryo-EM +Author-email: Huabin Zhou <huabin.zhou@utsouthwestern.edu> +License: Copyright (c) 2024 The University of Texas Southwestern Medical Center. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted for academic research use only (subject to the limitations in the disclaimer below) provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + + * Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + + ANY USE OR REDISTRIBUTION OF THIS SOFTWARE FOR COMMERCIAL PURPOSES, WHETHER IN SOURCE OR BINARY FORM, WITH OR WITHOUT MODIFICATION, IS EXPRESSLY PROHIBITED; ANY USE OR REDISTRIBUTION BY A FOR-PROFIT ENTITY SHALL COMPRISE USE OR REDISTRIBUTION FOR COMMERCIAL PURPOSES. + + NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE, AND ANY ACCOMPANYING DOCUMENTATION, IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE OR ANY OF ITS ACCOMPANYING DOCUMENTATION, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Project-URL: Source, https://github.com/mwaskom/seaborn +Project-URL: Docs, http://seaborn.pydata.org +Classifier: Intended Audience :: Science/Research +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: License :: OSI Approved :: BSD License +Classifier: Topic :: Cryo-EM Data Processing +Classifier: Operating System :: OS Independent +Classifier: Framework :: Matplotlib +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +License-File: LICENSE.md +Requires-Dist: numpy +Requires-Dist: pandas +Requires-Dist: mrcfile +Requires-Dist: scikit-learn +Requires-Dist: scipy +Requires-Dist: lxml +Provides-Extra: stats +Requires-Dist: scipy>=1.7; extra == "stats" +Requires-Dist: statsmodels>=0.12; extra == "stats" +Provides-Extra: dev +Requires-Dist: matplotlib; extra == "dev" +Requires-Dist: pytest; extra == "dev" +Requires-Dist: pytest-cov; extra == "dev" +Requires-Dist: pytest-xdist; extra == "dev" +Requires-Dist: flake8; extra == "dev" +Requires-Dist: mypy; extra == "dev" +Requires-Dist: pandas-stubs; extra == "dev" +Requires-Dist: pre-commit; extra == "dev" +Requires-Dist: flit; extra == "dev" +Provides-Extra: docs +Requires-Dist: numpydoc; extra == "docs" +Requires-Dist: nbconvert; extra == "docs" +Requires-Dist: ipykernel; extra == "docs" +Requires-Dist: sphinx<6.0.0; extra == "docs" +Requires-Dist: sphinx-copybutton; extra == "docs" +Requires-Dist: sphinx-issues; extra == "docs" +Requires-Dist: sphinx-design; extra == "docs" +Requires-Dist: pyyaml; extra == "docs" +Requires-Dist: pydata_sphinx_theme==0.10.0rc2; extra == "docs" diff --git a/src/catm.egg-info/SOURCES.txt b/src/catm.egg-info/SOURCES.txt new file mode 100644 index 0000000..48df523 --- /dev/null +++ b/src/catm.egg-info/SOURCES.txt @@ -0,0 +1,22 @@ +LICENSE.md +README.md +pyproject.toml +src/basic.py +src/catm.py +src/clash_resolver.py +src/config.py +src/file_handler.py +src/old_utils.py +src/run.py +src/template_matching.py +src/test_config.py +src/utils.py +src/catm.egg-info/PKG-INFO +src/catm.egg-info/SOURCES.txt +src/catm.egg-info/dependency_links.txt +src/catm.egg-info/entry_points.txt +src/catm.egg-info/requires.txt +src/catm.egg-info/top_level.txt +src/misc/angluar_assement.py +src/templates/generate_spherical_mask.py +src/templates/rotate_tomo.py \ No newline at end of file diff --git a/src/catm.egg-info/dependency_links.txt b/src/catm.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/catm.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/catm.egg-info/entry_points.txt b/src/catm.egg-info/entry_points.txt new file mode 100644 index 0000000..00fe888 --- /dev/null +++ b/src/catm.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +catm = catm:main diff --git a/src/catm.egg-info/requires.txt b/src/catm.egg-info/requires.txt new file mode 100644 index 0000000..3568fc1 --- /dev/null +++ b/src/catm.egg-info/requires.txt @@ -0,0 +1,32 @@ +numpy +pandas +mrcfile +scikit-learn +scipy +lxml + +[dev] +matplotlib +pytest +pytest-cov +pytest-xdist +flake8 +mypy +pandas-stubs +pre-commit +flit + +[docs] +numpydoc +nbconvert +ipykernel +sphinx<6.0.0 +sphinx-copybutton +sphinx-issues +sphinx-design +pyyaml +pydata_sphinx_theme==0.10.0rc2 + +[stats] +scipy>=1.7 +statsmodels>=0.12 diff --git a/src/catm.egg-info/top_level.txt b/src/catm.egg-info/top_level.txt new file mode 100644 index 0000000..da01f31 --- /dev/null +++ b/src/catm.egg-info/top_level.txt @@ -0,0 +1,14 @@ +basic +catm +clash_resolver +config +csv +file_handler +misc +old_utils +results +run +template_matching +templates +test_config +utils diff --git a/src/catm.py b/src/catm.py new file mode 100644 index 0000000..c787a1f --- /dev/null +++ b/src/catm.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 13 12:39:08 2021 +To do the template matching,the first stp is taking the input from the user: +1. Read in the configuration file +2. run the template matching +3. resolve the clashes and optimize the particles positions/poses +4. output the results with xml, csv and relion star file +""" + +from file_handler import parse_input +import time +import os +import sys +from config_loader import get_config +from template_matching import TemplateMatcher +from global_template_matching_chunk import TemplateMatcherGeneral +from clash_resolver import ClashResolver + +# Get the directory where the script is located +script_dir = os.path.abspath(os.path.dirname(__file__)) + +# Add the script directory to the system path +sys.path.append(script_dir) +# import config +config = get_config() + + +__version__ = "0.1.0" +__author__ = "Huabin Zhou" + + +def main(): + star_time = time.time() + user_inputs = parse_input() + if hasattr(config, "testTM"): # this will run only the global TM + if config.testTM: + print("Running general template matching!") + matcher = TemplateMatcherGeneral(user_inputs) + matcher.run_multiprocessing() + end_time = time.time() + total_time = end_time - star_time + print(f"Total runtime: {total_time:.2f} seconds") + print("Done!") + exit(0) + + if config.bypass_TM is False: + print("Running Context-aware template matching!") + matcher = TemplateMatcher(user_inputs) + # obj_xml = matcher.match_worker(4) # just for debugging the coorelation + obj_xml = matcher.run_multiprocessing() + print("Template matching finished, saving results") + else: + print("Bypassing template matching") + # for debugging, read the xml file + from lxml.etree import XMLParser, parse + + ps = XMLParser(huge_tree=True) + tree = parse(user_inputs["output_path"] + "match_raw_results.xml", parser=ps) + obj_xml = tree.getroot() + if config.bypass_CR is False: + print("Running clash resolving") + if config.bypass_optimizer: + print("Bypassing clash resolving optimizer") + resolver = ClashResolver(user_inputs, obj_xml) + resolver.resolve_clashes() + end_time = time.time() + total_time = end_time - star_time + print(f"Total runtime: {total_time:.2f} seconds") + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/src/clash_resolver.py b/src/clash_resolver.py new file mode 100644 index 0000000..d446cad --- /dev/null +++ b/src/clash_resolver.py @@ -0,0 +1,450 @@ +from file_handler import find_xml, save_xml_one_loop, coord_to_relion +from file_handler import write_mrc +from utils import rotate_high_res, try_add_obj, plot_obj, find_info +from sklearn.neighbors import NearestNeighbors +import numpy as np +from multiprocessing import Pool, cpu_count +import copy +import starfile +import pandas as pd +from lxml import etree +from functools import partial +from config_loader import get_config + +config = get_config() + + +class ClashResolver: + def __init__(self, inputs, obj_xml): + self.tomogram = inputs["tomogram"] + self.coords = inputs["coords"] + self.templates = inputs["templates"] + self.missing_wedge = inputs["missing_wedge"] + self.shrinkage = inputs["shrinkage_factor"] + self.search_depth = inputs["search_depth"] + self.output_path = inputs["output_path"] + self.write_model_file = inputs["write_model_file"] + self.obj_xml = obj_xml + self.box_sizes = [ + self.templates[i].shape[0] for i in range(len(self.templates)) + ] + + # Initialize attributes to store results + self.coord_save = [] + self.rot_save = [] + self.ccc_save = [] + self.model_save = [] + self.index = [] + self.resolved_particle = 0 + self.optimized_particle = 0 + self.pre_assigned_volume = inputs["pre_assigned_volume"] + + def resolve_clashes(self): + obj_xml = self.obj_xml + # if hasattr(config, "adjust_ccc") and config.adjust_ccc is not None: + # print("Adjusting CCC values for first half Relion by ", config.adjust_ccc) + # obj_xml = adjust_ccc_values_second_half( + # obj_xml, len(self.coords), config.adjust_ccc + # ) + print(obj_xml) + dims = self.tomogram.shape + if self.pre_assigned_volume is not None: + vol_array = self.pre_assigned_volume + vol_array.setflags("write=1") + else: + vol_array = np.zeros(dims, dtype=np.float32) + print("Total number of particles: ", len(self.obj_xml)) + # this is the main stream of the program,loop through all the particles + for ptl in range(len(self.obj_xml)): + if ptl % 100 == 0 and ptl > 0: + print("processing particle: ", ptl) + print("have got particles: ", len(self.coord_save)) + run_though = 0 + for pose in range(min(len(obj_xml[ptl]), self.search_depth)): + coord, angles, ccc, mod = find_xml(obj_xml, ptl, pose) + obj_rot = rotate_high_res(self.templates[mod], angles) + vol_array, score = try_add_obj( + vol_array, obj_rot, coord, self.shrinkage, ccc + ) + if score > 0: # -1 means clash + self.coord_save.append(coord) + self.rot_save.append(angles) + self.ccc_save.append(score) + self.model_save.append(mod) + self.index.append(ptl) + run_though = 1 + break + if pose == 0 and score <= 0: + self.optimized_particle += 1 + if ( + run_though == 0 + and config.bypass_optimizer is False + and len(self.coord_save) > 0 + ): + # current particle has clash, doing refinement + coord, angles, ccc, mod = find_xml(obj_xml, ptl, 0) + """ + # Here it can be simplfied by fit the current_point to coord_save to find the NN, TODO in the future + self.coord_save.append(coord) # temporarily save the coord + self.rot_save.append(angles) + self.ccc_save.append(ccc) + self.model_save.append(mod) + self.index.append(ptl) + nbrs = NearestNeighbors(n_neighbors=2, algorithm="ball_tree").fit( + self.coord_save + ) + distances, indices = nbrs.kneighbors(self.coord_save) + current_point = len(self.coord_save) - 1 + nbr = indices[current_point][1] + del self.coord_save[-1] # remove the last one, which is the current one + del self.rot_save[-1] + del self.ccc_save[-1] + del self.model_save[-1] + del self.index[-1] + """ + nbrs = NearestNeighbors(n_neighbors=1, algorithm="ball_tree").fit( + self.coord_save + ) + distances, indices = nbrs.kneighbors( + [coord] + ) # Provide the new coordinate as a list of one item + nbr = indices[0][0] + # Get the nearest neighbor index for the provided point + # if the distance is too close, pass, continue to the loop + if ( + distances[0][0] <= config.distance_tolerance + or distances[0][0] >= 12 + ): + continue + """ + # following code adapted from the original code, sometime there is error in the kd-tree that the order of the element is flip for the last pair, wired + if nbr >= indices[current_point][0]: + nbr = indices[current_point][0] + """ + nbr_coord = self.coord_save[nbr] + nbr_angles = self.rot_save[nbr] + nbr_mod = self.model_save[nbr] + nbr_obj_rot = rotate_high_res(self.templates[nbr_mod], nbr_angles) + nbr_idx = int(self.index[nbr]) + + # remove the neighbor from the list for now + del self.coord_save[nbr] + del self.rot_save[nbr] + del self.ccc_save[nbr] + del self.model_save[nbr] + del self.index[nbr] + # erase the particle from the volume array + vol_array = self.clean_particle(vol_array, nbr_obj_rot, nbr_coord) + + ## to make the program faster, we can retrieve all the candidate into lists + curinfo = find_info(self.obj_xml, ptl, self.search_depth) + xx1 = int(curinfo[0][0]) + yy1 = int(curinfo[1][0]) + zz1 = int(curinfo[2][0]) + nbrinfo = find_info(self.obj_xml, nbr_idx, self.search_depth) + ext = ( + max(self.box_sizes) * 2 + ) # this is the size expansion of the the box + xxl = max(0, xx1 - ext) + yyl = max(0, yy1 - ext) + zzl = max(0, zz1 - ext) + xxr = min(xx1 + ext, dims[2]) + yyr = min(yy1 + ext, dims[1]) + zzr = min(zz1 + ext, dims[0]) + vol_test = vol_array[zzl:zzr, yyl:yyr, xxl:xxr] + offset2 = [min(ext, zz1), min(ext, yy1), min(ext, xx1)] + oom, ppm = self.run_multiprocessing(vol_test, curinfo, nbrinfo, offset2) + # plot the best one back + vol_array = self.plot_back( + vol_array, ptl, nbr_idx, oom, ppm, nbr, self.shrinkage + ) + self.save_results(vol_array) + + def save_results(self, vol_array): + prefix = config.prefix + # save xml file + results_xml = save_xml_one_loop( + self.ccc_save, self.rot_save, self.coord_save, self.model_save + ) + tree = etree.ElementTree(results_xml) + tree.write(self.output_path + prefix + ".xml", pretty_print=True) + # save csv file + coords_array = np.array(self.coord_save) + rots_array = np.array(self.rot_save) + ccc_array = np.array(self.ccc_save).reshape(-1, 1) # Reshape for concatenation + model_array = np.array(self.model_save).reshape(-1, 1) + combined_array = np.hstack((coords_array, rots_array, ccc_array, model_array)) + df = pd.DataFrame( + combined_array, + columns=["x", "y", "z", "phi", "theta", "psi", "ccc", "model"], + ) + df.to_csv(self.output_path + prefix + ".csv", index=False) + # save star file + star = coord_to_relion(self.coord_save, self.rot_save, self.ccc_save) + starfile.write(star, self.output_path + prefix + ".star", overwrite=True) + # plot the annotation over the tomogram + df2 = coord_to_relion(self.coord_save, self.rot_save) + # this will write the coordinates and ZYZ rotation to coords file + df2.to_csv( + self.output_path + prefix + ".coords", + sep=" ", + index=False, + header=False, + ) + if self.write_model_file: + write_mrc(vol_array, self.output_path + prefix + ".models.mrc") + + """ + tomo = np.zeros(self.tomogram.shape, dtype=np.float32) + if self.write_overlap_files: + overlaped = copy.deepcopy(self.tomogram) + for i in range(len(self.coord_save)): + object_rot = rotate_high_res(self.templates[0], self.rot_save[i]) + overlapped, tomo = self.plot_obj_two_arrays( + overlaped, + tomo, + object_rot, + self.coord_save[i], + self.shrinkage, + ) + + write_mrc(overlaped, self.output_path + prefix + "-overlapped.mrc") + # print("Write the assignments to the tomogram!") + """ + """ + write_mrc(tomo, self.output_path + suffix + "-plot.mrc") + best = np.zeros(self.tomogram.shape, dtype=np.float32) + object_rot = rotate_high_res(self.templates[0], self.rot_save[0]) + best = plot_obj(best, object_rot, self.coord_save[0], self.shrinkage) + write_mrc(best, self.output_path + suffix + "-best.mrc") + worst = np.zeros(self.tomogram.shape, dtype=np.float32) + object_rot = rotate_high_res(self.templates[0], self.rot_save[-1]) + worst = plot_obj(worst, object_rot, self.coord_save[-1], self.shrinkage) + write_mrc(worst, self.output_path + suffix + "-worst.mrc") + """ + # print some metrics + print("Total number of particles: ", len(self.obj_xml)) + print("Total number of resolved particles: ", self.resolved_particle) + print("Total number of optimized particles: ", self.optimized_particle) + print("Total number of particles got: ", len(self.coord_save)) + print( + "Total number of particles lost: ", len(self.obj_xml) - len(self.coord_save) + ) + + def plot_obj_two_arrays(self, vol_array, tomo, obj_rot, coord, shrink): + dims = vol_array.shape + offset = int(obj_rot.shape[0] // 2) + obj_voxel = np.nonzero(obj_rot > shrink) + x_vox = obj_voxel[2] + int(coord[0]) - offset + y_vox = obj_voxel[1] + int(coord[1]) - offset + z_vox = obj_voxel[0] + int(coord[2]) - offset + # vol_array[np.array(z_vox),np.array(y_vox),np.array(x_vox)]=obj_rot[ + # np.array(obj_voxel[0]),np.array(obj_voxel[1]),np.array(obj_voxel[2])] + for idx in range(x_vox.size): + xx = x_vox[idx] + yy = y_vox[idx] + zz = z_vox[idx] + aa = obj_voxel[2][idx] + bb = obj_voxel[1][idx] + cc = obj_voxel[0][idx] + if 0 <= xx < dims[2] and 0 <= yy < dims[1] and 0 < zz < dims[0]: + vol_array[zz, yy, xx] = obj_rot[cc, bb, aa] + tomo[zz, yy, xx] = obj_rot[cc, bb, aa] + return vol_array, tomo + + def run_multiprocessing(self, vol_test, curinfo, nbrinfo, offset2): + addscore = [] + oo = [] + pp = [] + b = list(range(min(self.search_depth, len(curinfo[0])))) + ncpu = cpu_count() + pool2 = Pool(ncpu - 2) + # run2 = 0 # options for expand the searching + # templates_copy = copy.deepcopy(self.templates) + func1 = partial( + self.max_neg, + vol_test, + self.templates, + curinfo, + nbrinfo, + offset2, + self.shrinkage, + self.search_depth, + ) + for o, p, ad in pool2.map(func1, b): + oo.append(o) + pp.append(p) + addscore.append(ad) + pool2.close() + pool2.join() + + """ + # by pass the multiprocessing + for o in b: + oo1, pp1, ss1 = self.max_neg(vol_test, curinfo, nbrinfo, offset2, o) + oo.append(oo1) + pp.append(pp1) + addscore.append(ss1) + """ + addscore = [k for sub in addscore for k in sub] + oo = [k for sub in oo for k in sub] + pp = [k for sub in pp for k in sub] + # maxccc = max(addscore) + maxindex = addscore.index(max(addscore)) + # select the best conf and plot back + oom = oo[maxindex] + ppm = pp[maxindex] + return oom, ppm + + @staticmethod + def max_neg( + vol_test, templates, curinfo, nbsinfo, offset2, shrinkage, search_depth, o + ): + # it has to be a static method, otherwise it can't be pickled self.obj_xml + oo1 = [] + pp1 = [] + ss1 = [] + vol_array2 = copy.deepcopy(vol_test) + xx1 = int(curinfo[0][o]) + yy1 = int(curinfo[1][o]) + zz1 = int(curinfo[2][o]) + phi1 = curinfo[3][o] + theta1 = curinfo[4][o] + psi1 = curinfo[5][o] + ccc1 = curinfo[6][o] + mod = int(float(curinfo[7][o])) + # if int(vol_array.shape[1])<30: + xx1 = xx1 - int(curinfo[0][0]) + offset2[2] + yy1 = yy1 - int(curinfo[1][0]) + offset2[1] + zz1 = zz1 - int(curinfo[2][0]) + offset2[0] + obj_rot = rotate_high_res(templates[mod], [phi1, theta1, psi1]) + # apply wedge + # obj_rot = apply_wedge(obj_rot, mAng) + # add current obj, if clash, score -1,if ok, score=CCC + vol_array2, s1 = try_add_obj( + vol_array2, obj_rot, [xx1, yy1, zz1], shrinkage, ccc1 + ) + + # add nbs obj to the previous volume, if clash, score -1, if ok, score = CC + for p in range(min(search_depth, len(nbsinfo[0]))): + # if run2 == 1 and o < self.search_depth and p < self.search_depth: + # continue + vol_array3 = copy.deepcopy(vol_test) + xx2 = int(nbsinfo[0][p]) + yy2 = int(nbsinfo[1][p]) + zz2 = int(nbsinfo[2][p]) + phi2 = nbsinfo[3][p] + theta2 = nbsinfo[4][p] + psi2 = nbsinfo[5][p] + ccc2 = nbsinfo[6][p] + # if int(vol_array.shape[1])<70: + xx2 = xx2 - int(curinfo[0][0]) + offset2[2] + yy2 = yy2 - int(curinfo[1][0]) + offset2[1] + zz2 = zz2 - int(curinfo[2][0]) + offset2[0] + obj2_rot = rotate_high_res(templates[mod], [phi2, theta2, psi2]) + vol_array3, s2 = try_add_obj( + vol_array3, obj2_rot, [xx2, yy2, zz2], shrinkage, ccc2 + ) + # neither work + if s1 <= 0 and s2 <= 0: + ss1.append(-1) + oo1.append(None) + pp1.append(None) + else: + # apply wedge + # obj2_rot = apply_wedge(obj2_rot, mAng) # vol_array3, s2 = add_obj( + # vol_array3, obj2_rot, xx2, yy2, zz2, offset_x,shrink, ccc2) #we don't need it, because the + vol_array4, s3 = try_add_obj( + vol_array2, obj2_rot, [xx2, yy2, zz2], shrinkage, ccc2 + ) # consider all + + if s3 > 0 and s1 > 0: + pp1.append(p) + oo1.append(o) + ss1.append(s1 + s2) + else: + # if the current obj clash with nbs, plot nbs only to choose best one + if s2 > s1: + pp1.append(p) + oo1.append(None) + ss1.append(s2) + else: + oo1.append(o) + pp1.append(None) + ss1.append(s1) + + return oo1, pp1, ss1 + + def plot_back(self, vol_array, ptl, nbr_idx, oom, ppm, nbr, shrink): + if ppm is not None: + coord, rot, c, mod = find_xml(self.obj_xml, nbr_idx, ppm) + obj_rot = rotate_high_res(self.templates[mod], rot) + # apply wedge + # obj_rot = apply_wedge(obj_rot, mAng) + vol_array = plot_obj(vol_array, obj_rot, coord, self.shrinkage) + # add the coor back + self.coord_save.insert(nbr, coord) + self.rot_save.insert(nbr, rot) + self.index.insert(nbr, nbr_idx) + self.ccc_save.insert(nbr, c) + self.model_save.insert(nbr, mod) + # print("put the neighbor back") + if oom is not None: + coord, rot, c, mod = find_xml(self.obj_xml, ptl, oom) + obj_rot = rotate_high_res(self.templates[mod], rot) + # apply wedge + # obj_rot = apply_wedge(obj_rot, mAng) + # t1=clash_new(x_vox,y_vox,z_vox,vol_array,vol_array.shape) + # a1,t1=add_obj(vol_array,obj_rot,xx1,yy1,zz1,offset_x,shrink,c) + # if t1<=0: + # raise RuntimeError('cant add current obj, the position is '+ str(xx1)+ ' ' +str(yy1) + ' '+ str(zz1)) + vol_array = plot_obj(vol_array, obj_rot, coord, shrink) + # add the coor back + self.coord_save.append(coord) + self.rot_save.append(rot) + self.index.append(ptl) + self.ccc_save.append(c) + self.model_save.append(mod) + # print("current one assigned") + if ppm is not None: + self.resolved_particle += 1 + # print("resolved one clash!") + + return vol_array + + def find_info(self, ch1): + # this is original code, I don't want to rewrite it for now + info = [[] for i in range(7)] + obj_xml = self.obj_xml + for i in range(min(len(obj_xml[ch1]), self.search_depth)): + xx = obj_xml[ch1][i].attrib["x"] + yy = obj_xml[ch1][i].attrib["y"] + zz = obj_xml[ch1][i].attrib["z"] + phi = obj_xml[ch1][i].attrib["phi"] + psi = obj_xml[ch1][i].attrib["theta"] + the = obj_xml[ch1][i].attrib["psi"] + ccc = obj_xml[ch1][i].attrib["CCC"] + info[0].append(xx) + info[1].append(yy) + info[2].append(zz) + info[3].append(phi) + info[4].append(psi) + info[5].append(the) + info[6].append(ccc) + return info + + def clean_particle(self, vol_array, obj_rot, coord): + # Implementation of cleaner + dims = vol_array.shape + obj_voxel = np.nonzero(obj_rot > self.shrinkage) + offset = int(obj_rot.shape[0] // 2) + x_vox = obj_voxel[2] + int(coord[0]) - offset + y_vox = obj_voxel[1] + int(coord[1]) - offset + z_vox = obj_voxel[0] + int(coord[2]) - offset + for idx in range(x_vox.size): + xx = x_vox[idx] + yy = y_vox[idx] + zz = z_vox[idx] + if 0 <= xx < dims[2] and 0 <= yy < dims[1] and 0 < zz < dims[0]: + vol_array[zz, yy, xx] = 0 + return vol_array diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..a51f3a8 --- /dev/null +++ b/src/config.py @@ -0,0 +1,93 @@ +# config.py + +# Path to the tomograms +""" +To do the template matching,the first stp is taking the input from the user: +1.The tomogram, should be black on white background +2.The templates, needs to be white on black background, to be consistent with normal convention +3.The coordinates of picked particles +4.The missing wedge information, default to [30,42], corresponding to [-60,+48] tilt range +5.The shrinkage factor, default to 0.3, this controls the contours of the template +6. There are more high lever parameters, like the search depth, etc. +""" +import os + +current_dir = os.getcwd() +split_path = current_dir.split("/") +tomoID = split_path[-2] +############################ input && output control ################################## + +prefix = "s" + tomoID + ".match" +tomogram = "s" + tomoID + ".lps25.mrc" + +# Paths to templates and corresponding masks +# the templates need to have the same shape +templates = [ + "templates/mono-8Apx-lps30-box24-rot12-core.mrc", + "templates/pre-40s.mrc", + "templates/pre-60s.mrc", +] +contour_level = [0.2, 0.23, 0.23] +# this is an important parameter, find script_s1 to help if you need + +# Masks are required to be the same shape as the templates, spherical masks or shape masks +# One can create the mask in Relion with soft edges, extend 3 and soft 3 is recommended +# Leave it blank if you don't are not sure what to use, the mask will be generated internally +masks = [] + +# Coordinates of picked particles,you might also include the angles for local search +# coords = pd.read_csv("/path/to/coords.csv") +# Assuming the columns are x, y, z,change it if not +df = "../cluster/s" + tomoID + ".unetn_4.5.csv" + + +# We need the missing wedge information or a CTF model +# You might create a CTF model via Relion or Warp, which should be generic for data collected with the microscope +# If a CTF model present, it will be use by default +ctf_model_file = "s" + tomoID + "_0000000_ctf_8.00A.mrc" # path to the CTF model + +# Missing wedge information, if CTF model is not found +missing_wedge = [30, 42] # [30,42] Corresponding to [-60, +48] tilt range + +# Shrinkage factor, which the counter level of the volume, can be determined in Chimera +# It's designed to control how close two objects are allowed to be +# only template 1 will be used for the cleaning of clashes +shrinkage_factor = 0.9 + +# output path +output_path = "results/" +write_models = True # Generate a volume with assigned models + +# Keep the following parameters as default for now +# Number of angles for global template matching +number_of_angles = 2000 +# The minimum CCCs allowed for the template matching +min_CCC = 0.2 + +# Range of local search angles, only for local search, default to None for global search +# For local searchthe rough angles need to be provied in the coords.csv, +# with columns phi, theta, psi in intrinsic ZXZ convention +local_search_angles = False # False or True +local_search_range = 5 # in degrees, how far from the original position to search +local_search_step_interval = 5 +# set to zore if you don't want to have any refiment + +# searching space +matching_space = 3 # in pixels, how far from the original position to search + +# Search depth, control how many rotataions +search_depth = 200 + +# for development +bypass_TM = False +bypass_CR = False +sort_score = True # this is useful if you don't want to mix up the index +mpi_nn = -1 # -1 will use all the availbe cpu +pre_assigned_volume = None # path to the pre-assigned volume +bypass_optimizer = False # this will bypass the optimizer clash resolver +distance_tolerance = 3 +adjust_ccc = 0.11111111 +# max distance between two partles to be considered as the same particle +# option's for running only the general template matching +testTM = False +adjust_ccc_relion = 0.11111111 diff --git a/src/config_loader.py b/src/config_loader.py new file mode 100644 index 0000000..1840fe3 --- /dev/null +++ b/src/config_loader.py @@ -0,0 +1,58 @@ +# config_loader.py +import sys +import os +import importlib.util + + +class Config: + _instance = None + + @staticmethod + def get_instance(): + if Config._instance is None: + Config() + return Config._instance + + def __init__(self): + if Config._instance is not None: + raise Exception("This class is a singleton!") + else: + Config._instance = self + self.load_config_old() + + def load_config(self): + # Add the desired folder to the start of sys.path + # Get the directory where the script is located + script_dir = os.path.abspath(os.path.dirname(__file__)) + # Add the script directory to the system path + sys.path.append(script_dir) + # desired_folder = os.path.abspath( + # os.path.join(os.path.dirname(__file__), "../desired_folder") + # ) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + # Now import the config + import config as cfg + + self.config = cfg + + def load_config_old(self): + # Get the current working directory + current_dir = os.getcwd() + + # Path to config.py in the current directory + config_path = os.path.join(current_dir, "config.py") + + if os.path.exists(config_path): + # Import the module from the given path + spec = importlib.util.spec_from_file_location("config", config_path) + cfg = importlib.util.module_from_spec(spec) + spec.loader.exec_module(cfg) + self.config = cfg + else: + raise FileNotFoundError("config.py not found in the current directory.") + + +def get_config(): + return Config.get_instance().config diff --git a/src/file_handler.py b/src/file_handler.py new file mode 100644 index 0000000..2dbf604 --- /dev/null +++ b/src/file_handler.py @@ -0,0 +1,335 @@ +import mrcfile +import pandas as pd +import numpy as np +from scipy.spatial.transform import Rotation as R +from lxml import etree +import os +import importlib.util +import starfile + +import config + + +def read_mrc(filename): + with mrcfile.open(filename) as mrc: + array = mrc.data + return array + + +def read_mrc_permissive(filename): + with mrcfile.open(filename, permissive=True) as mrc: + array = mrc.data + return array + + +def write_mrc(data, file_name, pixel_size=8.0): + with mrcfile.new(file_name, overwrite=True) as mrc: + mrc.set_data(data) + mrc.voxel_size = pixel_size + print("Writing file: ", file_name, "with ", pixel_size, " Apx!") + + +def read_expand_ctf_model(ctf_model_file): + """ + Apply the CTF correction to the real data and return the result in real space. + + Parameters: + real_data (numpy.ndarray): The real data with shape [30, 30, 30]. + ctf_data (numpy.ndarray): The CTF data with shape [30, 30, 16]. + + Returns: + numpy.ndarray: The CTF-corrected real data in real space. + """ + + ctf_data = read_mrc(ctf_model_file) + dim1 = ctf_data.shape[1] + dim2 = ctf_data.shape[2] # this is the dim we apply the symm to + # Initialize an array for the expanded CTF data + expanded_ctf = np.zeros([dim1, dim1, dim1], dtype=complex) + + # Copy the original CTF data + expanded_ctf[:, :, :dim2] = ctf_data + + # Expand the CTF data by mirroring along the last axis + for i in range(1, dim2): + expanded_ctf[:, :, -i] = np.conj(ctf_data[:, :, i]) + + return expanded_ctf + + +def ZXZ_to_ZYZ(rot): + rot_in = R.from_euler("ZXZ", np.array(rot), degrees=True) + rot_out = rot_in.as_euler("ZYZ", degrees=True) + return rot_out + + +def ZYZ_to_ZXZ(rot): + rot_in = R.from_euler("ZYZ", np.array(rot), degrees=True) + rot_out = rot_in.as_euler("ZXZ", degrees=True) + return rot_out + + +def relion_to_coord(star): # perform the convertion and downsampling to 8Apx + # Assuming the star file is downgrade to 3.0 + # ToDo: check if the star file is 3.0 or not + coordinates = star[["rlnCoordinateX", "rlnCoordinateY", "rlnCoordinateZ"]].values + rotations = star[["rlnAngleRot", "rlnAngleTilt", "rlnAnglePsi"]].values + if "rlnImagePixelSize" in star.columns: + pixel_size = star["rlnImagePixelSize"].values[0] + coordinates = coordinates * pixel_size / 8 + print("Scale the relion coordinates pixel size to 8Apx") + else: + print("No pixel size information in the star file") + rotations_ZXZ = ZYZ_to_ZXZ(rotations) + return coordinates, rotations_ZXZ + + +def coord_to_relion(coord, rot, ccc=None): + rot_ZYZ = ZXZ_to_ZYZ(rot) + df_coor = pd.DataFrame( + coord, columns=["rlnCoordinateX", "rlnCoordinateY", "rlnCoordinateZ"] + ) + df_rotx = pd.DataFrame( + rot_ZYZ, columns=["rlnAngleRot", "rlnAngleTilt", "rlnAnglePsi"] + ) + combined_df = pd.concat([df_coor, df_rotx], axis=1) + if ccc is not None: + combined_df["ccc"] = ccc + return combined_df + + +def star2df(df): + # mpd = df['rlnMaxValueProbDistribution'] + + pixel_size = 8 # float(df['rlnImagePixelSize'][0]) + upsampling = 8 / pixel_size + phi_list = [] + theta_list = [] + psi_list = [] + df2 = {} + for i in range(len(df)): + df2["x"] = ( + df["rlnCoordinateX"] / upsampling + ) # -(df['rlnOriginXAngst']/df['rlnPixelSize'])/upsampling + df2["y"] = ( + df["rlnCoordinateY"] / upsampling + ) # -(df['rlnOriginYAngst']/df['rlnPixelSize']) /upsampling + df2["z"] = ( + df["rlnCoordinateZ"] / upsampling + ) # -(df['rlnOriginZAngst']/df['rlnPixelSize'])/upsampling + phi = df["rlnAngleRot"][i] + theta = df["rlnAngleTilt"][i] + psi = df["rlnAnglePsi"][i] + input_eulers = np.array([phi, theta, psi]) + rot = R.from_euler("ZYZ", input_eulers, degrees=True) + neweuler = rot.as_euler("ZXZ", degrees=True) + phi_list.append(neweuler[0]) + theta_list.append(neweuler[1]) + psi_list.append(neweuler[2]) + df2["phi"] = phi_list + df2["theta"] = theta_list + df2["psi"] = psi_list + df2 = pd.DataFrame(df2) + return df2 + + +def load_config(): + # Get the current working directory + current_dir = os.getcwd() + + # Path to config.py in the current directory + config_path = os.path.join(current_dir, "config.py") + + if os.path.exists(config_path): + # Import the module from the given path + spec = importlib.util.spec_from_file_location("config", config_path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + return config + else: + raise FileNotFoundError("config.py not found in the current directory.") + + +def parse_input(): + # Directly accessing the configuration + try: + config = load_config() + # Now you can use your config variables, e.g., config.some_setting + except FileNotFoundError as e: + print(e) + user_inputs = { + key: getattr(config, key) for key in dir(config) if not key.startswith("__") + } + print(user_inputs) + + # print(json.dumps(user_inputs, indent=4)) + # Processing and reading data + tomogram = read_mrc(user_inputs["tomogram"]) + dims = tomogram.shape + templates = [read_mrc_permissive(template) for template in user_inputs["templates"]] + if user_inputs["masks"]: + masks = [read_mrc_permissive(mask) for mask in user_inputs["masks"]] + else: + masks = None + if user_inputs["ctf_model_file"]: # os.path.exists(user_inputs["ctf_model_file"]): + ctf_model = read_expand_ctf_model(user_inputs["ctf_model_file"]) + else: + print("No CTF model found! Will use missing wedge compensation only!") + ctf_model = None + + # read the coordinate file + if config.testTM: + df = pd.DataFrame(np.random.rand(100, 3), columns=["x", "y", "z"]) + else: + coord_file = user_inputs["df"] + if coord_file.endswith("csv"): + df = pd.read_csv(coord_file) + elif coord_file.endswith("coords"): + df = pd.read_csv(coord_file, sep=" ", header=None, names=["x", "y", "z"]) + elif coord_file.endswith("star"): + df = starfile.read(coord_file) + coord_star, rot_star = relion_to_coord(df) + df_rotx = pd.DataFrame(rot_star, columns=["phi", "theta", "psi"]) + df_coord = pd.DataFrame(coord_star, columns=["x", "y", "z"]) + df = pd.concat([df_coord, df_rotx], axis=1) + else: + raise RuntimeError("No csv/coords/star file found") + length1 = len(df) + # filter file by remove the coordinates out of box #TODO + df = df[df["x"] > 0] + df = df[df["y"] > 0] + df = df[df["z"] > 0] + df = df[df["x"] < dims[2]] + df = df[df["y"] < dims[1]] + df = df[df["z"] < dims[0]] + print("Total " + str(len(df)) + " particles!") + if len(df) < length1: + print("Warning: some particles are out of box!" + str(length1 - len(df))) + coords = df[["x", "y", "z"]].to_numpy() + + rotations = None + if user_inputs["local_search_angles"] is not False: + rotations = df[["phi", "theta", "psi"]].to_numpy() + # make sure the output path is valid + if not os.path.exists(user_inputs["output_path"]): + os.makedirs(user_inputs["output_path"]) + + if user_inputs["pre_assigned_volume"]: + pre_assigned_volume = read_mrc(user_inputs["pre_assigned_volume"]) + else: + pre_assigned_volume = None + + # Return the processed configuration and data + return { + "tomogram": tomogram, + "templates": templates, + "contour_level": user_inputs["contour_level"], + "masks": masks, + "coords": coords, + "ctf_model": ctf_model, + "missing_wedge": user_inputs["missing_wedge"], + "shrinkage_factor": user_inputs["shrinkage_factor"], + "search_depth": user_inputs["search_depth"], + "min_CCC": user_inputs["min_CCC"], + "number_of_angles": user_inputs["number_of_angles"], + "local_search_angles": user_inputs["local_search_angles"], + "rotations": rotations, + "output_path": user_inputs["output_path"], + "matching_space": user_inputs["matching_space"], + "sort_score": user_inputs["sort_score"], + "mpi_nn": user_inputs["mpi_nn"], + "write_model_file": user_inputs["write_models"], + "pre_assigned_volume": pre_assigned_volume, + } + + +def save_xml(scores, rot, coord, model): + obj_xml = etree.Element("objlist") + print("Scores " + str(len(scores)) + " particles!") + print("Coordiante " + str(len(coord)) + " particles!") + print("Rotation " + str(len(rot)) + " particles!") + for i in range(len(scores)): + if (scores[i] is not None) and (coord[i] is not None) and (rot[i] is not None): + subtomo = etree.SubElement(obj_xml, "subtomo") + for j in range(len(coord[i])): + # something wired happened for local, debugging + if (len(coord[i]) != len(scores[i])) or (len(coord[i]) != len(rot[i])): + print("Warning: Coord, rot and scores don't match") + print("coord is", len(coord[i])) + print("score is ", len(scores[i])) + print("rot is ", len(rot[i])) + # print(coord[i]) + # print(scores[i]) + # print(rot[i]) + indiv_xml = etree.SubElement(subtomo, "object") + indiv_xml.set("subtomo_idx", "%d" % i) + indiv_xml.set("x", "%d" % coord[i][j][0]) + indiv_xml.set("y", "%d" % coord[i][j][1]) + indiv_xml.set("z", "%d" % coord[i][j][2]) + indiv_xml.set("phi", "%d" % rot[i][j][0]) + indiv_xml.set("theta", "%d" % rot[i][j][1]) + indiv_xml.set("psi", "%d" % rot[i][j][2]) + indiv_xml.set("CCC", "%f" % scores[i][j]) + indiv_xml.set("model", "%f" % model[i][j]) + return obj_xml + + +def save_xml_one_loop(scores, rot, coord, model): + obj_xml = etree.Element("objlist") + for i in range(len(scores)): + subtomo = etree.SubElement(obj_xml, "subtomo") + indiv_xml = etree.SubElement(subtomo, "object") + indiv_xml.set("subtomo_idx", "%d" % i) + indiv_xml.set("x", "%d" % coord[i][0]) + indiv_xml.set("y", "%d" % coord[i][1]) + indiv_xml.set("z", "%d" % coord[i][2]) + indiv_xml.set("phi", "%d" % rot[i][0]) + indiv_xml.set("theta", "%d" % rot[i][1]) + indiv_xml.set("psi", "%d" % rot[i][2]) + indiv_xml.set("CCC", "%f" % float(scores[i])) + indiv_xml.set("model", "%f" % int(model[i])) + + return obj_xml + + +def adjust_ccc_values_second_half(obj_xml, coord_length, offset=0.1): + """ + Adjusts the CCC values by adding an offset to the first half of the objects in each subtomo. + + Args: + obj_xml (etree.Element): The root element of the XML structure. + offset (float): The offset value to add to the CCC value of the first half of objects. + + Returns: + None; the obj_xml is modified in place. + """ + # Iterate over all subtomo elements in the XML + for subtomo in obj_xml.findall("subtomo"): + # Get all object elements within each subtomo + objects = subtomo.findall("object") + # Determine the midpoint (first half) + midpoint = coord_length // 2 + # Iterate over the first half of the objects + for obj in objects[midpoint:]: + # Get the current CCC value + current_ccc = float(obj.get("CCC")) + # Calculate the new CCC value and update the attribute + new_ccc = current_ccc + offset + obj.set("CCC", f"{new_ccc:.3f}") + + +def find_xml(obj_xml, ch1, ch2): + xx = obj_xml[ch1][ch2].attrib["x"] + yy = obj_xml[ch1][ch2].attrib["y"] + zz = obj_xml[ch1][ch2].attrib["z"] + phi = obj_xml[ch1][ch2].attrib["phi"] + theta = obj_xml[ch1][ch2].attrib["theta"] + psi = obj_xml[ch1][ch2].attrib["psi"] + ccc = obj_xml[ch1][ch2].attrib["CCC"] + model = float(obj_xml[ch1][ch2].attrib["model"]) + return ( + [int(xx), int(yy), int(zz)], + [np.float32(phi), np.float32(theta), np.float32(psi)], + ccc, + int(model), + ) diff --git a/src/filter_tomograms.py b/src/filter_tomograms.py new file mode 100644 index 0000000..daee728 --- /dev/null +++ b/src/filter_tomograms.py @@ -0,0 +1,70 @@ +import numpy as np +from scipy.fft import fftn, ifftn, fftshift +from scipy.ndimage import fourier_gaussian +from file_handler import read_mrc, write_mrc + + +def normalize_data(tomogram): + mean = np.mean(tomogram) + std_dev = np.std(tomogram) + normalized_tomogram = (tomogram - mean) / std_dev + return normalized_tomogram + + +def low_pass_filter(tomogram, pixel_size, cutoff_nm): + cutoff_frequency = cutoff_nm / pixel_size + tomogram_fft = fftn(tomogram) + tomogram_fft_shifted = fftshift(tomogram_fft) + filtered_fft_shifted = fourier_gaussian( + tomogram_fft_shifted, sigma=cutoff_frequency + ) + filtered_fft = fftshift(filtered_fft_shifted) + return np.real(ifftn(filtered_fft)) + + +def high_pass_filter(tomogram, pixel_size, cutoff_nm): + cutoff_frequency = cutoff_nm / pixel_size + tomogram_fft = fftn(tomogram) + tomogram_fft_shifted = fftshift(tomogram_fft) + low_pass_fft_shifted = fourier_gaussian( + tomogram_fft_shifted, sigma=cutoff_frequency + ) + high_pass_fft_shifted = tomogram_fft_shifted - low_pass_fft_shifted + high_pass_fft = fftshift(high_pass_fft_shifted) + return np.real(ifftn(high_pass_fft)) + + +def threshold_clampminmax_nsigma(tomogram, nsigma=3): + mean = np.mean(tomogram) + std_dev = np.std(tomogram) + threshold = mean + nsigma * std_dev + return np.maximum(tomogram, threshold) + + +def process(tomogram, pixel_size, cutoff_low_nm, cutoff_high_nm, nsigma): + # Apply low pass filter + tomogram_low_passed = low_pass_filter(tomogram, pixel_size, cutoff_low_nm) + + # Apply high pass filter + tomogram_high_passed = high_pass_filter( + tomogram_low_passed, pixel_size, cutoff_high_nm + ) + + # Normalize the tomogram + tomogram_normalized = normalize_data(tomogram_high_passed) + + # Apply threshold clampminmax nsigma + tomogram_thresholded = threshold_clampminmax_nsigma(tomogram_normalized, nsigma) + + return np.float32(tomogram_thresholded) + + +if __name__ == "__main__": + filename = "s68" + pixel_size = 0.8 # in nm + cutoff_low_nm = 25 + cutoff_high_nm = 100 + tomo = read_mrc(filename) + filtered_tomo = process(tomo, pixel_size, cutoff_low_nm, cutoff_high_nm, 3) + write_mrc(filtered_tomo, "s68.filtered.mrc") +# Example usage diff --git a/src/geo_utils.py b/src/geo_utils.py new file mode 100644 index 0000000..18f3fe7 --- /dev/null +++ b/src/geo_utils.py @@ -0,0 +1,31 @@ +import numpy as np + + +def findVec(point1, point2, unitSphere=False): + # setting unitSphere to True will make the vector scaled down to a sphere with a radius one, instead of it's orginal length + finalVector = [0 for coOrd in point1] + for dimension, coOrd in enumerate(point1): + # finding total differnce for that co-ordinate(x,y,z...) + deltaCoOrd = point2[dimension] - coOrd + # adding total difference + finalVector[dimension] = deltaCoOrd + if unitSphere: + totalDist = multiDimenDist(point1, point2) + unitVector = [] + for dimen in finalVector: + unitVector.append(dimen / totalDist) + return np.array(unitVector) + else: + return np.array(finalVector) + + +def multiDimenDist(point1, point2): + # find the difference between the two points, its really the same as below + deltaVals = [ + point2[dimension] - point1[dimension] for dimension in range(len(point1)) + ] + runningSquared = 0 + # because the pythagarom theorm works for any dimension we can just use that + for coOrd in deltaVals: + runningSquared += coOrd**2 + return runningSquared ** (1 / 2) diff --git a/src/global_template_matching.py b/src/global_template_matching.py new file mode 100644 index 0000000..24387d4 --- /dev/null +++ b/src/global_template_matching.py @@ -0,0 +1,283 @@ +import numpy as np +from scipy.spatial.transform import Rotation as R +from multiprocessing import Pool, cpu_count +from utils import rotate, rotate_high_res +from utils import prepare_ctf_volumes, apply_ctf, apply_wedge +from scipy.ndimage import binary_dilation +import pandas as pd +from config_loader import get_config + +config = get_config() + +""" +This script is for the benchmark ing of the template matching algorithm. +Works well on small tomograms, no optimzation for large tomograms yet +""" + + +class TemplateMatcher_general: + def __init__(self, inputs): + self.templates = inputs["templates"] + self.contour_level = inputs["contour_level"] + self.masks = inputs["masks"] + self.ctf = inputs["ctf_model"] + self.missing_wedge = inputs["missing_wedge"] + self.tomogram = inputs["tomogram"] + self.dims = self.tomogram.shape + self.number_of_angles = inputs["number_of_angles"] + self.min_CCC = inputs["min_CCC"] + self.output_path = inputs["output_path"] + self.sort_score = inputs["sort_score"] + self.mpi_nn = inputs["mpi_nn"] + self.global_angles = self.generate_radom_angles() + self.rotamer = self.prepare_rotamer() + self.maskamer = self.prepare_maskamer() + + def match_worker(self, index): + # Implementation of the template matching logic + # we have two different modes, one is global search, the other is local search + # subtomo = self.extract_subtomo(self.tomogram) + subtomo = ( + -self.tomogram + ) # here we simply use the whole tomogram as the subtomo, invert the contrast + # might consider split the whole tomograms into piceces + + angles = self.global_angles + if self.ctf is not None: + ctf_volumes = prepare_ctf_volumes(self.ctf, self.templates) + else: + ctf_volumes = [] + score, rots = self.template_match(subtomo, angles, ctf_volumes, index) + + return score, rots + + def run_multiprocessing(self): + if self.mpi_nn == -1: + ncpu = cpu_count() - 1 + else: + ncpu = self.mpi_nn + print("Using " + str(ncpu) + " cpu cores for processing~") + + # Initialize lists to store results + + # Use Pool context manager to ensure proper cleanup + with Pool(ncpu - 2) as pool: + indices = list(range(len(self.global_angles))) + results = pool.map(self.match_worker, indices) + # Process results in order + print(len(results)) + print("Template matching is done, sorting the results!") + # scores, rot_new, coord_new, model = [], [], [], [] + # Initialize the arrays + rot_save = np.empty( + (*self.dims, 3) + ) # Adding an extra dimension for the 3 elements of cur_rot + ccc_save = np.zeros(self.dims) + + # Iterate through results and apply the condition + for cur_ccc, cur_rot in results: + mask = cur_ccc > ccc_save + ccc_save[mask] = cur_ccc[mask] + + # Create a 4-dimensional temporary array filled with cur_rot values + temp_array = np.zeros((*self.dims, 3)) + temp_array[mask] = cur_rot + + # Assign the temporary array to rot_save where the mask is True + rot_save[mask] = temp_array[mask] + # Verify the result for one element + print(rot_save[0, 0, 0]) # Example verification + """ rot_save = np.empty(self.dims, dtype=object) + ccc_save = np.zeros(self.dims) + for cur_ccc, cur_rot in results: + with np.nditer([cur_ccc, ccc_save], flags=["multi_index"]) as it: + for x, y in it: + if x > y: + ccc_save[it.multi_index] = cur_ccc[it.multi_index] + rot_save[it.multi_index] = cur_rot.tolist()""" + # remove all the None in scores, rot_new and coord_new + # Flatten ccc_save + flat_ccc_save = ccc_save.ravel() + + # Reshape rot_save to a 2D array where each row is a 3-element list + flat_rot_save = rot_save.reshape(-1, 3) + # Get sorted indices from the flattened ccc_save array + sorted_indices = np.argsort(flat_ccc_save) + # Sort the ccc_save array and fetch corresponding values from rot_save + sorted_flat_ccc_save = flat_ccc_save[sorted_indices] + corresponding_flat_rot_save = flat_rot_save[sorted_indices] + # Convert flat indices to 3D indices (ignoring the last dimension) + sorted_indices_3d = np.unravel_index(sorted_indices, ccc_save.shape) + # Create DataFrame + df = pd.DataFrame( + { + "z": sorted_indices_3d[0], + "y": sorted_indices_3d[1], + "x": sorted_indices_3d[2], + "phi": corresponding_flat_rot_save[:, 0], + "theta": corresponding_flat_rot_save[:, 1], + "psi": corresponding_flat_rot_save[:, 2], + "ccc": sorted_flat_ccc_save, + } + ) + + df = df[df["ccc"] > self.min_CCC] + df.sort_values("ccc", inplace=True, ascending=False) + df.to_csv(config.output_path + config.prefix + ".tm.csv") + print("Done!") + + def template_match(self, subtomo, angles, ctf_vols, idx): + # define the search area + # ccc_save = np.zeros(self.dims) + num_templates = len(self.templates) + + # for i, (phi, theta, psi) in enumerate(angles): + # cur_rot = [phi, theta, psi] + cur_rot = angles[idx] + for temp in range(num_templates): + template_rot = self.rotamer[temp][idx] + # for given template at given angles,fetch from the rotamer + if self.ctf is not None: + template_rot = apply_ctf(template_rot, ctf_vols[temp]) + else: + template_rot = apply_wedge(template_rot, self.missing_wedge) + template_rot[template_rot < self.contour_level[temp]] = 0 + if self.masks is None: + structuring_element = np.ones((3, 3, 3)) + mask = binary_dilation(template_rot, structuring_element) + + else: + mask = self.maskamer[temp][idx] + mask = apply_wedge(mask, self.missing_wedge) + mask[mask < self.contour_level[temp]] = 0 + mask[mask >= self.contour_level[temp]] = 1 + + cur_ccc = self.calculate_correlation(subtomo, template_rot, mask) + # go throught the numpy array ccc_save, if curr_ccc is larger than ccc_save, replace it + + return cur_ccc, cur_rot + + def prepare_rotamer(self): + # this is to generate all the rotated template to accelerate the later steps + rotamer = [] + for temp in range(len(self.templates)): + dim = self.templates[temp].shape[0] + r = np.zeros([len(self.global_angles), dim, dim, dim]) + r = np.zeros([len(self.global_angles), dim, dim, dim]) + for i, (phi, theta, psi) in enumerate(self.global_angles): + template = self.templates[temp] + r[i] = rotate_high_res(template, (phi, theta, psi)) + rotamer.append(r) + return rotamer + + def prepare_maskamer(self): + # this is to generate all the rotated template to accelerate the later steps + maskamer = [] + for temp in range(len(self.templates)): + dim = self.templates[temp].shape[0] + r = np.zeros([len(self.global_angles), dim, dim, dim]) + r = np.zeros([len(self.global_angles), dim, dim, dim]) + for i, (phi, theta, psi) in enumerate(self.global_angles): + template = self.templates[temp] + r[i] = rotate_high_res(template, (phi, theta, psi)) + maskamer.append(r) + return maskamer + + def generate_radom_angles(self): + angles = R.random(self.number_of_angles).as_euler("ZXZ", degrees=True) + return angles + + def paste_to_whole_map(self, whole_map, vol, center=None): + """ + Paste a smaller volume (vol) into a larger volume (whole_map). + The smaller volume is centered at the position 'center' if specified, + otherwise at the center of the whole_map. + """ + if center is None: + center = np.array(whole_map.shape) // 2 + + start, end = self.calculate_start_end(center, whole_map.shape, vol.shape) + if start is None or end is None: + raise ValueError("Volume cannot be pasted, out of bounds.") + + self.paste_volume(whole_map, vol, start, end) + return whole_map + + def calculate_start_end(self, center, map_size, subvol_size): + center = np.array(center) + subvol_size = np.array(subvol_size) + start = center - np.ceil(subvol_size / 2.0).astype(int) + end = start + subvol_size + + if np.any(start < 0) or np.any(end > map_size): + return None, None + return start, end + + def paste_volume(self, map, subvol, start, end): + """ + Paste the subvolume into the map using the start and end indices provided. + This function assumes that start and end are valid and within the bounds of the map. + """ + slices = tuple(slice(s, e) for s, e in zip(start, end)) + map[slices] = subvol + + # Example usage: + """ + whole_map = np.zeros((100, 100, 100)) + vol = np.random.rand(10, 10, 10) + try: + updated_map = paste_to_whole_map(whole_map, vol, center=np.array([50, 50, 50])) + print("Volume pasted successfully.") + except ValueError as e: + print(str(e)) + """ + + def perform_convolution(self, volume, template): + """ + Perform convolution of a volume with a given template. + """ + from numpy.fft import fftn, ifftn, ifftshift + + # centering the template in the frequency domain + template_fft = fftn(ifftshift(template)) + convolved_volume = np.real(ifftn(fftn(volume) * np.conj(template_fft))) + return convolved_volume + + def calculate_correlation(self, data_volume, template_volume, mask): + """ + Calculate the correlation of a data volume with a template volume. + The correlation is similar to the implement in PyTom + """ + # + mask = mask.astype(bool) + num_elements = np.sum(mask) + template_mean = np.mean(template_volume[mask]) + template_std = np.std(template_volume[mask]) + normalized_template = template_volume - template_mean + normalized_template /= template_std + + # Construct the larger template and mask volumes + large_template = np.zeros_like(data_volume) + large_mask = np.zeros_like(data_volume) + self.paste_to_whole_map(large_template, normalized_template) + self.paste_to_whole_map(large_mask, mask) + + # Calculate mean and std of data volume under the mask + mean_volume = self.perform_convolution(data_volume, large_mask) / num_elements + mean_square_volume = ( + self.perform_convolution(np.square(data_volume), large_mask) / num_elements + ) + volume_std_dev = mean_square_volume - np.square(mean_volume) + np.maximum(volume_std_dev, 0.0, out=volume_std_dev) + np.sqrt(volume_std_dev, out=volume_std_dev) + + adjusted_volume = data_volume - mean_volume + valid_std_dev = volume_std_dev > 0 + adjusted_volume[valid_std_dev] /= volume_std_dev[valid_std_dev] + + large_template *= large_mask.astype(np.float32) + correlation_result = ( + self.perform_convolution(adjusted_volume, large_template) / num_elements + ) + print(correlation_result.max()) + return correlation_result diff --git a/src/global_template_matching_chunk.py b/src/global_template_matching_chunk.py new file mode 100644 index 0000000..e0b83ca --- /dev/null +++ b/src/global_template_matching_chunk.py @@ -0,0 +1,346 @@ +import numpy as np +from scipy.spatial.transform import Rotation as R +from multiprocessing import Pool, cpu_count +from utils import rotate_high_res, prepare_ctf_volumes, apply_ctf, apply_wedge +from scipy.ndimage import binary_dilation +import pandas as pd +from config_loader import get_config +from numpy.fft import fftn, ifftn, ifftshift +from file_handler import write_mrc +import tempfile +import gc +import os + +config = get_config() + + +class TemplateMatcherGeneral: + def __init__(self, inputs): + self.templates = inputs["templates"] + self.contour_level = inputs["contour_level"] + self.masks = inputs["masks"] + self.ctf = inputs["ctf_model"] + self.missing_wedge = inputs["missing_wedge"] + self.tomogram = inputs["tomogram"] + self.dims = self.tomogram.shape + self.number_of_angles = inputs["number_of_angles"] + self.min_CCC = inputs["min_CCC"] + self.output_path = inputs["output_path"] + self.sort_score = inputs["sort_score"] + self.mpi_nn = inputs["mpi_nn"] + self.global_angles = self.generate_random_angles() + self.rotamer = self.prepare_rotamer() + self.maskamer = self.prepare_maskamer() + + # Memory map the tomogram to a temporary file + self.tomogram_file_path = self.create_temp_file() + + # Initialize arrays to store results incrementally + self.rot_save = np.empty( + (*self.dims, 3) + ) # Adding an extra dimension for the 3 elements of cur_rot + self.ccc_save = np.zeros(self.dims) + + def create_temp_file(self): + tomogram_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy") + np.save(tomogram_file, self.tomogram) + tomogram_file.close() + return tomogram_file.name + + def cleanup(self): + os.remove(self.tomogram_file_path) + + def match_worker(self, index, tomogram_memmap_path): + # Memory map the tomogram in the worker process + tomogram_memmap = np.load(tomogram_memmap_path, mmap_mode="r") + subtomo = -tomogram_memmap # Invert the contrast + angles = self.global_angles + ctf_volumes = ( + prepare_ctf_volumes(self.ctf, self.templates) + if self.ctf is not None + else [] + ) + + score, rots = self.template_match(subtomo, angles, ctf_volumes, index) + + # Explicitly delete large variables and force garbage collection + del subtomo, angles, ctf_volumes, tomogram_memmap + gc.collect() + + return score, rots + + def result_callback(self, result): + cur_ccc, cur_rot = result + mask = cur_ccc > self.ccc_save + self.ccc_save[mask] = cur_ccc[mask] + + temp_array = np.zeros((*self.dims, 3)) + temp_array[mask] = cur_rot + self.rot_save[mask] = temp_array[mask] + + def run_multiprocessing(self): + ncpu = cpu_count() - 1 if self.mpi_nn == -1 else self.mpi_nn + print(f"Using {ncpu} cpu cores for processing~") + + with Pool(ncpu - 2) as pool: + indices = list(range(len(self.global_angles))) + for index in indices: + pool.apply_async( + self.match_worker, + args=(index, self.tomogram_file_path), + callback=self.result_callback, + ) + + pool.close() + pool.join() + + print("Template matching is done, sorting the results!") + + # print(self.rot_save[0, 0, 0]) + write_mrc( + np.float32(self.ccc_save), self.output_path + config.prefix + "ccc.mrc" + ) + flat_ccc_save = self.ccc_save.ravel() + flat_rot_save = self.rot_save.reshape(-1, 3) + sorted_indices = np.argsort(flat_ccc_save) + sorted_flat_ccc_save = flat_ccc_save[sorted_indices] + corresponding_flat_rot_save = flat_rot_save[sorted_indices] + sorted_indices_3d = np.unravel_index(sorted_indices, self.ccc_save.shape) + + df = pd.DataFrame( + { + "z": sorted_indices_3d[0], + "y": sorted_indices_3d[1], + "x": sorted_indices_3d[2], + "phi": corresponding_flat_rot_save[:, 0], + "theta": corresponding_flat_rot_save[:, 1], + "psi": corresponding_flat_rot_save[:, 2], + "ccc": sorted_flat_ccc_save, + } + ) + + df = df[df["ccc"] > self.min_CCC] + df.sort_values("ccc", inplace=True, ascending=False) + df.to_csv(config.output_path + config.prefix + ".tm.csv") + print("Done!") + + del ( + self.rot_save, + self.ccc_save, + flat_ccc_save, + flat_rot_save, + sorted_indices, + sorted_flat_ccc_save, + corresponding_flat_rot_save, + ) + gc.collect() + + self.cleanup() + + def template_match(self, subtomo, angles, ctf_vols, idx): + print("Template matching for index: ", idx) + cur_rot = angles[idx] + num_templates = len(self.templates) + cur_ccc = np.zeros(self.dims) + + for temp in range(num_templates): + template_rot = self.rotamer[temp][idx] + template_rot = ( + apply_ctf(template_rot, ctf_vols[temp]) + if self.ctf is not None + else apply_wedge(template_rot, self.missing_wedge) + ) + mask = ( + self.masks[temp] + if self.masks is not None + else binary_dilation(template_rot, np.ones((3, 3, 3))) + ) + mask = ( + apply_wedge(mask, self.missing_wedge) + if self.masks is not None + else mask + ) + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + cur_ccc = self.calculate_correlation(subtomo, template_rot, mask) + print(cur_ccc.max()) + # Explicitly delete large variables and force garbage collection + del template_rot, mask + gc.collect() + + return cur_ccc, cur_rot + + def prepare_rotamer(self): + rotamer = [] + for temp in range(len(self.templates)): + dim = self.templates[temp].shape[0] + r = np.zeros([len(self.global_angles), dim, dim, dim]) + for i, (phi, theta, psi) in enumerate(self.global_angles): + template = self.templates[temp] + r[i] = rotate_high_res(template, (phi, theta, psi)) + rotamer.append(r) + return rotamer + + def prepare_maskamer(self): + maskamer = [] + for temp in range(len(self.templates)): + dim = self.templates[temp].shape[0] + r = np.zeros([len(self.global_angles), dim, dim, dim]) + for i, (phi, theta, psi) in enumerate(self.global_angles): + template = self.templates[temp] + r[i] = rotate_high_res(template, (phi, theta, psi)) + maskamer.append(r) + return maskamer + + def generate_random_angles(self): + return R.random(self.number_of_angles).as_euler("ZXZ", degrees=True) + + def paste_to_whole_map(self, whole_map, vol, center=None): + if center is None: + center = np.array(whole_map.shape) // 2 + + start, end = self.calculate_start_end(center, whole_map.shape, vol.shape) + if start is None or end is None: + raise ValueError("Volume cannot be pasted, out of bounds.") + self.paste_volume(whole_map, vol, start, end) + return whole_map + + def calculate_start_end(self, center, map_size, subvol_size): + center = np.array(center) + subvol_size = np.array(subvol_size) + start = center - np.ceil(subvol_size / 2.0).astype(int) + end = start + subvol_size + + if np.any(start < 0) or np.any(end > map_size): + return None, None + return start, end + + def paste_volume(self, map, subvol, start, end): + slices = tuple(slice(s, e) for s, e in zip(start, end)) + map[slices] = subvol + + def perform_convolution(self, volume, template): + template_fft = fftn(ifftshift(template)) + convolved_volume = np.real(ifftn(fftn(volume) * np.conj(template_fft))) + return convolved_volume + + def calculate_correlation(self, data_volume, template_volume, mask): + chunk_size = config.chunk_size or data_volume.shape + if chunk_size is None: + chunk_size = data_volume.shape + mask = mask.astype(bool) + num_elements = np.sum(mask) + template_mean = np.mean(template_volume[mask]) + template_std = np.std(template_volume[mask]) + normalized_template = (template_volume - template_mean) / template_std + + data_shape = data_volume.shape + padding = 10 + chunks = [range(0, s, cs) for s, cs in zip(data_shape, chunk_size)] + correlation_result = np.zeros(data_shape, dtype=np.float32) + for z in chunks[0]: + for y in chunks[1]: + for x in chunks[2]: + z_end = min(z + chunk_size[0], data_shape[0]) + y_end = min(y + chunk_size[1], data_shape[1]) + x_end = min(x + chunk_size[2], data_shape[2]) + z_start = max(0, z - padding) + y_start = max(0, y - padding) + x_start = max(0, x - padding) + data_subvol = data_volume[ + z_start:z_end, y_start:y_end, x_start:x_end + ] + subvol_correlation = self.single_correlation( + data_subvol, normalized_template, mask, num_elements + ) + z_count = max(0, z - padding // 2) + y_count = max(0, y - padding // 2) + x_count = max(0, x - padding // 2) + z_sub = 0 if z_count == 0 else padding // 2 + y_sub = 0 if y_count == 0 else padding // 2 + x_sub = 0 if x_count == 0 else padding // 2 + correlation_result[ + z_count:z_end, y_count:y_end, x_count:x_end + ] = subvol_correlation[z_sub:, y_sub:, x_sub:] + + # Explicitly delete subvolumes and force garbage collection + del data_subvol, subvol_correlation + gc.collect() + + return correlation_result + + def single_correlation(self, data_subvol, normalized_template, mask, num_elements): + large_template = np.zeros_like(data_subvol) + large_mask = np.zeros_like(data_subvol) + self.paste_to_whole_map(large_template, normalized_template) + self.paste_to_whole_map(large_mask, mask) + + mean_volume = self.perform_convolution(data_subvol, large_mask) / num_elements + mean_square_volume = ( + self.perform_convolution(np.square(data_subvol), large_mask) / num_elements + ) + volume_std_dev = mean_square_volume - np.square(mean_volume) + np.maximum(volume_std_dev, 0.0, out=volume_std_dev) + np.sqrt(volume_std_dev, out=volume_std_dev) + + # adjusted_volume = data_subvol - mean_volume + # valid_std_dev = volume_std_dev > 0 + # adjusted_volume[valid_std_dev] /= volume_std_dev[valid_std_dev] + + large_template *= large_mask.astype(np.float32) + subvol_correlation = ( + self.perform_convolution(data_subvol, large_template) + / num_elements + / volume_std_dev + ) + + # Explicitly delete intermediate variables and force garbage collection + del ( + large_template, + large_mask, + mean_volume, + mean_square_volume, + volume_std_dev, + ) + gc.collect() + + return subvol_correlation + + def single_correlation_norm1( + self, data_subvol, normalized_template, mask, num_elements + ): + large_template = np.zeros_like(data_subvol) + large_mask = np.zeros_like(data_subvol) + self.paste_to_whole_map(large_template, normalized_template) + self.paste_to_whole_map(large_mask, mask) + + mean_volume = self.perform_convolution(data_subvol, large_mask) / num_elements + mean_square_volume = ( + self.perform_convolution(np.square(data_subvol), large_mask) / num_elements + ) + volume_std_dev = mean_square_volume - np.square(mean_volume) + np.maximum(volume_std_dev, 0.0, out=volume_std_dev) + np.sqrt(volume_std_dev, out=volume_std_dev) + + adjusted_volume = data_subvol - mean_volume + valid_std_dev = volume_std_dev > 0 + adjusted_volume[valid_std_dev] /= volume_std_dev[valid_std_dev] + + large_template *= large_mask.astype(np.float32) + subvol_correlation = ( + self.perform_convolution(adjusted_volume, large_template) / num_elements + ) + + # Explicitly delete intermediate variables and force garbage collection + del ( + large_template, + large_mask, + mean_volume, + mean_square_volume, + volume_std_dev, + adjusted_volume, + ) + gc.collect() + + return subvol_correlation diff --git a/src/misc.py b/src/misc.py new file mode 100644 index 0000000..6a097ea --- /dev/null +++ b/src/misc.py @@ -0,0 +1,20 @@ +import numpy as np +from scipy.spatial.transform import Rotation as R + + +def flip_y_euler_angles(euler_angles): + flipped_angles = [] + flip_matrix = np.diag([-1, 1, -1]) + + for angles in euler_angles: + # Convert Euler to rotation matrix + rot_matrix = R.from_euler("ZXZ", angles, degrees=True).as_matrix() + + # Apply the flip transformation + flipped_matrix = np.dot(flip_matrix, rot_matrix) + + # Convert back to Euler angles + flipped_euler = R.from_matrix(flipped_matrix).as_euler("ZXZ", degrees=True) + flipped_angles.append(flipped_euler) + + return np.array(flipped_angles) diff --git a/src/plot_nucleosome_with_z.py b/src/plot_nucleosome_with_z.py new file mode 100644 index 0000000..c6f1502 --- /dev/null +++ b/src/plot_nucleosome_with_z.py @@ -0,0 +1,29 @@ +from file_handler import read_mrc, write_mrc, relion_to_coord +from utils import plot_obj, rotate +import starfile +import numpy as np + + +def plot_z(): + filename = "" + select_z = 100 + star = starfile.read(filename) + star = star[ + (star["rlnCoordinateZ"] > (select_z - 10)) + & (star["rlnCoordinateZ"] < (select_z + 10)) + ] + starfile.write("s68.selected_z.star") + coords, rotxs = relion_to_coord(star) + tomo = read_mrc("s68.wi8Apx.mrc") + dims = tomo.shape + nucleosome = read_mrc("") + shrink = 0.6 + vol_array = np.zeros(dims) + for i in range(len(coords)): + obj_rot = rotate(nucleosome, rotxs[i]) + vol_array = plot_obj(vol_array, obj_rot, coords[i], shrink) + write_mrc(vol_array, "s68.select_z.mrc") + + +if __name__ == "__main__": + plot_z() diff --git a/src/slurm-tm.sh b/src/slurm-tm.sh new file mode 100644 index 0000000..171ad52 --- /dev/null +++ b/src/slurm-tm.sh @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --job-name HL12-core-33 +#SBATCH -p 256GBv1 # partition (queue) +#SBATCH -N 1 +#SBATCH -t 6-2:0:00 +#SBATCH -o job_%j.out +#SBATCH -e job_%j.err + + +#SBATCH --mail-type END +#SBATCH --mail-user huabin.zhou@utsouthwestern.edu + +module load python/3.7.x-anaconda +source /cm/shared/apps/python/3.7.x-anaconda/etc/profile.d/conda.sh +conda activate pysearch +#cd /home2/s194231/work/process/hpf/0/HI6/84/linker +python run.py + + +# COMMAND GROUP 1 +hostname + + + +# END OF SCRIPT diff --git a/src/template_matching.py b/src/template_matching.py new file mode 100644 index 0000000..df634d4 --- /dev/null +++ b/src/template_matching.py @@ -0,0 +1,473 @@ +import pandas as pd +import numpy as np +import starfile +from scipy.spatial.transform import Rotation as R +from multiprocessing import Pool, cpu_count +import itertools +from utils import rotate, rotate_high_res, calculate_correlation +from utils import prepare_ctf_volumes, apply_ctf, apply_wedge +from file_handler import save_xml +from scipy.ndimage import binary_dilation +from lxml import etree +from config_loader import get_config + +config = get_config() # cannot be self.config, pickle needed + +""" +Just to be clear, there are a few things you need to be aware of: +1. The angles are in degrees and in the order of phi, theta, psi, which is intrinsic ZXZ convention. +2. The template and template2 are the same size. +""" + + +class TemplateMatcher: + def __init__(self, inputs): + self.coords = inputs["coords"] + self.templates = inputs["templates"] + self.contour_level = inputs["contour_level"] + self.masks = inputs["masks"] + self.ctf = inputs["ctf_model"] + self.missing_wedge = inputs["missing_wedge"] + self.shrink = inputs["shrinkage_factor"] + self.tomogram = inputs["tomogram"] + self.number_of_angles = inputs["number_of_angles"] + self.search_depth = inputs["search_depth"] + self.local_search_angles = inputs["local_search_angles"] + self.min_CCC = inputs["min_CCC"] + self.matching_space = inputs["matching_space"] + self.rotations = inputs["rotations"] + self.output_path = inputs["output_path"] + self.sort_score = inputs["sort_score"] + self.sort_score = inputs["sort_score"] + self.mpi_nn = inputs["mpi_nn"] + if self.local_search_angles is False: + self.global_angles = self.generate_radom_angles() + self.rotamer = self.prepare_rotamer() + else: + print("Perform Local Refinement!") + + def match_worker(self, index): + # Implementation of the template matching logic + # we have two different modes, one is global search, the other is local search + box_sizes = [temp.shape[0] for temp in self.templates] + subtomo = self.extract_subtomo(self.tomogram, self.coords[index], box_sizes) + if self.local_search_angles: + angles = self.generate_local_angles( + index, + config.local_search_range, + config.local_search_step_interval, + ) + else: + angles = self.global_angles + if self.ctf is not None: + ctf_volumes = prepare_ctf_volumes(self.ctf, self.templates) + else: + ctf_volumes = [] + score, rots, coords, models = self.template_match( + subtomo, angles, self.coords[index], ctf_volumes + ) + + return score, rots, coords, models + + def run_multiprocessing(self): + if self.mpi_nn == -1: + ncpu = cpu_count() - 1 + else: + ncpu = self.mpi_nn + print("Using " + str(ncpu) + " cpu cores for processing~") + + # Use Pool context manager to ensure proper cleanup + with Pool(ncpu - 2) as pool: + indices = list(range(len(self.coords))) + results = pool.map(self.match_worker, indices) + + # Initialize lists to store results + scores, rot_new, coord_new, model = [], [], [], [] + + # Process results in order + for s, r, c, m in results: + scores.append(s) + rot_new.append(r) + coord_new.append(c) + model.append(m) + # remove all the None in scores, rot_new and coord_new + if hasattr(config, "adjust_ccc") and config.adjust_ccc is not None: + print("Adjust Relion CCC by: ", config.adjust_ccc) + for i in range((len(self.coords) // 2), len(scores)): + if scores[i] is not None: + for j in range(len(scores[i])): + scores[i][j] = scores[i][j] + config.adjust_ccc + scores = [x for x in scores if x is not None] + rot_new = [x for x in rot_new if x is not None] + coord_new = [x for x in coord_new if x is not None] + model = [x for x in model if x is not None] + sss = [] + + for i in range(len(scores)): + if scores[i] is None: + sss.append(0) + else: + sss.append(scores[i][0]) + # print(sorted(sss, reverse=True)) + if self.sort_score: + sorted_indices = sorted( + range(len(sss)), key=lambda i: scores[i], reverse=True + ) + + """ + rot_new = [x for _, x in sorted(zip(sss, rot_new), reverse=True)] + coord_new = [x for _, x in sorted(zip(sss, coord_new), reverse=True)] + model = [x for _, x in sorted(zip(sss, model), reverse=True)] + scores = [x for _, x in sorted(zip(sss, scores), reverse=True)] + """ + scores = [scores[i] for i in sorted_indices] + rot_new = [rot_new[i] for i in sorted_indices] + coord_new = [coord_new[i] for i in sorted_indices] + model = [model[i] for i in sorted_indices] + obj_xml = save_xml(scores, rot_new, coord_new, model) + tree = etree.ElementTree(obj_xml) + tree.write(self.output_path + "match_raw_results.xml", pretty_print=True) + max_pos = [] + for i in range(len(scores)): + max_pos.append( + [ + coord_new[i][0][0], + coord_new[i][0][1], + coord_new[i][0][2], + rot_new[i][0][0], + rot_new[i][0][1], + rot_new[i][0][2], + ] + ) + np.savetxt( + self.output_path + config.prefix + ".rawMax.coords", + np.array(max_pos), + delimiter=" ", + ) + df_star = pd.DataFrame( + max_pos, + columns=[ + "rlnCoordinateX", + "rlnCoordinateY", + "rlnCoordinateZ", + "rlnAngleRot", + "rlnAngleTilt", + "rlnAnglePsi", + ], + ) + # starfile.write( + # df_star, self.output_path + config.prefix + "rawMax.star", overwrite=1 + # ) + return obj_xml + + def sort_results(results, original_coordinates): + scores, rotations, updated_coordinates = [], [], [] + + for score, rotation, coord in results: + scores.append(0 if score is None else score[0]) + rotations.append(rotation) + updated_coordinates.append(coord) + + sorted_indices = sorted( + range(len(scores)), key=lambda i: scores[i], reverse=True + ) + sorted_scores = [scores[i] for i in sorted_indices] + sorted_rotations = [rotations[i] for i in sorted_indices] + sorted_updated_coords = [updated_coordinates[i] for i in sorted_indices] + sorted_original_coords = [original_coordinates[i] for i in sorted_indices] + + return ( + sorted_scores, + sorted_rotations, + sorted_updated_coords, + sorted_original_coords, + ) + + def template_match(self, subtomo, angles, cur_coord, ctf_vols): + # define the search area + area = self.generate_search_area(self.templates[0].shape[0]) + ccc_save = [[] for a in range(len(area))] + coord_save = [[] for a in range(len(area))] + rot_save = [[] for a in range(len(area))] + model = [[] for a in range(len(area))] + num_templates = len(self.templates) + + for i, (phi, theta, psi) in enumerate(angles): + cur_ccc = [] + cur_rot = [phi, theta, psi] + for temp in range(num_templates): + if self.local_search_angles is True: + template = self.templates[temp] + template_rot = rotate(template, (phi, theta, psi)) + else: + template_rot = self.rotamer[temp][i] + # for given template at given angles,fetch from the rotamer + if self.ctf is not None: + template_rot = apply_ctf(template_rot, ctf_vols[temp]) + else: + template_rot = apply_wedge(template_rot, self.missing_wedge) + template_rot[template_rot < self.contour_level[temp]] = 0 + if self.masks is None: + structuring_element = np.ones((3, 3, 3)) + mask = binary_dilation(template_rot, structuring_element) + else: + mask = self.masks[temp] + mask = rotate(mask, (phi, theta, psi)) + mask = apply_wedge(mask, self.missing_wedge) + mask[mask < 0.5] = 0 # the mask should be binary + mask[mask >= 0.5] = 1 + + cur_ccc.append(calculate_correlation(subtomo[temp], template_rot, mask)) + ccc_save, rot_save, coor_save = self.get_best_match( + cur_ccc, + cur_coord, + cur_rot, + ccc_save, + rot_save, + coord_save, + model, + ) + ccc, rot, coor, model = self.sort_best_match( + ccc_save, rot_save, coord_save, area, model + ) + if ccc is not None: + if (len(ccc) != len(rot)) or (len(ccc) != len(coor)): + print("ccc ", len(ccc)) + print("rot ", len(rot)) + print("coor ", len(coor)) + raise RuntimeError("ccc, rot, coor have different length") + return ccc, rot, coor, model + + def generate_search_area(self, box_size): + return list( + itertools.product( + range( + box_size // 2 - self.matching_space, + box_size // 2 + self.matching_space + 1, + 1, + ), + repeat=3, + ) + ) + + def get_best_match(self, cur_ccc, cur_coord, cur_rot, ccc, rot, coor, model): + # rather than get one maxCCC at one point, we retain multiple poses + for i in range(len(cur_ccc)): + template_dim = self.templates[i].shape[0] + center_offset = template_dim // 2 + area = self.generate_search_area(template_dim) + cur_maxccc = np.max(cur_ccc[i]) + if cur_maxccc > self.min_CCC: + for j in range(len(area)): + # be careful about the order, it's z,y,x in tomogram + xx = int(area[j][2]) + yy = int(area[j][1]) + zz = int(area[j][0]) + cur_point_ccc = cur_ccc[i][area[j]] + if cur_point_ccc > self.min_CCC: + ccc[j].append(cur_point_ccc) + rot[j].append(cur_rot) + coor[j].append( + [ + int(cur_coord[0]) - center_offset + xx, + int(cur_coord[1]) - center_offset + yy, + int(cur_coord[2]) - center_offset + zz, + ] + ) + model[j].append(i) + + """ + center_offset = self.box_size // 2 + for j in range(len(area)): + xx = int(area[j][2]) # be careful about the order, it's z,y,x in tomogram + yy = int(area[j][1]) + zz = int(area[j][0]) + p = (*area[j],) + cur_maxccc = np.max(max(cur_ccc[i][p] for i in range(len(cur_ccc)))) + if cur_maxccc > self.min_CCC: + ccc[j].append(cur_maxccc) + rot[j].append(cur_rot) + coor[j].append( + [ + int(cur_coord[0]) - center_offset + xx, + int(cur_coord[1]) - center_offset + yy, + int(cur_coord[2]) - center_offset + zz, + ] + ) + """ + return ccc, rot, coor + + def sort_best_match(self, ccc, rotx, coor, area, model): + cccs = [w for sub in ccc for w in sub] + if len(cccs) > 5: + for k in range(len(area)): + # sort ccc + rotx[k] = [x for _, x in sorted(zip(ccc[k], rotx[k]), reverse=True)] + coor[k] = [x for _, x in sorted(zip(ccc[k], coor[k]), reverse=True)] + model[k] = [x for _, x in sorted(zip(ccc[k], model[k]), reverse=True)] + ccc[k] = sorted(ccc[k], reverse=True) + if len(rotx[k]) > 5: # for each position, keep top 5 orientations + rotx[k] = rotx[k][0:5] + coor[k] = coor[k][0:5] + ccc[k] = ccc[k][0:5] + model[k] = model[k][0:5] + ##serilize the list + ccc = [i for sub in ccc for i in sub] + rotx = [i for sub in rotx for i in sub] + coor = [i for sub in coor for i in sub] + model = [i for sub in model for i in sub] + # sort in sum + rotx = [x for _, x in sorted(zip(ccc, rotx), reverse=True)] + coor = [x for _, x in sorted(zip(ccc, coor), reverse=True)] + ccc = sorted(ccc, reverse=True) + top = min(len(ccc), self.search_depth * 2) # limit total candiate numbers + rotx = rotx[0:top] + coor = coor[0:top] + ccc = ccc[0:top] + model = model[0:top] + print("the maxccc is: " + str(max(ccc))) + print("the minccc is: " + str(min(ccc))) + else: + ccc = None + rotx = None + coor = None + model = None + print("skip current position") + return ccc, rotx, coor, model + + def prepare_rotamer(self): + # this is to generate all the rotated template to accelerate the later steps + rotamer = [] + for temp in range(len(self.templates)): + dim = self.templates[temp].shape[0] + r = np.zeros([len(self.global_angles), dim, dim, dim]) + for i, (phi, theta, psi) in enumerate(self.global_angles): + template = self.templates[temp] + r[i] = rotate_high_res(template, (phi, theta, psi)) + rotamer.append(r) + + return rotamer + + def generate_radom_angles(self): + angles = R.random(self.number_of_angles).as_euler("ZXZ", degrees=True) + return angles + + def generate_local_angles(self, index, range_deg, interval): + base_angles = self.rotations[index] + a = int(base_angles[0]) + b = int(base_angles[1]) + c = int(base_angles[2]) + return [ + [x, y, z] + for x in range(a - range_deg, a + range_deg + 1, interval) + for y in range(b - range_deg, b + range_deg + 1, interval) + for z in range(c - range_deg, c + range_deg + 1, interval) + ] + + @staticmethod + def extract_subtomo(tomo, coord, box_sizes): + """ + Extracts a subtomogram from a tomogram. + + Parameters: + - coord (int): Center coordinates of the subtomogram. + - tomo (numpy.ndarray): The tomogram from which to extract the subtomogram. + - box_size (list): The list of size of the cube to be extracted. + + Returns: + - list of numpy.ndarray: The extracted and normalized subtomogram. + """ + subtomograms = [] + for box_size in box_sizes: + start_x, start_y, start_z = ( + round(coord[0] - box_size / 2), + round(coord[1] - box_size / 2), + round(coord[2] - box_size / 2), + ) + end_x, end_y, end_z = ( + start_x + box_size, + start_y + box_size, + start_z + box_size, + ) + + # Calculate padding requirements + pad_before = [max(-start_z, 0), max(-start_y, 0), max(-start_x, 0)] + pad_after = [ + max(end_z - tomo.shape[0], 0), + max(end_y - tomo.shape[1], 0), + max(end_x - tomo.shape[2], 0), + ] + padding = list(zip(pad_before, pad_after)) + + # Extract and pad the subtomogram + subtomo = tomo[ + max(start_z, 0) : min(end_z, tomo.shape[0]), + max(start_y, 0) : min(end_y, tomo.shape[1]), + max(start_x, 0) : min(end_x, tomo.shape[2]), + ] + subtomo = np.pad(subtomo, padding, mode="constant") + + # Normalize and invert the contrast + # subtomogram = (subtomogram - subtomogram.mean()) / subtomogram.std() + subtomo = -subtomo + subtomograms.append(subtomo) + + return subtomograms + + def extract_subtomo_old(coor_x, coor_y, coor_z, tomo, size): + """ + Extracts a subtomogram from a tomogram. + + Parameters: + - coor_x, coor_y, coor_z (int): Center coordinates of the subtomogram. + - tomo (numpy.ndarray): The tomogram from which to extract the subtomogram. + - size (int): The size of the cube to be extracted. + + Returns: + - numpy.ndarray: The extracted and normalized subtomogram. + """ + # Define the boundaries of the subtomogram + x_min, x_max = round(coor_x - size / 2), round(coor_x + size / 2) + y_min, y_max = round(coor_y - size / 2), round(coor_y + size / 2) + z_min, z_max = round(coor_z - size / 2), round(coor_z + size / 2) + + # Initialize the subtomogram volume + subtomogram = np.zeros((size, size, size), dtype=np.float32) + + # Calculate the overlap ranges for slicing + overlap_x_min, overlap_x_max = max(0, x_min), min(tomo.shape[2], x_max) + overlap_y_min, overlap_y_max = max(0, y_min), min(tomo.shape[1], y_max) + overlap_z_min, overlap_z_max = max(0, z_min), min(tomo.shape[0], z_max) + + # Extract the overlapping region + overlap = tomo[ + overlap_z_min:overlap_z_max, + overlap_y_min:overlap_y_max, + overlap_x_min:overlap_x_max, + ] + + # Calculate where to place the overlap in the subtomogram + sub_x_min = max(0, -x_min) + sub_y_min = max(0, -y_min) + sub_z_min = max(0, -z_min) + + # Place the extracted overlap into the subtomogram + subtomogram[ + sub_z_min : sub_z_min + overlap.shape[0], + sub_y_min : sub_y_min + overlap.shape[1], + sub_x_min : sub_x_min + overlap.shape[2], + ] = overlap + + # Normalize and invert the contrast + subtomogram = (subtomogram - subtomogram.mean()) / subtomogram.std() + subtomogram = -subtomogram + + return subtomogram + + # Other methods like apply_wedge, getMW, etc. + + +# Example usage: +# template_matcher = TemplateMatcher(template, template2, mask, angles, mAng, shrink) +# subtomo = SubtomoExtractor.extract(coor_x, coor_y, coor_z, tomo, size) +# ccc, rotx, coor = template_matcher.match(subtomo, x, y, z, search_depth) diff --git a/src/test_ccc.py b/src/test_ccc.py new file mode 100644 index 0000000..d1eac5a --- /dev/null +++ b/src/test_ccc.py @@ -0,0 +1,62 @@ +import numpy as np +from scipy.signal import fftconvolve +from scipy.ndimage import zoom + + +def calculate_correlation(data_volume, template_volume, mask): + """ + Calculate the correlation of a data volume with a template volume using local normalization. + """ + mask = mask.astype(bool) + num_elements = np.sum(mask) + + # Template normalization within the mask + template_masked = template_volume * mask + template_mean = np.mean(template_volume[mask]) + template_std = np.std(template_volume[mask]) + normalized_template = (template_masked - template_mean) / template_std + + # Prepare for efficient convolution by padding the template and mask + padded_template = np.zeros_like(data_volume) + padded_mask = np.zeros_like(data_volume) + + # Assuming template and mask are smaller than the data_volume + template_slices = tuple( + slice(s // 2 - ts // 2, s // 2 + ts // 2 + 1) + for s, ts in zip(data_volume.shape, template_volume.shape) + ) + padded_template[template_slices] = normalized_template + padded_mask[template_slices] = mask + + # Calculate mean and variance across the volume using convolution + mean_volume = fftconvolve(data_volume, padded_mask, mode="same") / num_elements + mean_square_volume = ( + fftconvolve(data_volume**2, padded_mask, mode="same") / num_elements + ) + + # Local standard deviation + volume_variance = mean_square_volume - mean_volume**2 + np.maximum(volume_variance, 0, out=volume_variance) # Ensure non-negative variance + volume_std_dev = np.sqrt(volume_variance) + + # Normalize the data volume locally + adjusted_volume = (data_volume - mean_volume) / volume_std_dev + adjusted_volume[volume_std_dev == 0] = 0 # Avoid division by zero + + # Compute correlation using convolution + correlation_result = ( + fftconvolve(adjusted_volume, padded_template[::-1, ::-1, ::-1], mode="same") + / num_elements + ) + + return correlation_result + + +# Example usage +chunk = np.random.rand(100, 100, 100) # Large data volume +template = np.random.rand(20, 20, 20) # Smaller template +mask = np.ones(template.shape) # Uniform mask for simplicity + +# Compute cross-correlation +cc_result = calculate_correlation(chunk, template, mask) +print("Cross-correlation computed successfully.") diff --git a/src/test_config.py b/src/test_config.py new file mode 100644 index 0000000..b48784a --- /dev/null +++ b/src/test_config.py @@ -0,0 +1,85 @@ +# config.py + +# Path to the tomograms +""" +To do the template matching,the first stp is taking the input from the user: +1.The tomogram, should be black on white background +2.The templates, needs to be white on black background, to be consistent with normal convention +3.The coordinates of picked particles +4.The missing wedge information, default to [30,42], corresponding to [-60,+48] tilt range +5.The shrinkage factor, default to 0.3, this controls the contours of the template +6. There are more high lever parameters, like the search depth, etc. +""" + +tomograms = "/Users/michael/work/process/dev/CATM/data/test1/s205.fiber3-inv.mrc" + +# Paths to templates and corresponding masks +templates = [ + "/Users/michael/work/process/dev/CATM/data/test1/templates/HE1-1-local30-mono-8Apx_b25_flipy.mrc", +] # the templates need to be the same size + +masks = [ + "/Users/michael/work/process/dev/CATM/data/test1/templates/mask-mono-8Apx-lps30-box25-ex2-s2.mrc", +] # can be None, but if provied, should be the same number as templates + +# Coordinates of picked particles +coords = "/Users/michael/work/process/dev/CATM/data/test1/s205.fiber3-calibrated-flip-fixed.csv" +# coords = pd.read_csv("/path/to/coords.csv") +# Assuming the columns are x, y, z,change it if not + +# Missing wedge information +missing_wedge = [30, 42] # Corresponding to [-60, +48] tilt range + +# Shrinkage factor, which the counter level of the volume, can be determined in Chimera +# It's designed to control how close two objects are allowed to be +shrinkage_factor = 0.3 + +# output path +output_path = "/Users/michael/work/process/dev/CATM/data/test1/results" + + +# Keep the following parameters as default for now +# Number of angles for global template matching +Number_of_angles = 3000 +# Range of local search angles, only for local search, defulat to None for global search +# the rough angles need to be provied in the coords.csv, with columns phi, theta, psi +local_search_angles = None +# The minimum CCCs allowed for the template matching +min_CCC = 0.2 + + +# Search depth, control how many rotataions +search_depth = 100 + + + +"""' +tomograms = "/path/to/tomogram" + +# Paths to templates and corresponding masks +templates = [ + "/path/to/template1", + "/path/to/template2", +] # the templates need to be the same size + +masks = [ + "/path/to/mask1", + "/path/to/mask2", +] # can be None, but if provied, should be the same length as templates + +# Coordinates of picked particles +coords = "/path/to/coords.csv" +# coords = pd.read_csv("/path/to/coords.csv") +# Assuming the columns are x, y, z,change it if not +# coords = coords[["x", "y", "z"]].to_numpy() + +# Missing wedge information +missing_wedge = [30, 42] # Corresponding to [-60, +48] tilt range + +# Shrinkage factor, which the counter level of the volume, can be determined in Chimera +# It's designed to control how close two objects are allowed to be +shrinkage_factor = 0.3 + +# output path +output_path = "/path/to/output" +""" \ No newline at end of file diff --git a/src/test_multiprocessing.py b/src/test_multiprocessing.py new file mode 100644 index 0000000..746f53b --- /dev/null +++ b/src/test_multiprocessing.py @@ -0,0 +1,24 @@ +import multiprocessing +from multiprocessing import Pool + + +class MyClass: + def __init__(self): + self.data = [1, 2, 3, 4, 5] # Example data + + def square_number(self, number): + """Simple function to square a number.""" + return number * number + + def run_multiprocessing(self): + """Method to run multiprocessing on the square_number function.""" + with Pool(multiprocessing.cpu_count() - 1) as pool: + results = pool.map(self.square_number, self.data) + return results + + +# Example usage +if __name__ == "__main__": + my_class_instance = MyClass() + squared_numbers = my_class_instance.run_multiprocessing() + print(squared_numbers) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..b0fc2b2 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 13 13:18:34 2021 + +@author: huabin +""" +import scipy.ndimage +import numpy as np +from numpy.fft import fftn, ifftn, ifftshift +from scipy.ndimage import zoom + + +def perform_convolution(volume, template): + """ + Perform convolution of a volume with a given template. + new,testing + """ + # centering the template in the frequency domain + template_fft = fftn(template) + convolved_volume = np.real(ifftshift(ifftn(fftn(volume) * np.conj(template_fft)))) + return convolved_volume + + +def perform_convolution_good(volume, template): + """ + Perform convolution of a volume with a given template. + new,0.966 + """ + # centering the template in the frequency domain + template_fft = fftn(template) + convolved_volume = np.real(ifftshift(ifftn(fftn(volume) * np.conj(template_fft)))) + return convolved_volume + + +def perform_convolution_old(volume, template): + """ + Perform convolution of a volume with a given template. + old + """ + # centering the template in the frequency domain + template_fft = fftn(ifftshift(template)) + convolved_volume = np.real(ifftn(fftn(volume) * np.conj(template_fft))) + return convolved_volume + + +def calculate_correlation(data_volume, template_volume, mask): + """ + Calculate the correlation of a data volume with a template volume. + The correlation is similar to the implement in PyTom + """ + # + mask = mask.astype(bool) + num_elements = np.sum(mask) + template_mean = np.mean(template_volume[mask]) + template_std = np.std(template_volume[mask]) + normalized_template = template_volume - template_mean + normalized_template /= template_std + + # Calculate volume mean and sqt under the mask,this is important to get local mean and variance + mask_float = mask.astype(np.float32) + mean_volume = perform_convolution(data_volume, mask_float) / num_elements + mean_square_volume = ( + perform_convolution(np.square(data_volume), mask_float) / num_elements + ) + # calculating the local standard deviation,normalize the data volume + # such that it reflects the local variability in data. + volume_std_dev = mean_square_volume - np.square(mean_volume) + np.maximum( + volume_std_dev, 0.0, out=volume_std_dev + ) # In-place maximum,mask sure it's not negative + np.sqrt(volume_std_dev, out=volume_std_dev) # In-place square root + + # normalize the data input + # adjusted_volume = data_volume - mean_volume + # valid_std_dev = volume_std_dev > 0 + # adjusted_volume[valid_std_dev] /= volume_std_dev[valid_std_dev] # In-place division + + # calculate correlation + normalized_template *= mask_float + correlation_result = ( + perform_convolution(data_volume, normalized_template) + / volume_std_dev + / num_elements + ) + return correlation_result + + +def calculate_correlation_old(data_volume, template_volume, mask): + """ + Calculate the correlation of a data volume with a template volume. + The correlation is similar to the implement in PyTom + """ + # + mask = mask.astype(bool) + num_elements = np.sum(mask) + template_mean = np.mean(template_volume[mask]) + template_std = np.std(template_volume[mask]) + normalized_template = template_volume - template_mean + normalized_template /= template_std + + # Calculate volume mean and sqt under the mask,this is important to get local mean and variance + mask_float = mask.astype(np.float32) + mean_volume = perform_convolution(data_volume, mask_float) / num_elements + mean_square_volume = ( + perform_convolution(np.square(data_volume), mask_float) / num_elements + ) + # calculating the local standard deviation,normalize the data volume + # such that it reflects the local variability in data. + volume_std_dev = mean_square_volume - np.square(mean_volume) + np.maximum( + volume_std_dev, 0.0, out=volume_std_dev + ) # In-place maximum,mask sure it's not negative + np.sqrt(volume_std_dev, out=volume_std_dev) # In-place square root + + # normalize the data input + adjusted_volume = data_volume - mean_volume + valid_std_dev = volume_std_dev > 0 + adjusted_volume[valid_std_dev] /= volume_std_dev[valid_std_dev] # In-place division + + # calculate correlation + normalized_template *= mask_float + correlation_result = ( + perform_convolution(adjusted_volume, normalized_template) / num_elements + ) + return correlation_result + + +def rotate(array, orient): + phi, theta, psi = orient + arrayR = scipy.ndimage.rotate(array, float(phi), axes=(1, 2), reshape=False) + arrayR = scipy.ndimage.rotate(arrayR, float(theta), axes=(0, 1), reshape=False) + arrayR = scipy.ndimage.rotate(arrayR, float(psi), axes=(1, 2), reshape=False) + return arrayR + + +def rotate_high_res(array, orient): + phi, theta, psi = orient + arrayR = scipy.ndimage.rotate( + array, float(phi), axes=(1, 2), reshape=False, order=5 + ) + arrayR = scipy.ndimage.rotate( + arrayR, float(theta), axes=(0, 1), reshape=False, order=5 + ) + arrayR = scipy.ndimage.rotate( + arrayR, float(psi), axes=(1, 2), reshape=False, order=5 + ) + return arrayR + + +# The missing wedge function is adapted and optimized from IsoNet,https://github.com/IsoNet-cryoET/IsoNet +def get_missing_wedge_mask(dim1, dim2, missing_angles): + mw = np.zeros((dim1, dim2), dtype=np.double) + missing_radians = np.pi / 180 * (90 - np.array(missing_angles)) + radius_squared = (min(dim1, dim2) / 2) ** 2 + + for i in range(dim1): + for j in range(dim2): + y, x = (i - dim1 / 2), (j - dim2 / 2) + theta = np.pi / 2 if x == 0 else abs(np.arctan(y / x)) + + if x**2 + y**2 <= radius_squared: + if ( + (x > 0 and y > 0 and theta < missing_radians[0]) + or (x < 0 and y < 0 and theta < missing_radians[0]) + or (x > 0 and y < 0 and theta < missing_radians[1]) + or (x < 0 and y > 0 and theta < missing_radians[1]) + or (int(y) == 0) + ): + mw[i, j] = 1 + + return mw + + +def apply_wedge(ori_data, missing_angles): + data = np.rot90(ori_data, k=1, axes=(0, 1)) + mw = get_missing_wedge_mask(data.shape[1], data.shape[2], missing_angles) + f_data = np.fft.fftn(data) + mwshift = np.fft.fftshift(mw) + outData = mwshift * f_data + real = np.real(np.fft.ifftn(outData)).astype(np.float32) + out = np.rot90(real, k=3, axes=(0, 1)) + return out + + +def rescale_ctf_volume(ctf_volume, original_size, target_size): + """ + Rescales a 3D CTF volume from original_size to target_size. + + :param ctf_volume: 3D numpy array representing the CTF volume. + :param original_size: Tuple of three ints representing the original size (x, y, z). + :param target_size: Tuple of three ints representing the target size (x, y, z). + :return: 3D numpy array of the rescaled CTF volume. + """ + # Calculate the zoom factors for each dimension + zoom_factors = [n / o for n, o in zip(target_size, original_size)] + + # Use scipy's zoom function to rescale the volume + rescaled_volume = zoom(ctf_volume, zoom_factors, order=3) + + return rescaled_volume + + +def prepare_ctf_volumes(ctf_ori, templates): + ctf_out = [] + ctf_ori_dim = ctf_ori.shape[0] + for i in range(len(templates)): + temp_dim = templates[i].shape[0] + if ctf_ori_dim != temp_dim: + ctf_out.append( + rescale_ctf_volume(ctf_ori, ctf_ori.shape, templates[i].shape) + ) + else: + ctf_out.append(ctf_ori) + return ctf_out + + +def apply_ctf(real_data, ctf): + """ + Apply the CTF correction to the real data and return the result in real space. + """ + + # Perform Fourier transform on real data + fourier_real_data = fftn(real_data) + + # Apply CTF to Fourier-transformed data + corrected_data = fourier_real_data * ctf + + # Inverse Fourier transform to get back to real space + corrected_real_data = ifftn(corrected_data) + + # Return only the real part of the corrected data + return np.float32(np.real(corrected_real_data)) + + +def clash_test(a_vox, b_vox, c_vox, vol, dims): + count = 0 + if max(a_vox) >= dims[2] or max(b_vox) >= dims[1] or max(c_vox) >= dims[0]: + count = 1 + else: + volsel = vol[np.array(c_vox), np.array(b_vox), np.array(a_vox)] + if volsel.max() > 0 or min(a_vox) < 0 or min(b_vox) < 0 or min(c_vox) < 0: + count = 1 + + return count + + +def plot_obj(vol_array, obj_rot, coord, shrink): + dims = vol_array.shape + offset = int(obj_rot.shape[0] // 2) + obj_voxel = np.nonzero(obj_rot > shrink) + x_vox = obj_voxel[2] + int(coord[0]) - offset + y_vox = obj_voxel[1] + int(coord[1]) - offset + z_vox = obj_voxel[0] + int(coord[2]) - offset + # vol_array[np.array(z_vox),np.array(y_vox),np.array(x_vox)]=obj_rot[ + # np.array(obj_voxel[0]),np.array(obj_voxel[1]),np.array(obj_voxel[2])] + for idx in range(x_vox.size): + xx = x_vox[idx] + yy = y_vox[idx] + zz = z_vox[idx] + aa = obj_voxel[2][idx] + bb = obj_voxel[1][idx] + cc = obj_voxel[0][idx] + if 0 <= xx < dims[2] and 0 <= yy < dims[1] and 0 < zz < dims[0]: + vol_array[zz, yy, xx] = obj_rot[cc, bb, aa] + return vol_array + + +def find_info(obj_xml, ch1, search_depth): + # this is original code, I don't want to rewrite it for now + info = [[] for i in range(8)] + obj_xml = obj_xml + for i in range(min(len(obj_xml[ch1]), search_depth)): + xx = obj_xml[ch1][i].attrib["x"] + yy = obj_xml[ch1][i].attrib["y"] + zz = obj_xml[ch1][i].attrib["z"] + phi = obj_xml[ch1][i].attrib["phi"] + psi = obj_xml[ch1][i].attrib["theta"] + the = obj_xml[ch1][i].attrib["psi"] + ccc = obj_xml[ch1][i].attrib["CCC"] + model = obj_xml[ch1][i].attrib["model"] + info[0].append(xx) + info[1].append(yy) + info[2].append(zz) + info[3].append(phi) + info[4].append(psi) + info[5].append(the) + info[6].append(ccc) + info[7].append(model) + return info + + +def try_add_obj(vol_array, obj_rot, coord, shrinkage, ccc=0): + dims = vol_array.shape + offset = int(obj_rot.shape[0] / 2) + obj_voxel = np.nonzero(obj_rot > shrinkage) + x_vox = obj_voxel[2] + int(coord[0]) - offset + y_vox = obj_voxel[1] + int(coord[1]) - offset + z_vox = obj_voxel[0] + int(coord[2]) - offset + cr = clash_test(x_vox, y_vox, z_vox, vol_array, dims) + if cr == 1: + score = -1 + else: + vol_array[np.array(z_vox), np.array(y_vox), np.array(x_vox)] = obj_rot[ + np.array(obj_voxel[0]), np.array(obj_voxel[1]), np.array(obj_voxel[2]) + ] + score = float(ccc) + + return vol_array, score + + +def normalize_roi(data, mask): + """Normalize data using the mask.""" + masked_data = data * mask + mean = np.sum(masked_data) / np.sum(mask) + variance = np.sum((masked_data - mean) ** 2 * mask) / np.sum(mask) + normalized_data = (masked_data - mean) / np.sqrt(variance) + return normalized_data + + +def cross_correlation(chunk, template, mask): + """Calculate the cross-correlation between the template and the chunk.""" + # Normalize the chunk and template using the mask + chunk_normalized = normalize_roi(chunk, mask) + template_normalized = normalize_with_mask(template, mask) + + # Fourier transforms + F_chunk = fftn(chunk_normalized, s=chunk.shape) + F_template = fftn(template_normalized, s=chunk.shape) + F_mask = fftn(mask, s=chunk.shape) + + # Cross-correlation core calculation + RD = np.real(ifftn(F_template * np.conj(F_chunk))) # Cross-correlation + mcn = np.real(ifftn(F_mask * np.conj(F_chunk**2))) # Variance normalization + mca = np.real(ifftn(F_mask * np.conj(F_chunk))) # Mean normalization + + mca2 = mca**2 / np.sum(mask) + mcn_joint = mcn - mca2 + norm_data = np.sqrt(np.maximum(mcn_joint, 0)) # Avoid negative under sqrt + norm_template = np.sqrt(np.sum(template_normalized**2 * mask)) + + mcn_final = norm_data * norm_template + mcn_final[mcn_final == 0] = np.nan # Avoid division by zero + + cc = RD / mcn_final + cc = fftshift(cc) # Align the output with direct space data coordinates + + # Handling NaN and inf values + cc[np.isnan(cc)] = 0 + cc[np.isinf(cc)] = 0 + + return cc + + +# Helper function to normalize the template with the mask +def normalize_with_mask(template, mask): + masked_template = template * mask + mean = np.sum(masked_template) / np.sum(mask) + variance = np.sum((masked_template - mean) ** 2 * mask) / np.sum(mask) + if variance == 0: + return np.zeros_like(template) # Avoid division by zero in normalization + return (masked_template - mean) / np.sqrt(variance) -- GitLab