Skip to content
Snippets Groups Projects

Add preliminary support for looking at the ddG of only a subset of residues and for getting per-residue energy changes

Merged Vishruth Mullapudi requested to merge center_chain_energetics into master
Compare and
1 file
+ 213
0
Preferences
File browser
Compare changes
import argparse
import os
from tqdm import tqdm
import toml
import pathlib
import concurrent.futures
from pathlib import Path
from itertools import takewhile, islice, dropwhile
import pandas as pd
global rosetta_output_file_name
global output_database_name
global rosetta_path
global rosetta_db
global USE_MULTITHREADING
global PROCESS_COUNT
def find_finished_jobs(output_folder):
return_dict = {}
job_dirs = [os.path.abspath(os.path.join(output_folder, d)) for d in os.listdir(output_folder) if
os.path.isdir(os.path.join(output_folder, d))]
for job_dir in job_dirs:
completed_struct_dirs = []
for potential_struct_dir in sorted([os.path.abspath(os.path.join(job_dir, d)) for d in os.listdir(job_dir) if
os.path.isdir(os.path.join(job_dir, d))]):
if rosetta_output_succeeded(potential_struct_dir):
if(os.path.isfile(os.path.join(potential_struct_dir,"wt_%05d.pdb" % trajectory_stride)) and os.path.isfile(os.path.join(potential_struct_dir,"mut_%05d.pdb" % trajectory_stride))):
completed_struct_dirs.append(potential_struct_dir)
else:
print("Succesful output found in %s, but both wt and mut pdb files not found. Was the structure extraction run completely?" % potential_struct_dir)
return_dict[job_dir] = completed_struct_dirs
return return_dict
def rosetta_output_succeeded(potential_struct_dir):
path_to_rosetta_output = os.path.join(potential_struct_dir, rosetta_output_file_name)
if not os.path.isfile(path_to_rosetta_output):
return False
db3_file = os.path.join(potential_struct_dir, output_database_name)
if not os.path.isfile(db3_file):
return False
success_line_found = False
no_more_batches_line_found = False
with open(path_to_rosetta_output, 'r') as f:
for line in f:
if line.startswith('protocols.jd2.JobDistributor') and 'reported success in' in line:
success_line_found = True
if line.startswith('protocols.jd2.JobDistributor') and 'no more batches to process' in line:
no_more_batches_line_found = True
return no_more_batches_line_found and success_line_found
def extract_pose_tables(casefolder, replicate_folders, renumber=None):
pose_table = pd.DataFrame()
for replicate_folder in replicate_folders:
# 1) get dataframe containing center chains from pdb wt and mut
# 2) calculate bound ddG from mut-wt aligning on the pose number and chain
# 3) optionally renumber (needed for phf I think but not cbd?)
wt_pdb = os.path.join(casefolder, replicate_folder,"wt_%05d.pdb" % trajectory_stride)
mut_pdb = os.path.join(casefolder, replicate_folder,"mut_%05d.pdb" % trajectory_stride)
with open(wt_pdb) as fid:
# need to convert pose numbering to
wt_pose_table = pd.DataFrame(x.split() for x in takewhile(lambda x: '#END_POSE_ENERGIES_TABLE' not in x, islice(dropwhile(lambda x: '#BEGIN_POSE_ENERGIES_TABLE' not in x, fid), 1, None)))
wt_pose_table.columns = wt_pose_table.iloc[0]
wt_pose_table = wt_pose_table.iloc[3:].assign(scored_state='wt_dG').assign(case_name=os.path.basename(casefolder)).assign(replicate=os.path.basename(replicate_folder))
with open(mut_pdb) as fid:
# need to convert pose numbering to
mut_pose_table = pd.DataFrame(x.split() for x in takewhile(lambda x: '#END_POSE_ENERGIES_TABLE' not in x, islice(dropwhile(lambda x: '#BEGIN_POSE_ENERGIES_TABLE' not in x, fid), 1, None)))
mut_pose_table.columns = mut_pose_table.iloc[0]
mut_pose_table = mut_pose_table.iloc[3:].assign(scored_state='mut_dG').assign(case_name=os.path.basename(casefolder)).assign(replicate=os.path.basename(replicate_folder))
pose_table = pd.concat((pose_table, mut_pose_table, wt_pose_table))
pose_table[['PDB_name', 'Mutated Chains', 'mutant']] = pose_table.case_name.str.split('_', expand=True)
pose_table[['wt_aa', 'res_num', "mutant_aa"]] = pose_table.mutant.str.split(r'([A-Z]{3})(\d{1,})([A-Z])', expand=True) \
.replace("", float("NaN")).dropna(how='all', axis=1)
pose_table.reset_index(inplace=True)
pose_table['res_num'] = pose_table['res_num'].astype(int)
return pose_table
def calc_dGs(df, minposenum, maxposenum):
per_res_ddgs = pd.DataFrame()
total_ddgs = pd.DataFrame()
center_ddgs = pd.DataFrame()
for column in df.columns:
if column not in ['index', 'case_name', 'PDB_name',
'Mutated Chains', 'wt_aa', 'replicate', 'res_num',
'mutant_aa', 'scored_state', 'label', 'mutant']:
df.loc[:,column] = df.loc[:,column].astype(float)
df[['wt_aa', 'pose_num']] = df['label'].str.rsplit('_', 1, expand=True)
df.loc[:,'pose_num'] = df['pose_num'].astype(int)
for case, data in df.groupby('case_name'):
wt = data.loc[data['scored_state'] == 'wt_dG']
mut = data.loc[data['scored_state'] == 'mut_dG']
for column in wt.columns:
# multiply their score fields by negative 1
if column not in ['index', 'case_name', 'PDB_name',
'Mutated Chains', 'wt_aa', 'replicate', 'res_num',
'mutant_aa', 'scored_state', 'label', 'mutant', 'wt_aa',
'pose_num']:
wt.loc[:, column] *= -1.0
bound_dG = pd.concat([mut, wt])
# get the bound ddg for each replicate at each pose number (per residue ddG)
bound_dG = bound_dG.groupby(
['PDB_name', 'Mutated Chains', 'replicate', 'res_num',
'mutant_aa', 'mutant', 'pose_num']).sum().reset_index()
# average all the nstruct dG's together
bound_dG = bound_dG.groupby(['PDB_name', 'Mutated Chains',
'res_num', 'mutant_aa', 'mutant', 'pose_num']).mean().round(
decimals=5).reset_index()
bound_dG.loc[:,'scored_state'] = "bound_ddG"
bound_dG['case_name'] = case
per_res_ddgs = pd.concat([per_res_ddgs, bound_dG])
total_ddgs = pd.concat([total_ddgs, bound_dG.groupby('mutant').sum().reset_index()])
center_chains = bound_dG[(bound_dG['pose_num'] >= minposenum) & (bound_dG['pose_num'] < maxposenum)]
temp = center_chains.groupby(['mutant']).sum().drop(columns=['res_num', 'pose_num']).reset_index()
temp['case'] = case
center_ddgs = pd.concat([center_ddgs, temp])
#ddgs = per_case_per_res_dGs.groupby(['mutant', 'pose_num']).sum()
return per_res_ddgs, total_ddgs, center_ddgs
def main(input_dir, minposenum, maxposenum):
print("Listing Alanine Scan Ouptut Folders")
pdb_dirs = find_finished_jobs(input_dir)
print('Found {:d} directories to analyze'.format(len(pdb_dirs)))
resultdf = pd.DataFrame()
if USE_MULTIPROCESSING:
with tqdm(total=len(pdb_dirs), unit="cases") as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=PROCESS_COUNT) as executor:
futures = {executor.submit(extract_pose_tables, casefolder, replicatefolders): (casefolder, replicatefolders) for casefolder, replicatefolders in pdb_dirs.items()}
results = {}
for future in concurrent.futures.as_completed(futures):
arg = futures[future]
results[arg[0]] = future.result()
pbar.update(1)
resultdf = pd.concat(list(results.values()))
else:
for casefolder, replicatefolders in tqdm(pdb_dirs.items()):
resultdf = pd.concat((resultdf, extract_pose_tables(casefolder, replicatefolders)))
resultdf = resultdf.reindex(columns=['PDB_name', 'Mutated Chains',
'wt_aa', 'replicate', 'res_num', 'mutant_aa', 'total',
'scored_state','fa_atr', 'fa_dun', 'fa_elec',
'fa_intra_rep', 'fa_rep','fa_sol', 'hbond_bb_sc',
'hbond_lr_bb', 'hbond_sc', 'hbond_sr_bb', 'omega',
'p_aa_pp', 'pro_close', 'rama', 'ref', 'yhh_planarity',
'dslf_fa13' , 'linear_chainbreak', 'overlap_chainbreak',
'case_name', 'label', 'mutant'])\
.sort_values(by=['res_num', 'replicate']).reset_index()
per_res_ddgs, total_ddgs, center_ddgs = calc_dGs(resultdf, minposenum, maxposenum)
per_res_ddgs.sort_values(by=['res_num', 'pose_num'], inplace=True)
#center_ddgs.sort_values(by=['res_num'])
basename = os.path.basename((input_dir + "/").rstrip("\\/"))
script_output_folder = 'analysis_output_sample'
center_ddgs.to_csv(os.path.join(script_output_folder, basename + '-center_chain_ddGs.csv'), index=False)
per_res_ddgs.to_csv(os.path.join(script_output_folder, basename + '-per_res_ddGs.csv'), index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="""Extracts structures generated by mutation scan script """
)
parser.add_argument(
'output_folder',
type=str,
help="The mutation output folder to analyze",
)
parser.add_argument(
"-c",
"--num_cpu",
type=int,
default=1,
help="The number of Rosetta processes to run in parallel. Take care to not exceed the number of available CPU threads or avaiable system memory when using larger numbers of threads.",
)
parser.add_argument(
"-m",
"--min",
type=int,
help="min pose num",
)
parser.add_argument(
"-x",
"--max",
type=int,
help="max pose num",
)
args = vars(parser.parse_args())
# regenerates global values for all the multiprocessing pool child processes
# This is the Rosetta structure output sqlite database found in the alanine scan output subfolders
rosetta_output_file_name = 'rosetta.out'
output_database_name = 'ddG.db3'
# Important - to correctly name extracted structures by the stride, this trajectory_stride must be used
# Set this to what was used to run the alanine scan with
global trajectory_stride
trajectory_stride = 10
# Enter the path to your Rosetta installation's score_jd2 binary here
rosetta_path = '/home/vish/rosetta_3.13/main/source/bin/score_jd2.linuxgccrelease_native'
rosetta_db = '/home/vish/rosetta_3.13/main/database'
if args['num_cpu'] == 1:
USE_MULTIPROCESSING = False
PROCESS_COUNT = 1
else:
USE_MULTIPROCESSING = True
PROCESS_COUNT = max(os.cpu_count(), args['num_cpu'])
if os.path.isdir(args['output_folder']):
main(args['output_folder'], args['min'], args['max'])
else:
print('ERROR: %s is not a valid directory' % args['output_folder'])
\ No newline at end of file