import os
import sys
import shutil
import pprint as pp
import traceback
import time
import numpy as np
from matplotlib import pyplot as plt
from casatasks.private.casa_transition import is_CASA6
if is_CASA6:
from casatasks import casalog
from casatools import table, ms, msmetadata
from casatools.platform import bytes2str
import subprocess
mst_local = ms()
tbt_local = table()
msmdt_local = msmetadata()
else:
from __main__ import *
from taskinit import *
import commands
mst_local = mstool()
tbt_local = tbtool()
msmdt_local = msmdtool()
class convertToMMS():
def __init__(self,\
inpdir=None, \
mmsdir=None, \
axis='auto', \
numsubms=4,
# createmslink=False, \
cleanup=False):
'''Run the partition task to create MMSs from a directory with MSs'''
casalog.origin('convertToMMS')
self.inpdir = inpdir
self.outdir = mmsdir
self.axis = axis
self.numsubms = numsubms
# self.createmslink = createmslink
self.mmsdir = '/tmp/mmsdir'
self.cleanup = cleanup
# Input directory is mandatory
if self.inpdir is None:
casalog.post('You must give an input directory to this script')
self.usage()
return
if not os.path.exists(self.inpdir):
casalog.post('Input directory inpdir does not exist -> '+self.inpdir,'ERROR')
self.usage()
return
if not os.path.isdir(self.inpdir):
casalog.post('Value of inpdir is not a directory -> '+self.inpdir,'ERROR')
self.usage()
return
# Only work with absolute paths
self.inpdir = os.path.abspath(self.inpdir)
casalog.post('Will read input MS from '+self.inpdir)
# Verify output directory
if self.outdir is None:
self.mmsdir = os.path.join(os.getcwd(),'mmsdir')
elif self.outdir == '/':
casalog.post('inpdir is set to root!', 'WARN')
self.mmsdir = os.path.join(os.getcwd(),'mmsdir')
else:
self.outdir = os.path.abspath(self.outdir)
self.mmsdir = self.outdir
if self.mmsdir == self.inpdir:
casalog.post('Output directory cannot be same of input directory','ERROR')
return
# Cleanup output directory
if self.cleanup:
casalog.post('Cleaning up output directory '+self.mmsdir)
if os.path.isdir(self.mmsdir):
shutil.rmtree(self.mmsdir)
if not os.path.exists(self.mmsdir):
os.makedirs(self.mmsdir)
casalog.post('Will save output MMS to '+self.mmsdir)
# Walk through input directory
files = os.walk(self.inpdir,followlinks=True).next()
# Get MS list
mslist = []
mslist = self.getMSlist(files)
casalog.post('List of MSs in input directory')
casalog.post(pp.pformat(mslist))
# Get non-MS directories and other files
nonmslist = []
nonmslist = self.getFileslist(files)
casalog.post('List of other files in input directory')
casalog.post(pp.pformat(nonmslist))
# Create an MMS for each MS in list
for ms in mslist:
casalog.post('Will create an MMS for '+ms)
ret = self.runPartition(ms, self.mmsdir, self.axis, self.numsubms)
if not ret:
sys.exit(2)
# Verify later if this is still needed
time.sleep(10)
casalog.origin('convertToMMS')
casalog.post('--------------- Successfully created MMS -----------------')
# Copy non-MS files to MMS directory
for nfile in nonmslist:
bfile = os.path.basename(nfile)
lfile = os.path.join(self.mmsdir, bfile)
casalog.post('Copying non-MS file '+bfile)
# os.symlink(file, lfile)
# shutil.copytree(nfile, lfile, symlinks=False)
os.system("cp -RL {0} {1}".format(nfile, lfile))
def getMSlist(self, files):
'''Get a list of MSs from a directory.
files -> a tuple that is returned by the following call:
files = os.walk(self.inpdir,followlinks=True).next()
It will test if a directory is an MS and will only return
true MSs, that have Type:Measurement Set in table.info. It will skip
directories that start with . and those that do not end with
extension .ms.
'''
topdir = files[0]
mslist = []
# Loop through list of directories
for d in files[1]:
# Skip . entries
if d.startswith('.'):
continue
# if not d.endswith('.ms'):
# continue
# Full path for directory
mydir = os.path.join(topdir,d)
# It is probably an MS
if self.isItMS(mydir) == 1:
mslist.append(mydir)
return mslist
def isItMS(self, mydir):
'''Check the type of a directory.
mydir --> full path of a directory.
Returns 1 for an MS, 2 for a cal table and 3 for a MMS.
If 0 is returned, it means any other type or an error.'''
ret = 0
# Listing of this directory
ldir = os.listdir(mydir)
if not ldir.__contains__('table.info'):
return ret
cmd1 = 'grep Type '+mydir+'/table.info'
cmd2 = 'grep SubType '+mydir+'/table.info'
if is_CASA6:
mytype = bytes2str(subprocess.check_output(cmd1)).rstrip("\n")
stype = bytes2str(subprocess.check_output(cmd2)).rstrip("\n")
else:
mytype = commands.getoutput(cmd1)
stype = commands.getoutput(cmd2)
# It is a cal table
if mytype.__contains__('Calibration'):
ret = 2
elif mytype.__contains__('Measurement'):
# It is a Multi-MS
if stype.__contains__('CONCATENATED'):
# Further check
if ldir.__contains__('SUBMSS'):
ret = 3
# It is an MS
else:
ret = 1
return ret
def getFileslist(self, files):
'''Get a list of non-MS files from a directory.
files -> a tuple that is returned by the following call:
files = os.walk(self.inpdir,followlinks=True).next()
It will return files and directories that are not MSs. It will skip
files that start with .
'''
topdir = files[0]
fileslist = []
# Get other directories that are not MSs
for d in files[1]:
# Skip . entries
if d.startswith('.'):
continue
# Skip MS directories
if d.endswith('.ms'):
continue
# Full path for directory
mydir = os.path.join(topdir,d)
# It is not an MS
if self.isItMS(mydir) != 1:
fileslist.append(mydir)
# Get non-directory files
for f in files[2]:
# Skip . entries
if f.startswith('.'):
continue
# Full path for file
myfile = os.path.join(topdir, f)
fileslist.append(myfile)
return fileslist
def runPartition(self, ms, mmsdir, axis, subms):
'''Run partition with default values to create an MMS.
ms --> full pathname of the MS
mmsdir --> directory to save the MMS to
axis --> separationaxis to use (spw, scan, auto)
subms --> number of subMss to create
'''
try:
# CASA 6
from casatasks import partition
except ImportError:
# CASA 5
from tasks import partition
if not os.path.lexists(ms):
return False
# Create MMS name
# bname = os.path.basename(ms)
# if bname.endswith('.ms'):
# mmsname = bname.replace('.ms','.mms')
# else:
# mmsname = bname+'.mms'
# Create MMS with the same name of the MS, but in a different location
MSBaseName = os.path.basename(ms)
MMSFullName = os.path.join(self.mmsdir, MSBaseName)
if os.path.lexists(MMSFullName):
casalog.post('Output MMS already exist -->'+MMSFullName,'ERROR')
return False
casalog.post('Output MMS will be: '+MMSFullName)
# mms = os.path.join(self.mmsdir, mmsname)
# if os.path.lexists(mms):
# casalog.post('Output MMS already exist -->'+mms,'ERROR')
# return False
# Check for remainings of corrupted mms
# corrupted = mms.replace('.mms','.data')
corrupted = MMSFullName + '.data'
if os.path.exists(corrupted):
casalog.post('Cleaning up left overs','WARN')
shutil.rmtree(corrupted)
# Run partition
partition(vis=ms, outputvis=MMSFullName, createmms=True, datacolumn='all', flagbackup=False,
separationaxis=axis, numsubms=subms)
casalog.origin('convertToMMS')
# Check if MMS was created
if not os.path.exists(MMSFullName):
casalog.post('Cannot create MMS ->'+MMSFullName, 'ERROR')
return False
# If requested, create a link to this MMS with the original MS name
# if createlink:
# here = os.getcwd()
# os.chdir(mmsdir)
# mmsname = os.path.basename(mms)
## lms = mmsname.replace('.mms', '.ms')
# casalog.post('Creating symbolic link to MMS')
## os.symlink(mmsname, lms)
# os.symlink(mmsname, bname)
# os.chdir(here)
return True
def usage(self):
casalog.post('=========================================================================')
casalog.post(' convertToMMS will create a directory with multi-MSs.')
casalog.post('Usage:\n')
casalog.post(' import partitionhelper as ph')
casalog.post(' ph.convertToMMS(inpdir=\'dir\') \n')
casalog.post('Options:')
casalog.post(' inpdir
directory with input MS.')
casalog.post(' mmsdir directory to save output MMS. If not given, it will save ')
casalog.post(' the MMS in a directory called mmsdir in the current directory.')
casalog.post(" axis='auto' separationaxis parameter of partition (spw,scan,auto).")
casalog.post(" numsubms=4 number of subMSs to create in output MMS")
casalog.post(' cleanup=False if True it will remove the output directory before starting.\n')
casalog.post(' NOTE: this script will run using the default values of partition. It will try to ')
casalog.post(' create an MMS for every MS in the input directory. It will skip non-MS directories ')
casalog.post(' such as cal tables. If partition succeeds, the script will create a link to every ')
casalog.post(' other directory or file in the output directory. ')
casalog.post(' The script will not walk through sub-directories of inpdir. It will also skip ')
casalog.post(' files or directories that start with a .')
casalog.post('==========================================================================')
return
#
# -------------- HELPER functions for dealing with an MMS --------------
#
# getMMSScans 'Get the list of scans of an MMS dictionary'
# getScanList 'Get the list of scans of an MS or MMS'
# getScanNrows 'Get the number of rows of a scan in a MS. It will add the
# nrows of all sub-scans.'
# getMMSScanNrows 'Get the number of rows of a scan in an MMS dictionary.'
# getSpwIds 'Get the Spw IDs of a scan.'
# getDiskUsage 'eturn the size in bytes of an MS in disk.'
#
# ----------------------------------------------------------------------
# def getNumberOf(msfile, item='row'):
# '''Using the msmd tool, it gets the number of
# scan, spw, antenna, baseline, field, state,
# channel, row in a MS or MMS'''
#
# md = msmdtool() # or msmd() in CASA 6
# try:
# md.open(msfile)
# except:
# casalog.post('Cannot open the msfile')
# return 0
#
# if item == 'row':
# numof = md.nrows()
# elif item == 'scan':
# numof = md.nscans()
# elif item == 'spw':
# numof = md.nspw()
# elif item == 'antenna':
# numof = md.nantennas()
# elif item == 'baseline':
# numof = md.nbaselines()
# elif item == 'channel':
# numof = md.nchan()
# elif item == 'field':
# numof = md.nfields()
# elif item == 'state':
# numof = md.nstates()
# else:
# numof = 0
#
# md.close()
# return numof
# NOTE
# There is a bug in ms.getscansummary() that does not give the scans for all
# observation Ids, but only for the last one. See CAS-4409
def getMMSScans(mmsdict):
'''Get the list of scans of an MMS dictionary.
mmsdict --> output dictionary from listpartition(MMS,createdict=true)
Return a list of the scans in this MMS. '''
if not isinstance(mmsdict, dict):
casalog.post('ERROR: Input is not a dictionary', 'ERROR')
return []
tkeys = mmsdict.keys()
scanlist = []
slist = set(scanlist)
for k in tkeys:
skeys = mmsdict[k]['scanId'].keys()
for j in skeys:
slist.add(j)
return list(slist)
def getScanList(msfile, selection={}):
'''Get the list of scans of an MS or MMS.
msfile --> name of MS or MMS
selection --> dictionary with data selection
Return a list of the scans in this MS/MMS. '''
mst_local.open(msfile)
if isinstance(selection, dict) and selection != {}:
mst_local.msselect(items=selection)
scand = mst_local.getscansummary()
mst_local.close()
scanlist = scand.keys()
return scanlist
def getScanNrows(msfile, myscan, selection={}):
'''Get the number of rows of a scan in a MS. It will add the nrows of all sub-scans.
This will not take into account any selection done on the MS.
msfile --> name of the MS or MMS
myscan --> scan ID (int)
selection --> dictionary with data selection
Return the number of rows in the scan.
To compare with the dictionary returned by listpartition, do the following:
resdict = listpartition('file.mms', createdict=True)
slist = ph.getMMSScans(thisdict)
for s in slist:
mmsN = ph.getMMSScanNrows(thisdict, s)
msN = ph.getScanNrows('referenceMS', s)
assert (mmsN == msN)
'''
mst_local.open(msfile)
if isinstance(selection, dict) and selection != {}:
mst_local.msselect(items=selection)
scand = mst_local.getscansummary()
mst_local.close()
Nrows = 0
if not str(myscan) in scand:
return Nrows
subscans = scand[str(myscan)]
for ii in subscans.keys():
Nrows += scand[str(myscan)][ii]['nRow']
return Nrows
def getMMSScanNrows(thisdict, myscan):
'''Get the number of rows of a scan in an MMS dictionary.
thisdict --> output dictionary from listpartition(MMS,createdict=true)
myscan --> scan ID (int)
Return the number of rows in the given scan. '''
if not isinstance(thisdict, dict):
casalog.post('ERROR: Input is not a dictionary', 'ERROR')
return -1
tkeys = thisdict.keys()
scanrows = 0
for k in tkeys:
if myscan in thisdict[k]['scanId']:
scanrows += thisdict[k]['scanId'][myscan]['nrows']
return scanrows
def getSpwIds(msfile, myscan, selection={}):
'''Get the Spw IDs of a scan.
msfile --> name of the MS or MMS
myscan --> scan Id (int)
selection --> dictionary with data selection
Return a list with the Spw IDs. Note that the returned spw IDs are sorted.
'''
import numpy as np
mst_local.open(msfile)
if isinstance(selection, dict) and selection != {}:
mst_local.msselect(items=selection)
scand = mst_local.getscansummary()
mst_local.close()
spwlist = []
if not str(myscan) in scand:
return spwlist
subscans = scand[str(myscan)]
aspws = np.array([],dtype=int)
for ii in subscans.keys():
sscanid = ii
spwids = scand[str(myscan)][sscanid]['SpwIds']
aspws = np.append(aspws,spwids)
# Sort spws and remove duplicates
aspws.sort()
uniquespws = np.unique(aspws)
# Try to return a list
spwlist = uniquespws.ravel().tolist()
return spwlist
def getScanSpwSummary(mslist=[]):
""" Get a consolidated dictionary with scan, spw, channel information
of a list of MSs. It adds the nrows of all sub-scans of a scan.
Keyword arguments:
mslist --> list with names of MSs
Returns a dictionary such as:
mylist=['subms1.ms','subms2.ms']
outdict = getScanSpwSummary(mylist)
outdict = {0: {'MS': 'subms1.ms',
'scanId': {30: {'nchans': array([64, 64]),
'nrows': 544,
'spwIds': array([ 0, 1])}},
'size': '214M'},
1: {'MS': 'ngc5921.ms',
'scanId': {1: {'nchans': array([63]),
'nrows': 4509,
'spwIds': array([0])},
2: {'nchans': array([63]),
'nrows': 1890,
'spwIds': array([0])}},
'size': '72M'}}
"""
if mslist == []:
return {}
# Create lists for scan and spw dictionaries of each MS
msscanlist = []
msspwlist = []
# List with sizes in bytes per sub-MS
sizelist = []
# Loop through all MSs
for subms in mslist:
try:
mst_local.open(subms)
scans = mst_local.getscansummary()
msscanlist.append(scans)
spws = mst_local.getspectralwindowinfo()
msspwlist.append(spws)
except Exception as exc:
raise Exception('Cannot get scan/spw information from subMS: {0}'.format(exc))
finally:
mst_local.close()
# Get the data volume in bytes per sub-MS
sizelist.append(getDiskUsage(subms))
# Get the information to list in output
# Dictionary to return
outdict = {}
for ims in range(mslist.__len__()):
# Create temp dictionary for each sub-MS
tempdict = {}
msname = os.path.basename(mslist[ims])
tempdict['MS'] = msname
tempdict['size'] = sizelist[ims]
# Get scan dictionary for this sub-MS
scandict = msscanlist[ims]
# Get spw dictionary for this sub-MS
# NOTE: the keys of spwdict.keys() are NOT the spw Ids
spwdict = msspwlist[ims]
# The keys are the scan numbers
scanlist = scandict.keys()
# Get information per scan
tempdict['scanId'] = {}
for scan in scanlist:
newscandict = {}
subscanlist = scandict[scan].keys()
# Get spws and nrows per sub-scan
nrows = 0
aspws = np.array([],dtype='int32')
for subscan in subscanlist:
nrows += scandict[scan][subscan]['nRow']
# Get the spws for each sub-scan
spwids = scandict[scan][subscan]['SpwIds']
aspws = np.append(aspws,spwids)
newscandict['nrows'] = nrows
# Sort spws and remove duplicates
aspws.sort()
uniquespws = np.unique(aspws)
newscandict['spwIds'] = uniquespws
# Array to hold channels
charray = np.empty_like(uniquespws)
spwsize = np.size(uniquespws)
# Now get the number of channels per spw
for ind in range(spwsize):
spwid = uniquespws[ind]
for sid in spwdict.keys():
if spwdict[sid]['SpectralWindowId'] == spwid:
nchans = spwdict[sid]['NumChan']
charray[ind] = nchans
continue
newscandict['nchans'] = charray
tempdict['scanId'][int(scan)] = newscandict
outdict[ims] = tempdict
#casalog.post(pp.format(outdict))
return outdict
def getMMSSpwIds(thisdict):
'''Get the list of spws from an MMS dictionary.
thisdict --> output dictionary from listpartition(MMS,createdict=true)
Return a list of the spw Ids in the dictionary. '''
import numpy as np
if not isinstance(thisdict, dict):
casalog.post('ERROR: Input is not a dictionary', 'ERROR')
return []
tkeys = thisdict.keys()
aspws = np.array([],dtype='int32')
for k in tkeys:
scanlist = thisdict[k]['scanId'].keys()
for s in scanlist:
spwids = thisdict[k]['scanId'][s]['spwIds']
aspws = np.append(aspws, spwids)
# Sort spws and remove duplicates
aspws.sort()
uniquespws = np.unique(aspws)
# Try to return a list
spwlist = uniquespws.ravel().tolist()
return spwlist
def getSubMSSpwIds(subms, thisdict):
import numpy as np
tkeys = thisdict.keys()
aspws = np.array([],dtype='int32')
mysubms = os.path.basename(subms)
for k in tkeys:
if thisdict[k]['MS'] == mysubms:
# get the spwIds of this subMS
scanlist = thisdict[k]['scanId'].keys()
for s in scanlist:
spwids = thisdict[k]['scanId'][s]['spwIds']
aspws = np.append(aspws, spwids)
break
# Sort spws and remove duplicates
aspws.sort()
uniquespws = np.unique(aspws)
# Try to return a list
spwlist = uniquespws.ravel().tolist()
return spwlist
def getDiskUsage(msfile):
"""Return the size in bytes of an MS or MMS in disk.
Keyword arguments:
msfile --> name of the MS
This function will return a value given by the command du -hs
"""
from subprocess import Popen, PIPE, STDOUT
# Command line to run
ducmd = 'du -hs {0}'.format(msfile)
if is_CASA6:
p = Popen(ducmd, shell=True, stdin=None, stdout=PIPE, stderr=STDOUT, close_fds=True)
o, e = p.communicate() ### previously 'sizeline = p.stdout.read()' here
### left process running...
sizeline = bytes2str(o.split( )[0])
else:
p = Popen(ducmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True)
sizeline = p.stdout.read()
_out, _err = p.communicate()
# Create a list of the output string, which looks like this:
# ' 75M\tuidScan23.data/uidScan23.0000.ms\n'
# This will create a list with [size,sub-ms]
mssize = sizeline.split()
return mssize[0]
def getSubtables(vis):
theSubTables = []
tbt_local.open(vis)
myKeyw = tbt_local.getkeywords()
tbt_local.close()
for k in myKeyw.keys():
theKeyw = myKeyw[k]
if (type(theKeyw)==str and theKeyw.split(' ')[0]=='Table:'
and not k=='SORTED_TABLE'):
theSubTables.append(os.path.basename(theKeyw.split(' ')[1]))
return theSubTables
def makeMMS(outputvis, submslist, copysubtables=False, omitsubtables=[], parallelaxis=''):
"""Create a Multi-MS from a list of MSs
Keyword arguments:
outputvis -- name of the output MMS
submslist -- list of input subMSs to create the output from
copysubtables -- True will copy the sub-tables from the first subMS to the others in the
output MMS. Default to False.
omitsubtables -- List of sub-tables to omit when copying to output MMS. They will be linked instead
parallelasxis -- Optionally, set the value to be written to AxisType in table.info of the output MMS
Usually this value comes from the separationaxis keyword of partition or mstransform.
Be AWARE that this function will remove the tables listed in submslist.
"""
if os.path.exists(outputvis):
raise ValueError('Output MS already exists')
if len(submslist)==0:
raise ValueError('No SubMSs given')
## make an MMS with all sub-MSs contained in a SUBMSS subdirectory
origpath = os.getcwd()
try:
try:
mst_local.createmultims(outputvis,
submslist,
[],
True, # nomodify
False, # lock
copysubtables,
omitsubtables
) # when copying the subtables, omit these
except Exception:
raise
finally:
mst_local.close()
# remove the SORTED_TABLE keywords because the sorting is not reliable after partitioning
try:
tbt_local.open(outputvis, nomodify=False)
if 'SORTED_TABLE' in tbt_local.keywordnames():
tbt_local.removekeyword('SORTED_TABLE')
tbt_local.close()
for thesubms in submslist:
tbt_local.open(outputvis+'/SUBMSS/'+os.path.basename(thesubms), nomodify=False)
if 'SORTED_TABLE' in tbt_local.keywordnames():
tobedel = tbt_local.getkeyword('SORTED_TABLE').split(' ')[1]
tbt_local.removekeyword('SORTED_TABLE')
os.system('rm -rf '+tobedel)
tbt_local.close()
except Exception:
tbt_local.close()
raise
# Create symbolic links to the subtables of the first SubMS in the reference MS (top one)
os.chdir(outputvis)
mastersubms = os.path.basename(submslist[0].rstrip('/'))
thesubtables = getSubtables('SUBMSS/'+mastersubms)
for s in thesubtables:
os.symlink('SUBMSS/'+mastersubms+'/'+s, s)
os.chdir('SUBMSS/'+mastersubms)
# Remove the SOURCE and HISTORY tables, which should not be linked
thesubtables.remove('SOURCE')
thesubtables.remove('HISTORY')
# Create sym links to all sub-tables in all subMSs
for i in range(1,len(submslist)):
thesubms = os.path.basename(submslist[i].rstrip('/'))
os.chdir('../'+thesubms)
for s in thesubtables:
os.system('rm -rf '+s)
os.symlink('../'+mastersubms+'/'+s, s)
# Write the AxisType info in the MMS
if parallelaxis != '':
setAxisType(outputvis, parallelaxis)
except Exception as exc:
os.chdir(origpath)
raise ValueError('Problem in MMS creation: {0}'.format(exc))
os.chdir(origpath)
return True
def axisType(mmsname):
"""Get the axisType information from a Multi-MS. The AxisType information
is usually added for Multi-MS with the axis which data is parallelized across.
Keyword arguments:
mmsname -- name of the Multi-MS
It returns the value of AxisType or an empty string if it doesn't exist.
"""
axis = ''
try:
tbt_local.open(mmsname, nomodify=True)
tbinfo = tbt_local.info()
except Exception as exc:
raise ValueError('Unable to open table {0}. Exception: {1}'.format(mmsname, exc))
finally:
tbt_local.close()
if 'readme' in tbinfo:
readme = tbinfo['readme']
readlist = readme.splitlines()
for val in readlist:
if val.__contains__('AxisType'):
a,b,axis = val.partition('=')
return axis.strip()
def setAxisType(mmsname, axis=''):
"""Set the AxisType keyword in a Multi-MS info. If AxisType already
exists, it will be overwritten.
Keyword arguments:
mmsname -- name of the Multi-MS
axis -- parallel axis of the Multi-MS. Options: scan; spw or scan,spw
Return True on success, False otherwise.
"""
import copy
if axis == '':
raise ValueError('Axis value cannot be empty')
try:
tbt_local.open(mmsname)
tbinfo = tbt_local.info()
except Exception as exc:
raise ValueError('Unable to open table {0}. Exception: {1}'.format(mmsname, exc))
finally:
tbt_local.close()
readme = ''
# Save original readme
if 'readme' in tbinfo:
readme = tbinfo['readme']
# Check if AxisType already exist and remove it
if axisType(mmsname) != '':
casalog.post('WARN: Will overwrite the existing AxisType value', 'WARN')
readlist = readme.splitlines()
newlist = copy.deepcopy(readlist)
for val in newlist:
if val.__contains__('AxisType'):
readlist.remove(val)
# Recreate the string
nr = ''
for val in readlist:
nr = nr + val + '\n'
readme = nr.rstrip()
# Preset for axis info
axisInfo = "AxisType = "
axis.rstrip()
axisInfo = axisInfo + axis + '\n'
# New readme
newReadme = axisInfo + readme
# Create readme record
readmerec = {'readme':newReadme}
try:
tbt_local.open(mmsname, nomodify=False)
tbt_local.putinfo(readmerec)
except Exception as exc:
raise ValueError('Unable to put readme info into table {0}. Exception: {1}'.
format(mmsname, exc))
finally:
tbt_local.close()
# Check if the axis was correctly added
check_axis = axisType(mmsname)
if check_axis != axis:
return False
return True
def buildScanDDIMap(scanSummary, ddIspectralWindowInfo):
"""
Builds a scan->DDI map and 3 list of # visibilities per DDI, scan, field
:param scanSummary: scan summary dictionary as produced by the mstool (getscansummary)
:param ddiSpectralWindowInfo: SPW info dictionary as produced by the mstool
(getspectralwindowinfo())
:returns: a dict with a scan->ddi map, and three dict with # of visibilities per
ddi, scan, and field.
"""
# Make an array for total number of visibilites per ddi and scan separatelly
nVisPerDDI = {}
nVisPerScan = {}
nVisPerField = {}
# Iterate over scan list
scanDdiMap = {}
for scan in sorted(scanSummary):
# Initialize scan sub-map
scanDdiMap[scan] = {}
# Iterate over timestamps for this scan
for timestamp in scanSummary[scan]:
# Get list of ddis for this timestamp
DDIds = scanSummary[scan][timestamp]['DDIds']
fieldId = str(scanSummary[scan][timestamp]['FieldId'])
# Get number of rows per ddi (assume all DDIs have the same number of rows)
# In ALMA data WVR DDI has only one row per antenna but it is separated from the other DDIs
nrowsPerDDI = scanSummary[scan][timestamp]['nRow'] / len(DDIds)
# Iterate over DDIs for this timestamp
for ddi in DDIds:
# Convert to string to be used as a map key
ddi = str(ddi)
# Check if DDI entry is already present for this scan, otherwise initialize it
if ddi not in scanDdiMap[scan]:
scanDdiMap[scan][ddi] = {}
scanDdiMap[scan][ddi]['nVis'] = 0
scanDdiMap[scan][ddi]['fieldId'] = fieldId
scanDdiMap[scan][ddi]['isWVR'] = ddIspectralWindowInfo[ddi]['isWVR']
# Calculate number of visibilities
nvis = nrowsPerDDI*ddIspectralWindowInfo[ddi]['NumChan']*ddIspectralWindowInfo[ddi]['NumCorr']
# Add number of rows and vis from this timestamp
scanDdiMap[scan][ddi]['nVis'] = scanDdiMap[scan][ddi]['nVis'] + nvis
# Update ddi nvis
if ddi not in nVisPerDDI:
nVisPerDDI[ddi] = nvis
else:
nVisPerDDI[ddi] = nVisPerDDI[ddi] + nvis
# Update scan nvis
if scan not in nVisPerScan:
nVisPerScan[scan] = nvis
else:
nVisPerScan[scan] = nVisPerScan[scan] + nvis
# Update field nvis
if fieldId not in nVisPerField:
nVisPerField[fieldId] = nvis
else:
nVisPerField[fieldId] = nVisPerField[fieldId] + nvis
return scanDdiMap, nVisPerDDI, nVisPerScan, nVisPerField
def getPartitionMap(msfilename, nsubms, selection={}, axis=['field','spw','scan'],plotMode=0):
"""Generates a partition scan/spw map to obtain optimal load balancing with the following criteria:
1st - Maximize the scan/spw/field distribution across sub-MSs
2nd - Generate sub-MSs with similar size
In order to balance better the size of the subMSs the allocation process
iterates over the scan,spw pairs in descending number of visibilities.
That is larger chunks are allocated first, and smaller chunks at the final
stages so that they can be used to balance the load in a stable way
Keyword arguments:
msname -- Input MS filename
nsubms -- Number of subMSs
selection -- Data selection dictionary
axis -- Vector of strings containing the axis for load distribution (scan,spw,field)
plotMode -- Integer in the range 0-3 to determine the plot generation mode
0 - Don't generate any plots
1 - Show plots but don't save them
2 - Save plots but don't show them
3 - Show and save plots
Returns a map of the sub-MSs with the corresponding scan/spw selections and the number of visibilities
"""
# Open ms tool
mst_local.open(msfilename)
# Apply data selection
if isinstance(selection, dict) and selection != {}:
mst_local.msselect(items=selection)
# Get list of DDIs and timestamps per scan
scanSummary = mst_local.getscansummary()
ddIspectralWindowInfo = mst_local.getspectralwindowinfo()
# Close ms tool
mst_local.close()
# Get list of WVR SPWs using the ms metadata tool
msmdt_local.open(msfilename)
wvrspws = msmdt_local.wvrspws()
msmdt_local.close()
# Mark WVR DDIs as identified by the ms metadata tool
for ddi in ddIspectralWindowInfo:
if ddIspectralWindowInfo[ddi] in wvrspws:
ddIspectralWindowInfo[ddi]['isWVR'] = True
else:
ddIspectralWindowInfo[ddi]['isWVR'] = False
scanDdiMap, nVisPerDDI, nVisPerScan, nVisPerField = buildScanDDIMap(scanSummary,
ddIspectralWindowInfo)
# Sort the scan/ddi pairs depending on the number of visibilities
ddiList = list()
scanList = list()
fieldList = list()
nVisList = list()
nScanDDIPairs = 0
for scan in scanDdiMap:
for ddi in scanDdiMap[scan]:
ddiList.append(ddi)
scanList.append(scan)
fieldList.append(scanDdiMap[scan][ddi]['fieldId'])
nVisList.append(scanDdiMap[scan][ddi]['nVis'])
nScanDDIPairs += 1
# Check that the number of available scan/ddi pairs is not greater than the number of subMSs
if nsubms > nScanDDIPairs:
casalog.post("Number of subMSs (%i) is greater than available scan,ddi pairs (%i), setting nsubms to %i"
% (nsubms,nScanDDIPairs,nScanDDIPairs),"WARN","getPartitionMap")
nsubms = nScanDDIPairs
ddiArray = np.array(ddiList)
scanArray = np.array(scanList)
nVisArray = np.array(nVisList)
nVisSortIndex = np.lexsort((ddiArray, scanArray, nVisArray))
# argsort/lexsort return indices by increasing value. This reverses the indices by
# decreasing value
nVisSortIndex[:] = nVisSortIndex[::-1]
ddiArray = ddiArray[nVisSortIndex]
scanArray = scanArray[nVisSortIndex]
nVisArray = nVisArray[nVisSortIndex]
# Make a map for the contribution of each subMS to each scan
scanNvisDistributionPerSubMs = {}
for scan in scanSummary:
scanNvisDistributionPerSubMs[scan] = np.zeros(nsubms)
# Make a map for the contribution of each subMS to each ddi
ddiNvisDistributionPerSubMs = {}
for ddi in ddIspectralWindowInfo:
ddiNvisDistributionPerSubMs[ddi] = np.zeros(nsubms)
# Make a map for the contribution of each subMS to each field
fieldList = np.unique(fieldList)
fieldNvisDistributionPerSubMs = {}
for field in fieldList:
fieldNvisDistributionPerSubMs[field] = np.zeros(nsubms)
# Make an array for total number of visibilites per subms
nvisPerSubMs = np.zeros(nsubms)
# Initialize final map of scans/pw pairs per subms
submScanDdiMap = {}
for subms in range (0,nsubms):
submScanDdiMap[subms] = {}
submScanDdiMap[subms]['scanList'] = list()
submScanDdiMap[subms]['ddiList'] = list()
submScanDdiMap[subms]['fieldList'] = list()
submScanDdiMap[subms]['nVisList'] = list()
submScanDdiMap[subms]['nVisTotal'] = 0
# Iterate over the scan/ddi map and assign each pair to a subMS
for pair in range(len(ddiArray)):
ddi = ddiArray[pair]
scan = scanArray[pair]
field = scanDdiMap[scan][ddi]['fieldId']
# Select the subMS that with bigger (scan/ddi/field gap)
# We use the average as a refLevel to include global structure information
# But we also take into account the actual max value in case we are distributing large uneven chunks
jointNvisGap = np.zeros(nsubms)
if 'scan' in axis:
refLevel = max(nVisPerScan[scan] //
nsubms,scanNvisDistributionPerSubMs[scan].max())
jointNvisGap = jointNvisGap + refLevel - scanNvisDistributionPerSubMs[scan]
if 'spw' in axis:
refLevel = max(nVisPerDDI[ddi] //
nsubms,ddiNvisDistributionPerSubMs[ddi].max())
jointNvisGap = jointNvisGap + refLevel - ddiNvisDistributionPerSubMs[ddi]
if 'field' in axis:
refLevel = max(nVisPerField[field] //
nsubms,fieldNvisDistributionPerSubMs[field].max())
jointNvisGap = jointNvisGap + refLevel - fieldNvisDistributionPerSubMs[field]
optimalSubMs = np.where(jointNvisGap == jointNvisGap.max())
optimalSubMs = optimalSubMs[0] # np.where returns a tuple
# In case of multiple candidates select the subms with minum number of total visibilities
if len(optimalSubMs) > 1:
subIdx = np.argmin(nvisPerSubMs[optimalSubMs])
optimalSubMs = optimalSubMs[subIdx]
else:
optimalSubMs = optimalSubMs[0]
# Store the scan/ddi pair info in the selected optimal subms
nVis = scanDdiMap[scan][ddi]['nVis']
nvisPerSubMs[optimalSubMs] = nvisPerSubMs[optimalSubMs] + nVis
submScanDdiMap[optimalSubMs]['scanList'].append(int(scan))
submScanDdiMap[optimalSubMs]['ddiList'].append(int(ddi))
submScanDdiMap[optimalSubMs]['fieldList'].append(field)
submScanDdiMap[optimalSubMs]['nVisList'].append(nVis)
submScanDdiMap[optimalSubMs]['nVisTotal'] = submScanDdiMap[optimalSubMs]['nVisTotal'] + nVis
# Also update the counters for the subms-scan and subms-ddi maps
scanNvisDistributionPerSubMs[scan][optimalSubMs] = scanNvisDistributionPerSubMs[scan][optimalSubMs] + nVis
ddiNvisDistributionPerSubMs[ddi][optimalSubMs] = ddiNvisDistributionPerSubMs[ddi][optimalSubMs] + nVis
fieldNvisDistributionPerSubMs[field][optimalSubMs] = fieldNvisDistributionPerSubMs[field][optimalSubMs] + nVis
# Generate plots
if plotMode > 0:
plt.close()
plotname_prefix = os.path.basename(msfilename) + ' axis ' + string.join(axis)
plotVisDistribution(nVisPerScan,scanNvisDistributionPerSubMs,plotname_prefix,'scan',plotMode=plotMode)
plotVisDistribution(nVisPerDDI,ddiNvisDistributionPerSubMs,plotname_prefix,'ddi',plotMode=plotMode)
plotVisDistribution(nVisPerField,fieldNvisDistributionPerSubMs,plotname_prefix,'field',plotMode=plotMode)
# Generate list of taql commands
for subms in submScanDdiMap:
# Initialize taql command
from collections import defaultdict
dmytaql = defaultdict(list)
for pair in range(len(submScanDdiMap[subms]['scanList'])):
# Get scan/ddi for this pair
ddi = submScanDdiMap[subms]['ddiList'][pair]
scan = submScanDdiMap[subms]['scanList'][pair]
dmytaql[ddi].append(scan)
mytaql = []
for ddi, scans in dmytaql.items():
scansel = '[' + ', '.join([str(x) for x in scans]) + ']'
mytaql.append(('(DATA_DESC_ID==%i && (SCAN_NUMBER IN %s))') % (ddi, scansel))
mytaql = ' OR '.join(mytaql)
# Store taql
submScanDdiMap[subms]['taql'] = mytaql
# Return map of scan/ddi pairs per subMs
return submScanDdiMap
def plotVisDistribution(nvisMap,idNvisDistributionPerSubMs,filename,idLabel,plotMode=1):
"""Generates a plot to show the distribution of scans/wp across subMs.
The plot style is a stacked bar char, where the spw/scans with higher number of visibilities are shown at the bottom
Keyword arguments:
nvisMap -- Map of total numbe of visibilities per Id
idNvisDistributionPerSubMs -- Map of visibilities per subMS for each Id
filename -- Name of MS to be shown in the title and plot filename
idLabel -- idLabel to indicate the id (spw, scan) to be used for the figure title
plotMode -- Integer in the range 0-3 to determine the plot generation mode
0 - Don't generate any plots
1 - Show plots but don't save them
2 - Save plots but don't show them
2 - Show and save plots
"""
# Create a new figure
plt.ioff()
# If plot is not to be shown then use pre-define sized figure to 1585x1170 pizels with 75 DPI
# (we cannot maximize the window to the screen size)
if plotMode==2:
plt.figure(figsize=(21.13,15.6),dpi=75) # Size is given in inches
else:
plt.figure()
# Sort the id according to the total number of visibilities to that we can
# represent bigger the groups at the bottom and the smaller ones at the top
idx = 0
idArray = np.zeros(len(nvisMap))
idNvisArray = np.zeros(len(nvisMap))
for id in nvisMap:
idArray[idx] = int(id)
idNvisArray[idx] = nvisMap[id]
idx = idx + 1
idArraySortIndex = np.argsort(idNvisArray)
idArraySortIndex[:] = idArraySortIndex[::-1]
idArraySorted = idArray[idArraySortIndex]
# Initialize color vector to alternate cold/warm colors
nid = len(nvisMap)
colorVector = list()
colorRange = range(nid)
colorVectorEven = colorRange[::2]
colorVectorOdd = colorRange[1::2]
colorVectorOdd.reverse()
while len(colorVectorOdd) > 0 or len(colorVectorEven) > 0:
if len(colorVectorOdd) > 0: colorVector.append(colorVectorOdd.pop())
if len(colorVectorEven) > 0: colorVector.append(colorVectorEven.pop())
# Generate stacked bar plot
coloridx = 0 # color index
width = 0.35 # bar width
nsubms = len(idNvisDistributionPerSubMs[idNvisDistributionPerSubMs.keys()[0]])
idx = np.arange(nsubms) # location of the bar centers in the horizontal axis
bottomLevel = np.zeros(nsubms) # Reference level for the bars to be stacked after the previous ones
legendidLabels = list() # List of legend idLabels
plotHandles=list() # List of plot handles for the legend
for id in idArraySorted:
id = str(int(id))
idplot = plt.bar(idx, idNvisDistributionPerSubMs[id], width, bottom=bottomLevel, color=plt.cm.Paired(1.*colorVector[coloridx]/nid))
# Update color index
coloridx = coloridx + 1
# Update legend lists
plotHandles.append(idplot)
legendidLabels.append(idLabel + ' ' + id)
# Update reference level
bottomLevel = bottomLevel + idNvisDistributionPerSubMs[id]
# Add legend
plt.legend( plotHandles, legendidLabels, bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)
# AQdd lable for y axis
plt.ylabel('nVis')
# Add x-ticks
xticks = list()
for subms in range(0,nsubms):
xticks.append('subMS-' + str(subms))
plt.xticks(idx+width/2., xticks )
# Add title
title = filename + ' distribution of ' + idLabel + ' visibilities across sub-MSs'
plt.title(title)
# Resize to full screen
if plotMode==1 or plotMode==3:
mng = plt.get_current_fig_manager()
mng.resize(*mng.window.maxsize())
# Show figure
if plotMode==1 or plotMode==3:
plt.ion()
plt.show()
# Save plot
if plotMode>1:
title = title.replace(' ','-') + '.png'
plt.savefig(title)
# If plot is not to be shown then close it
if plotMode==2:
plt.close()