# this is supposed to be a module that makes it easy to run stuff on the cluster or at home, etc.
import dill
import numpy as np
import os
import random
import sys
import subprocess
import time

base="""#!/bin/bash
#
#SBATCH --job-name=NAME
#SBATCH -o res_RESID.txt  # output file
#SBATCH -e res_RESID.err   # File to which STDERR will be written
#SBATCH --partition=PARTITION # Partition to submit to
#SBATCH -n NCORES #
##SBATCH --gres=gpu
#SBATCH -N 1 # makes sure cores are all on one node
#SBATCH --time=0-HOURS:MINUTES       # Runtime in D-HH:MM (max 12 for defq on swarm)
#SBATCH --mem=MEM
#SBATCH --exclude=swarm[001-010]
#SBATCH -o ./report/output.%j.txt # STDOUT
#SBATCH -e ./report/error.%j.txt  # STDERR
export MKL_NUM_THREADS=NCORES
export MKL_DYNAMIC="FALSE"
export OMP_NUM_THREADS=NCORES
python -c "import mkl; print('max threads:', mkl.get_max_threads())"
python ezrun.py FULLID"""

def onserver():
    #import sys
    #return sys.platform=='linux'
    #return os.path.isfile('sbatch')
    import socket
    return socket.gethostname().startswith('swarm')

def getscript(id,name=None,resid=None,ncores=1, mem=100,hours=12,minutes=0):
    mycmd = base.replace('FULLID',id).replace('NCORES',str(ncores)).replace('MEM',str(mem))
    if resid is not None:
        mycmd = mycmd.replace('RESID',resid)
    else:
        mycmd = mycmd.replace('RESID','%j')

    if name is not None:
        mycmd = mycmd.replace('NAME',name)
    else:
        mycmd = mycmd.replace('NAME','ezrun')

    mycmd = mycmd.replace("HOURS","{:02d}".format(hours))
    mycmd = mycmd.replace("MINUTES","{:02d}".format(minutes))

    # automatically switch to longq if necessary
    if hours > 12:
        partition = 'longq'
    else:
        partition = 'defq'
    mycmd = mycmd.replace('PARTITION',partition)

    return mycmd

def run(fun,*args,name=None,resid=None,ncores=1,mem=100,submit=None,hours=12,minutes=0):
    if submit is None:
        submit = onserver()
    if not submit:
        fun(*args)
        return

    if not os.path.isdir('report'):
        raise Exception('must have report subdirectory in directory where you run ezrun')

    # get a random id
    id = 'ezrun_' + ''.join(random.choice('0123456789ABCDEF') for i in range(64))

    # save the actual function to run
    dill.dump((fun,args),open(id + '.pk','wb'))

    # write it to a file
    text_file = open(id + ".sh", "w")
    text_file.write(getscript(id,name,resid,ncores=ncores,mem=mem,hours=hours,minutes=minutes))
    text_file.close()

    os.system('sbatch ' + id + '.sh')
    #subprocess.Popen('sbatch ' + str(id) + '.sh',shell=True)
    #proc = subprocess.Popen('sbatch ' + id + '.sh', shell=True,
    #         stdin=None, stdout=subprocess.DEVNULL, stderr=None, close_fds=True)

def good_iter_to_print(i):
    if(i==0):
        return True;
    a = 10 ** np.floor(np.log10(i*1.0))
    return (i%a)==0

def nrunning():
    # extra args do little
    rez = subprocess.run(['squeue','-u','domke','-n','ezrun','-t','RUNNING'], shell=True, stdout=subprocess.PIPE)
    nprocs = len(str(rez.stdout).split('ezrun    domke  R'))
    nprocs -= 1
    return nprocs

def npending():
    rez = subprocess.run(['squeue','-u','domke','-n','ezrun','-t','PENDING'], shell=True, stdout=subprocess.PIPE)
    nprocs = len(str(rez.stdout).split('ezrun    domke PD'))
    nprocs -= 1
    return nprocs

def wait_for_complete():
    if not onserver():
        return

    time.sleep(2)
    t0 = time.time()

    iters = 0
    while nrunning() > 0 or npending() > 0:
        time.sleep(.1)
        iters += 1
        # want to print if all digits of iters are zero except last
        if good_iter_to_print(iters):
            print('waiting for jobs. time : {: 8.4f}  jobs runnning {:4d} pending {:4d}'.format(time.time()-t0, nrunning(), npending()), flush=True)
            #print('waiting for jobs. time: ' + str(time.time()-t0) + ' jobs running: ' + str(nrunning()) + ' jobs pending: ' + str(npending()),flush=True)

    time.sleep(2)

def cancel_all_jobs():
    os.system('scancel -u domke -n ezrun')
    wait_for_complete()

def clean_tmps():
    os.system('rm ezrun_*.pk')
    os.system('rm ezrun_*.sh')

if __name__ == '__main__':
    id = sys.argv[1]
    print("id: " + id)

    f,args = dill.load(open(id + ".pk","rb"))
    f(*args)

    print("done running.")
    os.system('rm ' + id + '.pk')
    os.system('rm ' + id + '.sh')
    print("cleaned up files.")
