Commit f9eb6015 authored by David Trudgian's avatar David Trudgian
Browse files

Parallel srun submission per SLURM allocation

parent d917dc1b
......@@ -63,12 +63,12 @@ def main():
p.load()
if arguments['run']:
runner = executors.LocalExecutor(p.commands, os.path.dirname((os.path.abspath(param_file))))
runner = executors.LocalExecutor(p.commands, os.path.dirname((os.path.abspath(param_file))), p)
runner.run()
if arguments['srun']:
runner = executors.SrunExecutor(p.commands, os.path.dirname(
(os.path.abspath(param_file))))
(os.path.abspath(param_file))), p)
runner.run()
except Exception as e:
......
from abc import ABCMeta, abstractmethod
from subprocess import call
from pathos.multiprocessing import Pool
import os
import datetime
import logging
import re
from .templates import template_env
logger = logging.getLogger(__name__)
class BaseExecutor(object):
class BaseExecutor:
metaclass = ABCMeta
commands = []
logdir =""
params = None
def __init__(self, commands, cwd, params):
def __init__(self, commands, cwd):
logger.info("Initializing local executor")
logger.info("Initializing %s " % self.__class__.__name__)
self.commands = commands
self.params = params
logdir = os.path.join(cwd, "param_runner_%s" % datetime.datetime.now().strftime('%Y%m%d-%H%M%s') )
......@@ -28,7 +36,7 @@ class BaseExecutor(object):
self.logdir = logdir
def run(self):
def start_trace(self):
trace_path = os.path.join(self.logdir, 'trace.txt')
trace_file = open(trace_path, 'w', buffering = False)
......@@ -40,46 +48,69 @@ class BaseExecutor(object):
trace_file.write("Return Code\tDuration\tStart\tEnd\n")
cmd_idx = 0
for cmd_idx, cmd in enumerate(self.commands):
def write_trace(self, msg):
trace_path = os.path.join(self.logdir, 'trace.txt')
trace_file = open(trace_path, 'w', buffering = False)
trace_file.write("msg")
stdout_path = os.path.join(self.logdir, "%d.out" % cmd_idx)
stderr_path = os.path.join(self.logdir, "%d.err" % cmd_idx)
stderr_file = open(stderr_path, 'w')
stdout_file = open(stdout_path, 'w')
trace_file.close()
start = datetime.datetime.now()
ret = self.run_cmd(cmd, stderr_file, stdout_file)
end = datetime.datetime.now()
duration = end - start
stderr_file.close()
stdout_file.close()
def wrap_cmd(self, cmd_idx):
summary_line = "%d\t" % cmd_idx
for arg in cmd.values():
summary_line += "%s\t" % str(arg['value'])
logger.debug(" - Task %d is running" % cmd_idx)
summary_line += "%d\t%s\t%s\t%s\n" % (
ret,
duration,
start,
end
)
cmd = self.commands[cmd_idx]
trace_file.write(summary_line)
stdout_path = os.path.join(self.logdir, "%d.out" % cmd_idx)
stderr_path = os.path.join(self.logdir, "%d.err" % cmd_idx)
stderr_file = open(stderr_path, 'w')
stdout_file = open(stdout_path, 'w')
trace_file.close()
start = datetime.datetime.now()
ret = self.run_cmd(cmd, stderr_file, stdout_file)
end = datetime.datetime.now()
duration = end - start
stderr_file.close()
stdout_file.close()
summary_line = "%d\t" % cmd_idx
for arg in cmd.values():
summary_line += "%s\t" % str(arg['value'])
summary_line += "%d\t%s\t%s\t%s\n" % (
ret,
duration,
start,
end
)
self.write_trace(summary_line)
@abstractmethod
def run(self):
pass
@abstractmethod
def run_cmd(self):
pass
class LocalExecutor(BaseExecutor):
def run(self):
for cmd_idx, cmd in enumerate(self.commands):
wrap_cmd(cmd_idx,cmd)
def run_cmd(self, cmd, stderr_file, stdout_file):
ret = call(cmd['__command']['value'], shell=True, stderr=stderr_file,
......@@ -90,11 +121,38 @@ class LocalExecutor(BaseExecutor):
class SrunExecutor(BaseExecutor):
def run(self):
if 'SLURM_JOB_CPUS_PER_NODE' not in os.environ:
raise EnvironmentError('The srun executor can only be run inside a SLURM allocation.')
logger.info(" - Running in SLURM JOB %s" % os.environ.get('SLURM_JOB_ID' ))
cpus_per_node = int(re.match('\d+', os.environ.get('SLURM_JOB_CPUS_PER_NODE') ).group(0))
logger.info(" - %d CPUs per node" % cpus_per_node)
nodes = int(os.environ.get('SLURM_NNODES'))
logger.info(" - %d nodes" % nodes)
cpus_per_task = self.params.vals['cpus_per_task']
conc_tasks = (cpus_per_node / cpus_per_task) * nodes
logger.info(" - %d Concurrent tasks" % conc_tasks)
p = Pool(conc_tasks)
for cmd_idx, cmd in enumerate(self.commands):
logger.debug(" - Task %d is waiting in the pool" % cmd_idx)
p.apply_async( self.wrap_cmd, [cmd_idx])
p.close()
p.join()
def run_cmd(self, cmd, stderr_file, stdout_file):
srun_cmd = "srun -n1 '%s'" % cmd['__command']['value']
cpus_per_task = self.params.vals['cpus_per_task']
srun_cmd = "srun -n1 -c%d '%s'" % ( cpus_per_task, cmd['__command']['value'])
ret = call(srun_cmd, shell=True, stderr=stderr_file,
stdout=stdout_file)
stdout=stdout_file, env=os.environ)
return ret
\ No newline at end of file
return ret
......@@ -240,13 +240,6 @@ class ParamFile(object):
class IntRange(object):
"""Compute integer parameters from an int_range definition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment