import os
import sys
import shutil
import pprint as pp
import traceback
import time
import numpy as np
from matplotlib import pyplot as plt

from casatasks.private.casa_transition import is_CASA6
if is_CASA6:
    from casatasks import casalog
    from casatools import table, ms, msmetadata
    from casatools.platform import bytes2str

    import subprocess

    mst_local = ms()
    tbt_local = table()
    msmdt_local = msmetadata()
else:
    from __main__ import *
    from taskinit import *

    import commands

    mst_local = mstool()
    tbt_local = tbtool()
    msmdt_local = msmdtool()

class convertToMMS():
    def __init__(self,\
                 inpdir=None, \
                 mmsdir=None, \
                 axis='auto', \
                 numsubms=4,
#                 createmslink=False, \
                 cleanup=False):

        '''Run the partition task to create MMSs from a directory with MSs'''
        casalog.origin('convertToMMS')

        self.inpdir = inpdir
        self.outdir = mmsdir
        self.axis = axis
        self.numsubms = numsubms
#        self.createmslink = createmslink
        self.mmsdir = '/tmp/mmsdir'
        self.cleanup = cleanup        
        
        # Input directory is mandatory
        if self.inpdir is None:
            casalog.post('You must give an input directory to this script') 
            self.usage()
            return
            
        if not os.path.exists(self.inpdir):
            casalog.post('Input directory inpdir does not exist -> '+self.inpdir,'ERROR') 
            self.usage()
            return
        
        if not os.path.isdir(self.inpdir):                            
            casalog.post('Value of inpdir is not a directory -> '+self.inpdir,'ERROR') 
            self.usage()
            return


        # Only work with absolute paths
        self.inpdir = os.path.abspath(self.inpdir)
        casalog.post('Will read input MS from '+self.inpdir)

        # Verify output directory
        if self.outdir is None:
            self.mmsdir = os.path.join(os.getcwd(),'mmsdir')
        elif self.outdir == '/':
            casalog.post('inpdir is set to root!', 'WARN')
            self.mmsdir = os.path.join(os.getcwd(),'mmsdir')
        else:
            self.outdir = os.path.abspath(self.outdir)
            self.mmsdir = self.outdir
            
        if self.mmsdir == self.inpdir:
            casalog.post('Output directory cannot be same of input directory','ERROR')
            return

        # Cleanup output directory
        if self.cleanup:
            casalog.post('Cleaning up output directory '+self.mmsdir)
            if os.path.isdir(self.mmsdir):
                shutil.rmtree(self.mmsdir)
        
        if not os.path.exists(self.mmsdir):
            os.makedirs(self.mmsdir)
            
        
        casalog.post('Will save output MMS to '+self.mmsdir)

        # Walk through input directory
        files = os.walk(self.inpdir,followlinks=True).next()

        # Get MS list
        mslist = []
        mslist = self.getMSlist(files)
                        
        casalog.post('List of MSs in input directory')
        casalog.post(pp.pformat(mslist))
        
        # Get non-MS directories and other files
        nonmslist = []
        nonmslist = self.getFileslist(files)

        casalog.post('List of other files in input directory')
        casalog.post(pp.pformat(nonmslist))
                    
    
        # Create an MMS for each MS in list
        for ms in mslist:
            casalog.post('Will create an MMS for '+ms)
            ret = self.runPartition(ms, self.mmsdir, self.axis, self.numsubms)
            if not ret:
                sys.exit(2)
            
            # Verify later if this is still needed
            time.sleep(10)
        
            casalog.origin('convertToMMS')
            casalog.post('--------------- Successfully created MMS -----------------')
                    
                
        # Copy non-MS files to MMS directory
        for nfile in nonmslist:
            bfile = os.path.basename(nfile)
            lfile = os.path.join(self.mmsdir, bfile)
            casalog.post('Copying non-MS file '+bfile)
#            os.symlink(file, lfile)
#            shutil.copytree(nfile, lfile, symlinks=False)
            os.system("cp -RL {0} {1}".format(nfile, lfile))
            

    def getMSlist(self, files):
        '''Get a list of MSs from a directory.
           files -> a tuple that is returned by the following call:
           files = os.walk(self.inpdir,followlinks=True).next() 
           
           It will test if a directory is an MS and will only return
           true MSs, that have Type:Measurement Set in table.info. It will skip
           directories that start with . and those that do not end with
           extension .ms.
           '''
        
        topdir = files[0]
        mslist = []
        
        # Loop through list of directories
        for d in files[1]:
            # Skip . entries
            if d.startswith('.'):
                continue
            
#            if not d.endswith('.ms'):
#                continue
            
            # Full path for directory
            mydir = os.path.join(topdir,d)
                        
            # It is probably an MS
            if self.isItMS(mydir) == 1:
                mslist.append(mydir)
        
        return mslist

    def isItMS(self, mydir):
        '''Check the type of a directory.
           mydir  --> full path of a directory.
                Returns 1 for an MS, 2 for a cal table and 3 for a MMS.
                If 0 is returned, it means any other type or an error.'''
                
        ret = 0
        
        # Listing of this directory
        ldir = os.listdir(mydir)
        
        if not ldir.__contains__('table.info'): 
            return ret
                
        cmd1 = 'grep Type '+mydir+'/table.info'
        cmd2 = 'grep SubType '+mydir+'/table.info'
        if is_CASA6:
            mytype = bytes2str(subprocess.check_output(cmd1)).rstrip("\n")
            stype = bytes2str(subprocess.check_output(cmd2)).rstrip("\n")
        else:
            mytype = commands.getoutput(cmd1)
            stype = commands.getoutput(cmd2)

        # It is a cal table
        if mytype.__contains__('Calibration'):
            ret = 2
        
        elif mytype.__contains__('Measurement'):
            # It is a Multi-MS
            if stype.__contains__('CONCATENATED'):
                # Further check
                if ldir.__contains__('SUBMSS'):            
                    ret = 3
            # It is an MS
            else:
                ret = 1
            
        return ret
                        

    def getFileslist(self, files):
        '''Get a list of non-MS files from a directory.
           files -> a tuple that is returned by the following call:
           files = os.walk(self.inpdir,followlinks=True).next() 
           
           It will return files and directories that are not MSs. It will skip
           files that start with .
           '''
                
        topdir = files[0]
        fileslist = []
        
        # Get other directories that are not MSs
        for d in files[1]:
            
            # Skip . entries
            if d.startswith('.'):
                continue
            
            # Skip MS directories
            if d.endswith('.ms'):
                continue
            
            # Full path for directory
            mydir = os.path.join(topdir,d)
            
            # It is not an MS
            if self.isItMS(mydir) != 1:
                fileslist.append(mydir)


        # Get non-directory files        
        for f in files[2]:
            # Skip . entries
            if f.startswith('.'):
                continue
            
            # Full path for file
            myfile = os.path.join(topdir, f)
            fileslist.append(myfile)
            
        return fileslist


    def runPartition(self, ms, mmsdir, axis, subms):
        '''Run partition with default values to create an MMS.
           ms         --> full pathname of the MS
           mmsdir     --> directory to save the MMS to
           axis      --> separationaxis to use (spw, scan, auto)
           subms  --> number of subMss to create

        '''
        try:
            # CASA 6
            from casatasks import partition
        except ImportError:
            # CASA 5
            from tasks import partition

        if not os.path.lexists(ms):
            return False
        
        # Create MMS name
#        bname = os.path.basename(ms)
#        if bname.endswith('.ms'):
#            mmsname = bname.replace('.ms','.mms')
#        else:
#            mmsname = bname+'.mms'
            
        # Create MMS with the same name of the MS, but in a different location
        MSBaseName = os.path.basename(ms)
        MMSFullName = os.path.join(self.mmsdir, MSBaseName)
        if os.path.lexists(MMSFullName):
            casalog.post('Output MMS already exist -->'+MMSFullName,'ERROR')
            return False
            
        casalog.post('Output MMS will be: '+MMSFullName)
        
#        mms = os.path.join(self.mmsdir, mmsname)
#        if os.path.lexists(mms):
#            casalog.post('Output MMS already exist -->'+mms,'ERROR')
#            return False
        
        # Check for remainings of corrupted mms
#        corrupted = mms.replace('.mms','.data')
        corrupted = MMSFullName + '.data'
        if os.path.exists(corrupted):
            casalog.post('Cleaning up left overs','WARN')
            shutil.rmtree(corrupted)
        
        # Run partition   
        partition(vis=ms, outputvis=MMSFullName, createmms=True, datacolumn='all', flagbackup=False,
                  separationaxis=axis, numsubms=subms)
        casalog.origin('convertToMMS')
        
        # Check if MMS was created
        if not os.path.exists(MMSFullName):
            casalog.post('Cannot create MMS ->'+MMSFullName, 'ERROR')
            return False
        
        # If requested, create a link to this MMS with the original MS name
#        if createlink:
#            here = os.getcwd()
#            os.chdir(mmsdir)
#            mmsname = os.path.basename(mms)
##            lms = mmsname.replace('.mms', '.ms')
#            casalog.post('Creating symbolic link to MMS')
##            os.symlink(mmsname, lms)
#            os.symlink(mmsname, bname)
#            os.chdir(here)
                
        return True
        
    def usage(self):
        casalog.post('=========================================================================')
        casalog.post('          convertToMMS will create a directory with multi-MSs.')
        casalog.post('Usage:\n')
        casalog.post('  import partitionhelper as ph')
        casalog.post('  ph.convertToMMS(inpdir=\'dir\') \n')
        casalog.post('Options:')
        casalog.post('   inpdir <dir>        directory with input MS.')
        casalog.post('   mmsdir <dir>        directory to save output MMS. If not given, it will save ')
        casalog.post('                       the MMS in a directory called mmsdir in the current directory.')
        casalog.post("   axis='auto'         separationaxis parameter of partition (spw,scan,auto).")
        casalog.post("   numsubms=4         number of subMSs to create in output MMS")
        casalog.post('   cleanup=False       if True it will remove the output directory before starting.\n')
        
        casalog.post(' NOTE: this script will run using the default values of partition. It will try to ')
        casalog.post(' create an MMS for every MS in the input directory. It will skip non-MS directories ')
        casalog.post(' such as cal tables. If partition succeeds, the script will create a link to every ')
        casalog.post(' other directory or file in the output directory. ')
        casalog.post(' The script will not walk through sub-directories of inpdir. It will also skip ')
        casalog.post(' files or directories that start with a .')
        casalog.post('==========================================================================')
        return
        
#
# -------------- HELPER functions for dealing with an MMS --------------
#
#    getMMSScans        'Get the list of scans of an MMS dictionary'
#    getScanList        'Get the list of scans of an MS or MMS'
#    getScanNrows       'Get the number of rows of a scan in a MS. It will add the 
#                         nrows of all sub-scans.'
#    getMMSScanNrows    'Get the number of rows of a scan in an MMS dictionary.'
#    getSpwIds          'Get the Spw IDs of a scan.'
#    getDiskUsage       'eturn the size in bytes of an MS in disk.'
#
# ----------------------------------------------------------------------

# def getNumberOf(msfile, item='row'):
#     '''Using the msmd tool, it gets the number of
#        scan, spw, antenna, baseline, field, state,
#        channel, row in a MS or MMS'''
#     
#     md = msmdtool()      # or msmd() in CASA 6
#     try:
#         md.open(msfile)
#     except:
#         casalog.post('Cannot open the msfile')
#         return 0
#     
#     if item == 'row':
#         numof = md.nrows()
#     elif item == 'scan':
#         numof = md.nscans()
#     elif item == 'spw':
#         numof = md.nspw()
#     elif item == 'antenna':
#         numof = md.nantennas()
#     elif item == 'baseline':
#         numof = md.nbaselines()
#     elif item == 'channel':
#         numof = md.nchan()
#     elif item == 'field':
#         numof = md.nfields()
#     elif item == 'state':
#         numof = md.nstates()
#     else:
#         numof = 0
#         
#     md.close()
#     return numof
    
      
# NOTE
# There is a bug in ms.getscansummary() that does not give the scans for all 
# observation Ids, but only for the last one. See CAS-4409
def getMMSScans(mmsdict):
    '''Get the list of scans of an MMS dictionary.
       mmsdict  --> output dictionary from listpartition(MMS,createdict=true)
       Return a list of the scans in this MMS. '''
    
    if not isinstance(mmsdict, dict):
        casalog.post('ERROR: Input is not a dictionary', 'ERROR')
        return []
    
    tkeys = mmsdict.keys()
    scanlist = []
    slist = set(scanlist)
    for k in tkeys:
        skeys = mmsdict[k]['scanId'].keys()
        for j in skeys:
            slist.add(j)
    
    return list(slist)
    
def getScanList(msfile, selection={}):
    '''Get the list of scans of an MS or MMS. 
       msfile     --> name of MS or MMS
       selection  --> dictionary with data selection
       
       Return a list of the scans in this MS/MMS. '''

    mst_local.open(msfile)
    if isinstance(selection, dict) and selection != {}:
        mst_local.msselect(items=selection)
        
    scand = mst_local.getscansummary()
    mst_local.close()
        
    scanlist = scand.keys()
    
    return scanlist
    
    
def getScanNrows(msfile, myscan, selection={}):
    '''Get the number of rows of a scan in a MS. It will add the nrows of all sub-scans.
       This will not take into account any selection done on the MS.
       msfile     --> name of the MS or MMS
       myscan     --> scan ID (int)
       selection  --> dictionary with data selection
       
       Return the number of rows in the scan.
       
       To compare with the dictionary returned by listpartition, do the following:
       
        resdict = listpartition('file.mms', createdict=True)
        slist = ph.getMMSScans(thisdict)
        for s in slist:
            mmsN = ph.getMMSScanNrows(thisdict, s)
            msN = ph.getScanNrows('referenceMS', s)
            assert (mmsN == msN)
    '''
    mst_local.open(msfile)
    if isinstance(selection, dict) and selection != {}:
        mst_local.msselect(items=selection)
        
    scand = mst_local.getscansummary()
    mst_local.close()
    
    Nrows = 0
    if not str(myscan) in scand:
        return Nrows
    
    subscans = scand[str(myscan)]
    for ii in subscans.keys():
        Nrows += scand[str(myscan)][ii]['nRow']
    
    return Nrows


def getMMSScanNrows(thisdict, myscan):
    '''Get the number of rows of a scan in an MMS dictionary.
       thisdict  --> output dictionary from listpartition(MMS,createdict=true)
       myscan    --> scan ID (int) 
       Return the number of rows in the given scan. '''
    
    if not isinstance(thisdict, dict):
        casalog.post('ERROR: Input is not a dictionary', 'ERROR')
        return -1
    
    tkeys = thisdict.keys()
    scanrows = 0
    for k in tkeys:
        if myscan in thisdict[k]['scanId']:
            scanrows += thisdict[k]['scanId'][myscan]['nrows']
        
    return scanrows
   

def getSpwIds(msfile, myscan, selection={}):
    '''Get the Spw IDs of a scan. 
       msfile     --> name of the MS or MMS
       myscan     --> scan Id (int)
       selection  --> dictionary with data selection
       
       Return a list with the Spw IDs. Note that the returned spw IDs are sorted.
                           
    '''
    import numpy as np
    
    mst_local.open(msfile)
    if isinstance(selection, dict) and selection != {}:
        mst_local.msselect(items=selection)

    scand = mst_local.getscansummary()
    mst_local.close()
    
    spwlist = []

    if not str(myscan) in scand:
        return spwlist
    
    subscans = scand[str(myscan)]
    aspws = np.array([],dtype=int)
    
    for ii in subscans.keys():
        sscanid = ii
        spwids = scand[str(myscan)][sscanid]['SpwIds']
        aspws = np.append(aspws,spwids)
    
    # Sort spws  and remove duplicates
    aspws.sort()
    uniquespws = np.unique(aspws)
    
    # Try to return a list
    spwlist = uniquespws.ravel().tolist()
    return spwlist


def getScanSpwSummary(mslist=[]):
    """ Get a consolidated dictionary with scan, spw, channel information
       of a list of MSs. It adds the nrows of all sub-scans of a scan.
       
       Keyword arguments:
       mslist    --> list with names of MSs
       
       Returns a dictionary such as:
       mylist=['subms1.ms','subms2.ms']
       outdict = getScanSpwSummary(mylist)
       outdict = {0: {'MS': 'subms1.ms',
                      'scanId': {30: {'nchans': array([64, 64]),
                                      'nrows': 544,
                                      'spwIds': array([ 0,  1])}},
                      'size': '214M'},
                  1: {'MS': 'ngc5921.ms',
                      'scanId': {1: {'nchans': array([63]),
                                     'nrows': 4509,
                                     'spwIds': array([0])},
                                 2: {'nchans': array([63]),
                                     'nrows': 1890,
                                     'spwIds': array([0])}},
                      'size': '72M'}}
    """
                     
    if mslist == []:
        return {}

    # Create lists for scan and spw dictionaries of each MS
    msscanlist = []
    msspwlist = []

    # List with sizes in bytes per sub-MS
    sizelist = []
    
    # Loop through all MSs
    for subms in mslist:
        try:
            mst_local.open(subms)
            scans = mst_local.getscansummary()
            msscanlist.append(scans)
            spws = mst_local.getspectralwindowinfo()
            msspwlist.append(spws)
        except Exception as exc:
            raise Exception('Cannot get scan/spw information from subMS: {0}'.format(exc))
        finally:
            mst_local.close()

        # Get the data volume in bytes per sub-MS
        sizelist.append(getDiskUsage(subms))

    # Get the information to list in output
    # Dictionary to return
    outdict = {}

    for ims in range(mslist.__len__()):   
        # Create temp dictionary for each sub-MS
        tempdict = {}
        msname = os.path.basename(mslist[ims])
        tempdict['MS'] = msname
        tempdict['size'] = sizelist[ims]
        
        # Get scan dictionary for this sub-MS
        scandict = msscanlist[ims]
        
        # Get spw dictionary for this sub-MS
        # NOTE: the keys of spwdict.keys() are NOT the spw Ids
        spwdict = msspwlist[ims]
        
        # The keys are the scan numbers
        scanlist = scandict.keys()
        
        # Get information per scan
        tempdict['scanId'] = {}
        for scan in scanlist:
            newscandict = {}
            subscanlist = scandict[scan].keys()
            
            # Get spws and nrows per sub-scan
            nrows = 0
            aspws = np.array([],dtype='int32')
            for subscan in subscanlist:
                nrows += scandict[scan][subscan]['nRow']

                # Get the spws for each sub-scan
                spwids = scandict[scan][subscan]['SpwIds']
                aspws = np.append(aspws,spwids)

            newscandict['nrows'] = nrows

            # Sort spws  and remove duplicates
            aspws.sort()
            uniquespws = np.unique(aspws)
            newscandict['spwIds'] = uniquespws
                            
            # Array to hold channels
            charray = np.empty_like(uniquespws)
            spwsize = np.size(uniquespws)
            
            # Now get the number of channels per spw
            for ind in range(spwsize):
                spwid = uniquespws[ind]
                for sid in spwdict.keys():
                    if spwdict[sid]['SpectralWindowId'] == spwid:
                        nchans = spwdict[sid]['NumChan']
                        charray[ind] = nchans
                        continue
                
            newscandict['nchans'] = charray
            tempdict['scanId'][int(scan)] = newscandict
                
            
        outdict[ims] = tempdict
    #casalog.post(pp.format(outdict))

    return outdict


def getMMSSpwIds(thisdict):
    '''Get the list of spws from an MMS dictionary.
       thisdict  --> output dictionary from listpartition(MMS,createdict=true)
       Return a list of the spw Ids in the dictionary. '''

    import numpy as np
    
    if not isinstance(thisdict, dict):
        casalog.post('ERROR: Input is not a dictionary', 'ERROR')
        return []
    
    tkeys = thisdict.keys()

    aspws = np.array([],dtype='int32')
    for k in tkeys:
        scanlist = thisdict[k]['scanId'].keys()
        for s in scanlist:
            spwids = thisdict[k]['scanId'][s]['spwIds']
            aspws = np.append(aspws, spwids)

    # Sort spws  and remove duplicates
    aspws.sort()
    uniquespws = np.unique(aspws)
    
    # Try to return a list
    spwlist = uniquespws.ravel().tolist()
        
    return spwlist

def getSubMSSpwIds(subms, thisdict):
    
    import numpy as np
    tkeys = thisdict.keys()
    aspws = np.array([],dtype='int32')
    mysubms = os.path.basename(subms)
    for k in tkeys:
        if thisdict[k]['MS'] == mysubms:
            # get the spwIds of this subMS
            scanlist = thisdict[k]['scanId'].keys()
            for s in scanlist:
                spwids = thisdict[k]['scanId'][s]['spwIds']
                aspws = np.append(aspws, spwids)
            break
                
    # Sort spws  and remove duplicates
    aspws.sort()
    uniquespws = np.unique(aspws)
    
    # Try to return a list
    spwlist = uniquespws.ravel().tolist()
    return spwlist

def getDiskUsage(msfile):
    """Return the size in bytes of an MS or MMS in disk.
    
    Keyword arguments:
       msfile  --> name of the MS
       This function will return a value given by the command du -hs
    """
    
    from subprocess import Popen, PIPE, STDOUT

    # Command line to run
    ducmd = 'du -hs {0}'.format(msfile)

    if is_CASA6:
        p = Popen(ducmd, shell=True, stdin=None, stdout=PIPE, stderr=STDOUT, close_fds=True)
        o, e = p.communicate()             ### previously 'sizeline = p.stdout.read()' here
                                           ### left process running...
        sizeline = bytes2str(o.split( )[0])
    else:
        p = Popen(ducmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True)
        sizeline = p.stdout.read()
        _out, _err = p.communicate()

    # Create a list of the output string, which looks like this:
    # ' 75M\tuidScan23.data/uidScan23.0000.ms\n'
    # This will create a list with [size,sub-ms]
    mssize = sizeline.split()

    return mssize[0]


def getSubtables(vis):
    theSubTables = []
    tbt_local.open(vis)
    myKeyw = tbt_local.getkeywords()
    tbt_local.close()
    for k in myKeyw.keys():
        theKeyw = myKeyw[k]
        if (type(theKeyw)==str and theKeyw.split(' ')[0]=='Table:'
            and not k=='SORTED_TABLE'):
            theSubTables.append(os.path.basename(theKeyw.split(' ')[1]))
            
    return theSubTables


def makeMMS(outputvis, submslist, copysubtables=False, omitsubtables=[], parallelaxis=''):
    """Create a Multi-MS from a list of MSs
    
    Keyword arguments:
        outputvis        -- name of the output MMS
        submslist        -- list of input subMSs to create the output from
        copysubtables    -- True will copy the sub-tables from the first subMS to the others in the
                            output MMS. Default to False.
        omitsubtables    -- List of sub-tables to omit when copying to output MMS. They will be linked instead
        parallelasxis    -- Optionally, set the value to be written to AxisType in table.info of the output MMS
                            Usually this value comes from the separationaxis keyword of partition or mstransform.
                            
          Be AWARE that this function will remove the tables listed in submslist.
    """

    if os.path.exists(outputvis):
        raise ValueError('Output MS already exists')

    if len(submslist)==0:
        raise ValueError('No SubMSs given')

    ## make an MMS with all sub-MSs contained in a SUBMSS subdirectory
    origpath = os.getcwd()
    
    try:
        try:
            mst_local.createmultims(outputvis,
                                    submslist,
                                    [],
                                    True,  # nomodify
                                    False, # lock
                                    copysubtables,
                                    omitsubtables
            ) # when copying the subtables, omit these

        except Exception:
            raise
        finally:
            mst_local.close()
        
        # remove the SORTED_TABLE keywords because the sorting is not reliable after partitioning
        try:
            tbt_local.open(outputvis, nomodify=False)
            if 'SORTED_TABLE' in tbt_local.keywordnames():
                tbt_local.removekeyword('SORTED_TABLE')
                tbt_local.close()

            for thesubms in submslist:
                tbt_local.open(outputvis+'/SUBMSS/'+os.path.basename(thesubms), nomodify=False)
                if 'SORTED_TABLE' in tbt_local.keywordnames():
                    tobedel = tbt_local.getkeyword('SORTED_TABLE').split(' ')[1]
                    tbt_local.removekeyword('SORTED_TABLE')
                    os.system('rm -rf '+tobedel)
                tbt_local.close()
        except Exception:
            tbt_local.close()
            raise
            
        # Create symbolic links to the subtables of the first SubMS in the reference MS (top one)
        os.chdir(outputvis)
        mastersubms = os.path.basename(submslist[0].rstrip('/'))
        thesubtables = getSubtables('SUBMSS/'+mastersubms)
        
        for s in thesubtables:
            os.symlink('SUBMSS/'+mastersubms+'/'+s, s)

        os.chdir('SUBMSS/'+mastersubms)
                
        # Remove the SOURCE and HISTORY tables, which should not be linked
        thesubtables.remove('SOURCE')
        thesubtables.remove('HISTORY')

        # Create sym links to all sub-tables in all subMSs
        for i in range(1,len(submslist)):
            thesubms = os.path.basename(submslist[i].rstrip('/'))
            os.chdir('../'+thesubms)
            
            for s in thesubtables:
                os.system('rm -rf '+s)
                os.symlink('../'+mastersubms+'/'+s, s)
                            
        # Write the AxisType info in the MMS
        if parallelaxis != '':
            setAxisType(outputvis, parallelaxis)

    except Exception as exc:
        os.chdir(origpath)
        raise ValueError('Problem in MMS creation: {0}'.format(exc))

    os.chdir(origpath)

    return True

def axisType(mmsname):
    """Get the axisType information from a Multi-MS. The AxisType information
       is usually added for Multi-MS with the axis which data is parallelized across.
       
       Keyword arguments:
           mmsname    --    name of the Multi-MS

        It returns the value of AxisType or an empty string if it doesn't exist.
    """
    
    axis = ''

    try:
        tbt_local.open(mmsname, nomodify=True)
        tbinfo = tbt_local.info()
    except Exception as exc:
        raise ValueError('Unable to open table {0}. Exception: {1}'.format(mmsname, exc))
    finally:
        tbt_local.close()
    
    if 'readme' in tbinfo:
        readme = tbinfo['readme']
        readlist = readme.splitlines()
        for val in readlist:
            if val.__contains__('AxisType'):
                a,b,axis = val.partition('=')
                
    return axis.strip()

def setAxisType(mmsname, axis=''):
    """Set the AxisType keyword in a Multi-MS info. If AxisType already
       exists, it will be overwritten.
    
    Keyword arguments:
        mmsname    --    name of the Multi-MS
        axis       --    parallel axis of the Multi-MS. Options: scan; spw or scan,spw
        
        Return True on success, False otherwise.
    """
    
    import copy

    if axis == '':
        raise ValueError('Axis value cannot be empty')
    
    try:
        tbt_local.open(mmsname)
        tbinfo = tbt_local.info()
    except Exception as exc:
        raise ValueError('Unable to open table {0}. Exception: {1}'.format(mmsname, exc))
    finally:
        tbt_local.close()
    
    readme = ''
    # Save original readme
    if 'readme' in tbinfo:
        readme = tbinfo['readme']

    # Check if AxisType already exist and remove it
    if axisType(mmsname) != '':
        casalog.post('WARN: Will overwrite the existing AxisType value', 'WARN')
        readlist = readme.splitlines()
        newlist = copy.deepcopy(readlist)
        for val in newlist:
            if val.__contains__('AxisType'):
                readlist.remove(val)

        # Recreate the string
        nr = ''
        for val in readlist:
            nr = nr + val + '\n'

        readme = nr.rstrip()
        
        
    # Preset for axis info
    axisInfo = "AxisType = "
    axis.rstrip()
    axisInfo = axisInfo + axis + '\n'
    
    # New readme
    newReadme = axisInfo + readme
    
    # Create readme record
    readmerec = {'readme':newReadme}

    try:
        tbt_local.open(mmsname, nomodify=False)
        tbt_local.putinfo(readmerec)
    except Exception as exc:
        raise ValueError('Unable to put readme info into table {0}. Exception: {1}'.
                         format(mmsname, exc))
    finally:
        tbt_local.close()
    
    # Check if the axis was correctly added
    check_axis = axisType(mmsname)

    if check_axis != axis:
        return False
 
    return True

def buildScanDDIMap(scanSummary, ddIspectralWindowInfo):
    """
    Builds a scan->DDI map and 3 list of # visibilities per DDI, scan, field

    :param scanSummary: scan summary dictionary as produced by the mstool (getscansummary)
    :param ddiSpectralWindowInfo: SPW info dictionary as produced by the mstool
                                  (getspectralwindowinfo())
    :returns: a dict with a scan->ddi map, and three dict with # of visibilities per
              ddi, scan, and field.
    """
    # Make an array for total number of visibilites per ddi and scan separatelly
    nVisPerDDI = {}
    nVisPerScan = {}
    nVisPerField = {}

    # Iterate over scan list
    scanDdiMap = {}
    for scan in sorted(scanSummary):
        # Initialize scan sub-map
        scanDdiMap[scan] = {}
        # Iterate over timestamps for this scan
        for timestamp in scanSummary[scan]:
            # Get list of ddis for this timestamp
            DDIds = scanSummary[scan][timestamp]['DDIds']
            fieldId = str(scanSummary[scan][timestamp]['FieldId'])
            # Get number of rows per ddi (assume all DDIs have the same number of rows)
            # In ALMA data WVR DDI has only one row per antenna but it is separated from the other DDIs
            nrowsPerDDI = scanSummary[scan][timestamp]['nRow'] / len(DDIds)
            # Iterate over DDIs for this timestamp
            for ddi in DDIds:
                # Convert to string to be used as a map key
                ddi = str(ddi)
                # Check if DDI entry is already present for this scan, otherwise initialize it
                if ddi not in scanDdiMap[scan]:
                    scanDdiMap[scan][ddi] = {}
                    scanDdiMap[scan][ddi]['nVis'] = 0
                    scanDdiMap[scan][ddi]['fieldId'] = fieldId
                    scanDdiMap[scan][ddi]['isWVR'] = ddIspectralWindowInfo[ddi]['isWVR']
                # Calculate number of visibilities
                nvis = nrowsPerDDI*ddIspectralWindowInfo[ddi]['NumChan']*ddIspectralWindowInfo[ddi]['NumCorr']
                # Add number of rows and vis from this timestamp
                scanDdiMap[scan][ddi]['nVis'] = scanDdiMap[scan][ddi]['nVis'] + nvis
                # Update ddi nvis
                if ddi not in nVisPerDDI:
                    nVisPerDDI[ddi] = nvis
                else:
                    nVisPerDDI[ddi] = nVisPerDDI[ddi] + nvis
                # Update scan nvis
                if scan not in nVisPerScan:
                    nVisPerScan[scan] = nvis
                else:
                    nVisPerScan[scan] = nVisPerScan[scan] + nvis
                # Update field nvis
                if fieldId not in nVisPerField:
                    nVisPerField[fieldId] = nvis
                else:
                    nVisPerField[fieldId] = nVisPerField[fieldId] + nvis

    return scanDdiMap, nVisPerDDI, nVisPerScan, nVisPerField

def getPartitionMap(msfilename, nsubms, selection={}, axis=['field','spw','scan'],plotMode=0):
    """Generates a partition scan/spw map to obtain optimal load balancing with the following criteria:

       1st - Maximize the scan/spw/field distribution across sub-MSs
       2nd - Generate sub-MSs with similar size

       In order to balance better the size of the subMSs the allocation process
       iterates over the scan,spw pairs in descending number of visibilities.

       That is larger chunks are allocated first, and smaller chunks at the final
       stages so that they can be used to balance the load in a stable way

    Keyword arguments:
        msname          --    Input MS filename
        nsubms          --    Number of subMSs
        selection       --    Data selection dictionary
        axis            --    Vector of strings containing the axis for load distribution (scan,spw,field)
        plotMode        --    Integer in the range 0-3 to determine the plot generation mode
                                0 - Don't generate any plots
                                1 - Show plots but don't save them
                                2 - Save plots but don't show them
                                3 - Show and save plots

        Returns a map of the sub-MSs with the corresponding scan/spw selections and the number of visibilities
    """

    # Open ms tool
    mst_local.open(msfilename)

    # Apply data selection
    if isinstance(selection, dict) and selection != {}:
        mst_local.msselect(items=selection)

    # Get list of DDIs and timestamps per scan
    scanSummary = mst_local.getscansummary()
    ddIspectralWindowInfo = mst_local.getspectralwindowinfo()

    # Close ms tool
    mst_local.close()

    # Get list of WVR SPWs using the ms metadata tool
    msmdt_local.open(msfilename)
    wvrspws = msmdt_local.wvrspws()
    msmdt_local.close()

    # Mark WVR DDIs as identified by the ms metadata tool
    for ddi in ddIspectralWindowInfo:
        if ddIspectralWindowInfo[ddi] in wvrspws:
            ddIspectralWindowInfo[ddi]['isWVR'] = True
        else:
            ddIspectralWindowInfo[ddi]['isWVR'] = False

    scanDdiMap, nVisPerDDI, nVisPerScan, nVisPerField = buildScanDDIMap(scanSummary,
                                                                        ddIspectralWindowInfo)

    # Sort the scan/ddi pairs depending on the number of visibilities
    ddiList = list()
    scanList = list()
    fieldList = list()
    nVisList = list()
    nScanDDIPairs = 0
    for scan in scanDdiMap:
        for ddi in scanDdiMap[scan]:
            ddiList.append(ddi)
            scanList.append(scan)
            fieldList.append(scanDdiMap[scan][ddi]['fieldId'])
            nVisList.append(scanDdiMap[scan][ddi]['nVis'])
            nScanDDIPairs += 1
            
                    
    # Check that the number of available scan/ddi pairs is not greater than the number of subMSs
    if nsubms > nScanDDIPairs:
        casalog.post("Number of subMSs (%i) is greater than available scan,ddi pairs (%i), setting nsubms to %i" 
                     % (nsubms,nScanDDIPairs,nScanDDIPairs),"WARN","getPartitionMap")
        nsubms = nScanDDIPairs            
    
    ddiArray = np.array(ddiList)
    scanArray = np.array(scanList)
    nVisArray = np.array(nVisList)

    nVisSortIndex = np.lexsort((ddiArray, scanArray, nVisArray))
    # argsort/lexsort return indices by increasing value. This reverses the indices by
    # decreasing value
    nVisSortIndex[:] = nVisSortIndex[::-1]
    
    ddiArray = ddiArray[nVisSortIndex]
    scanArray = scanArray[nVisSortIndex]
    nVisArray = nVisArray[nVisSortIndex]
            
    # Make a map for the contribution of each subMS to each scan
    scanNvisDistributionPerSubMs = {}
    for scan in scanSummary:
        scanNvisDistributionPerSubMs[scan] = np.zeros(nsubms)
     
     
    # Make a map for the contribution of each subMS to each ddi   
    ddiNvisDistributionPerSubMs = {}
    for ddi in ddIspectralWindowInfo:
        ddiNvisDistributionPerSubMs[ddi] = np.zeros(nsubms)
        
        
    # Make a map for the contribution of each subMS to each field   
    fieldList = np.unique(fieldList)
    fieldNvisDistributionPerSubMs = {}
    for field in fieldList:
        fieldNvisDistributionPerSubMs[field] = np.zeros(nsubms)
        
        
    # Make an array for total number of visibilites per subms
    nvisPerSubMs = np.zeros(nsubms)     
        
        
    # Initialize final map of scans/pw pairs per subms
    submScanDdiMap = {}
    for subms in range (0,nsubms):
        submScanDdiMap[subms] = {}
        submScanDdiMap[subms]['scanList'] = list()
        submScanDdiMap[subms]['ddiList'] = list()
        submScanDdiMap[subms]['fieldList'] = list()
        submScanDdiMap[subms]['nVisList'] = list()
        submScanDdiMap[subms]['nVisTotal'] = 0
        
        
    # Iterate over the scan/ddi map and assign each pair to a subMS
    for pair in range(len(ddiArray)):
        
        ddi = ddiArray[pair]
        scan = scanArray[pair]
        field = scanDdiMap[scan][ddi]['fieldId']
                   
        # Select the subMS that with bigger (scan/ddi/field gap)
        # We use the average as a refLevel to include global structure information
        # But we also take into account the actual max value in case we are distributing large uneven chunks
        jointNvisGap = np.zeros(nsubms)
        if 'scan' in axis:
            refLevel = max(nVisPerScan[scan] //
                           nsubms,scanNvisDistributionPerSubMs[scan].max())
            jointNvisGap = jointNvisGap + refLevel - scanNvisDistributionPerSubMs[scan]
        if 'spw' in axis:
            refLevel = max(nVisPerDDI[ddi] //
                           nsubms,ddiNvisDistributionPerSubMs[ddi].max())
            jointNvisGap = jointNvisGap + refLevel - ddiNvisDistributionPerSubMs[ddi]
        if 'field' in axis:
            refLevel = max(nVisPerField[field] //
                           nsubms,fieldNvisDistributionPerSubMs[field].max())
            jointNvisGap = jointNvisGap + refLevel - fieldNvisDistributionPerSubMs[field]
            
        optimalSubMs = np.where(jointNvisGap == jointNvisGap.max())
        optimalSubMs = optimalSubMs[0] # np.where returns a tuple
        
        # In case of multiple candidates select the subms with minum number of total visibilities
        if len(optimalSubMs) > 1:
            subIdx = np.argmin(nvisPerSubMs[optimalSubMs])
            optimalSubMs = optimalSubMs[subIdx]
        else:
            optimalSubMs = optimalSubMs[0]        
                
        # Store the scan/ddi pair info in the selected optimal subms
        nVis = scanDdiMap[scan][ddi]['nVis']
        nvisPerSubMs[optimalSubMs] = nvisPerSubMs[optimalSubMs] + nVis
        submScanDdiMap[optimalSubMs]['scanList'].append(int(scan))
        submScanDdiMap[optimalSubMs]['ddiList'].append(int(ddi))
        submScanDdiMap[optimalSubMs]['fieldList'].append(field)
        submScanDdiMap[optimalSubMs]['nVisList'].append(nVis)
        submScanDdiMap[optimalSubMs]['nVisTotal'] = submScanDdiMap[optimalSubMs]['nVisTotal'] + nVis
        
        # Also update the counters for the subms-scan and subms-ddi maps            
        scanNvisDistributionPerSubMs[scan][optimalSubMs] = scanNvisDistributionPerSubMs[scan][optimalSubMs] + nVis
        ddiNvisDistributionPerSubMs[ddi][optimalSubMs] = ddiNvisDistributionPerSubMs[ddi][optimalSubMs] + nVis
        fieldNvisDistributionPerSubMs[field][optimalSubMs] = fieldNvisDistributionPerSubMs[field][optimalSubMs] + nVis
            

    # Generate plots
    if plotMode > 0:
        plt.close()
        plotname_prefix = os.path.basename(msfilename) + ' axis ' + string.join(axis)
        plotVisDistribution(nVisPerScan,scanNvisDistributionPerSubMs,plotname_prefix,'scan',plotMode=plotMode)
        plotVisDistribution(nVisPerDDI,ddiNvisDistributionPerSubMs,plotname_prefix,'ddi',plotMode=plotMode)
        plotVisDistribution(nVisPerField,fieldNvisDistributionPerSubMs,plotname_prefix,'field',plotMode=plotMode)
            
               
    # Generate list of taql commands
    for subms in submScanDdiMap:
        # Initialize taql command
        from collections import defaultdict
        dmytaql = defaultdict(list)

        for pair in range(len(submScanDdiMap[subms]['scanList'])):
            # Get scan/ddi for this pair
            ddi = submScanDdiMap[subms]['ddiList'][pair]
            scan = submScanDdiMap[subms]['scanList'][pair]
            dmytaql[ddi].append(scan)

        mytaql = []
        for ddi, scans in dmytaql.items():
            scansel = '[' + ', '.join([str(x) for x in scans]) + ']'
            mytaql.append(('(DATA_DESC_ID==%i && (SCAN_NUMBER IN %s))') % (ddi, scansel))

        mytaql = ' OR '.join(mytaql)

        # Store taql
        submScanDdiMap[subms]['taql'] = mytaql


    # Return map of scan/ddi pairs per subMs
    return submScanDdiMap


def plotVisDistribution(nvisMap,idNvisDistributionPerSubMs,filename,idLabel,plotMode=1):
    """Generates a plot to show the distribution of scans/wp across subMs.
    The plot style is a stacked bar char, where the spw/scans with higher number of visibilities are shown at the bottom
    
    Keyword arguments:
        nvisMap                          --    Map of total numbe of visibilities per Id
        idNvisDistributionPerSubMs       --    Map of visibilities per subMS for each Id
        filename                         --    Name of MS to be shown in the title and plot filename
        idLabel                          --    idLabel to indicate the id (spw, scan) to be used for the figure title
        plotMode                         --    Integer in the range 0-3 to determine the plot generation mode
                                                 0 - Don't generate any plots
                                                 1 - Show plots but don't save them
                                                 2 - Save plots but don't show them
                                                 2 - Show and save plots
    """
    
    # Create a new figure
    plt.ioff()
    
    
    # If plot is not to be shown then use pre-define sized figure to 1585x1170 pizels with 75 DPI
    # (we cannot maximize the window to the screen size)
    if plotMode==2:
        plt.figure(figsize=(21.13,15.6),dpi=75) # Size is given in inches
    else:
        plt.figure()
    
    
    # Sort the id according to the total number of visibilities to that we can
    # represent bigger the groups at the bottom and the smaller ones at the top
    idx = 0
    idArray = np.zeros(len(nvisMap))
    idNvisArray = np.zeros(len(nvisMap))
    for id in nvisMap:
        idArray[idx] = int(id)
        idNvisArray[idx] = nvisMap[id]
        idx  = idx + 1
    
    idArraySortIndex = np.argsort(idNvisArray)
    idArraySortIndex[:] = idArraySortIndex[::-1]
    idArraySorted = idArray[idArraySortIndex]
    
    
    # Initialize color vector to alternate cold/warm colors
    nid = len(nvisMap)
    colorVector = list()
    colorRange = range(nid)
    colorVectorEven = colorRange[::2]
    colorVectorOdd = colorRange[1::2]
    colorVectorOdd.reverse()
    while len(colorVectorOdd) > 0 or len(colorVectorEven) > 0:
        if len(colorVectorOdd) > 0: colorVector.append(colorVectorOdd.pop())
        if len(colorVectorEven) > 0: colorVector.append(colorVectorEven.pop())
    
    
    # Generate stacked bar plot
    coloridx = 0 # color index
    width = 0.35 # bar width
    nsubms = len(idNvisDistributionPerSubMs[idNvisDistributionPerSubMs.keys()[0]])
    idx = np.arange(nsubms) # location of the bar centers in the horizontal axis
    bottomLevel = np.zeros(nsubms) # Reference level for the bars to be stacked after the previous ones
    legendidLabels = list() # List of legend idLabels
    plotHandles=list() # List of plot handles for the legend
    for id in idArraySorted:
        
        id = str(int(id))
        
        idplot = plt.bar(idx, idNvisDistributionPerSubMs[id], width, bottom=bottomLevel, color=plt.cm.Paired(1.*colorVector[coloridx]/nid))
        
        # Update color index 
        coloridx = coloridx + 1
                
        # Update legend lists
        plotHandles.append(idplot)
        legendidLabels.append(idLabel + ' ' + id)
        
        # Update reference level
        bottomLevel = bottomLevel + idNvisDistributionPerSubMs[id]

        
    # Add legend
    plt.legend( plotHandles, legendidLabels, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
    
    
    # AQdd lable for y axis
    plt.ylabel('nVis')
    
    
    # Add x-ticks 
    xticks = list()
    for subms in range(0,nsubms):
        xticks.append('subMS-' + str(subms))       
    plt.xticks(idx+width/2., xticks )
    
    
    # Add title
    title = filename + ' distribution of ' + idLabel + ' visibilities across sub-MSs'
    plt.title(title)
    
    
    # Resize to full screen
    if plotMode==1 or plotMode==3:
        mng = plt.get_current_fig_manager()
        mng.resize(*mng.window.maxsize())
    
    
    # Show figure
    if plotMode==1 or plotMode==3:
        plt.ion()
        plt.show()
    
    
    # Save plot
    if plotMode>1:
        title = title.replace(' ','-') + '.png'
        plt.savefig(title)
        
        
    # If plot is not to be shown then close it
    if plotMode==2:
        plt.close()