#! /usr/bin/env python # The above is for running the doctest, not normal use. """ A set of functions for manipulating spw:chan selection strings. If this is run from a shell (i.e. not in casapy), doctest will be used to run several unit tests from the doc strings, including the one below: Example: >>> from update_spw import update_spw >>> update_spw('0~2,5', None)[0] '0~2,3' >>> update_spw('0~2,5', None)[1]['5'] # doctest warning! dicts don't always print out in the same order! '3' """ from __future__ import absolute_import import copy import os from casatasks.private.casa_transition import is_CASA6 if is_CASA6: from casatools import ms from casatools import table as tbtool _ms = ms( ) else: #from taskinit import mstool from casac import * from taskinit import ms as _ms tbtool = casac.table def update_spw(spw, spwmap=None): """ Given an spw:chan selection string, return what it should be after the spws have been remapped (i.e. by split), and a map from input to output spws (spwmap). It does not change spw OR the *channels* part of the output spw string! (See update_spwchan) If given, spwmap will be used as a dictionary from (string) input spw to (string) output spws. Otherwise it will be freshly calculated. Supplying spwmap doesn't just save work: it is also necessary for chaining update_spw() calls when the first selection includes more spws than the subsequent one(s). HOWEVER, if given, spwmap must have slots for all the spws that will appear in the output MS, i.e. it can't be grown once made. Examples: >>> from update_spw import update_spw >>> myfitspw, spws = update_spw('0~3,5;6:1~7;11~13', None) >>> myfitspw '0~3,4;5:1~7;11~13' >>> myspw = update_spw('1,5;6:8~10', spws)[0] >>> myspw # not '0,1,2:8~10' '1,4;5:8~10' >>> update_spw('0~3,5;6:1~7;11~13,7~9:0~3,11,7~8:6~8', None)[0] '0~3,4;5:1~7;11~13,6~8:0~3,9,6~7:6~8' # Let's say we want updates of both fitspw and spw, but fitspw and spw # are disjoint (in spws). >>> fitspw = '1~10:5~122,15~22:5~122' >>> spw = '6~14' # Initialize spwmap with the union of them. >>> spwmap = update_spw(join_spws(fitspw, spw), None)[1] >>> myfitspw = update_spw(fitspw, spwmap)[0] >>> myfitspw '0~9:5~122,14~21:5~122' >>> myspw = update_spw(spw, spwmap)[0] >>> myspw '5~13' >>> myspw = update_spw('0,1,3;5~8:20~30;44~50^2', None)[0] >>> myspw '0,1,2;3~6:20~30;44~50^2' """ # Blank is valid. Blank is good. if not spw: return '', {} # A list of [spw, chan] pairs. The chan parts will not be changed. spwchans = [] make_spwmap = False if not spwmap: spwmap = {} make_spwmap = True spws = set([]) # Because ; means different things when it separates spws and channel # ranges, I can't think of a better way to construct spwchans than an # explicit state machine. (But $spws_alone =~ s/:[^,]+//g;) inspw = True # until a : is encountered. spwgrp = '' chagrp = '' def store_spwchan(sstr, cstr): spwchans.append([sstr, cstr]) if make_spwmap: for sgrp in sstr.split(';'): if sgrp.find('~') > -1: start, end = map(int, sgrp.split('~')) spws.update(range(start, end + 1)) else: spws.add(int(sgrp)) for c in spw: if c == ',': # Start new [spw, chan] pair. # Store old one. store_spwchan(spwgrp, chagrp) # Initialize new one. spwgrp = '' chagrp = '' inspw = True elif c == ':': inspw = False elif inspw: spwgrp += c else: chagrp += c # Store final [spw, chan] pair. store_spwchan(spwgrp, chagrp) # casalog.post("spwchans ={}".format(spwchans)) # casalog.post("spws ={}".format(spws)) # Update spw (+ fitspw) if make_spwmap: i = 0 for s in sorted(spws): spwmap[str(s)] = str(i) i += 1 outstr = '' for sc in spwchans: sgrps = sc[0].split(';') for sind in range(len(sgrps)): sgrp = sgrps[sind] if sgrp.find('~') > -1: start, end = sgrp.split('~') sgrps[sind] = spwmap[start] + '~' + spwmap[end] else: sgrps[sind] = spwmap[sgrp] outstr += ';'.join(sgrps) if sc[1]: outstr += ':' + sc[1] outstr += ',' return outstr.rstrip(','), spwmap # discard final comma. def spwchan_to_ranges(vis, spw): """ Returns the spw:chan selection string spw as a dict of channel selection ranges for vis, keyed by spectral window ID. The ranges are stored as tuples of (start channel, end channel (inclusive!), step). Note that '' returns an empty set! Use '*' to select everything! Example: >>> from update_spw import spwchan_to_ranges >>> selranges = spwchan_to_ranges('uid___A002_X1acc4e_X1e7.ms', '7:10~20^2;40~55') ValueError: spwchan_to_ranges() does not support multiple channel ranges per spw. >>> selranges = spwchan_to_ranges('uid___A002_X1acc4e_X1e7.ms', '0~1:1~3,5;7:10~20^2') >>> selranges {0: (1, 3, 1), 1: (1, 3, 1), 5: (10, 20, 2), 7: (10, 20, 2)} """ selarr = _ms.msseltoindex(vis, spw=spw)['channel'] nspw = selarr.shape[0] selranges = {} for s in range(nspw): if selarr[s][0] in selranges: raise ValueError('spwchan_to_ranges() does not support multiple channel ranges per spw.') selranges[selarr[s][0]] = tuple(selarr[s][1:]) return selranges def spwchan_to_sets(vis, spw): """ Returns the spw:chan selection string spw as a dict of sets of selected channels for vis, keyed by spectral window ID. Note that '' returns an empty set! Use '*' to select everything! Example (16.ms has spws 0 and 1 with 16 chans each): >>> from update_spw import spwchan_to_sets >>> vis = casa['dirs']['data'] + '/regression/unittest/split/unordered_polspw.ms' >>> spwchan_to_sets(vis, '0:0') {0: set([0])} >>> selsets = spwchan_to_sets(vis, '1:1~3;5~9^2,9') # 9 is a bogus spw. >>> selsets {1: [1, 2, 3, 5, 7, 9]} >>> spwchan_to_sets(vis, '1:1~3;5~9^2,8') {1: set([1, 2, 3, 5, 7, 9]), 8: set([0])} >>> spwchan_to_sets(vis, '') {} """ if not spw: # _ms.msseltoindex(vis, spw='')['channel'] returns a return {} # different kind of empty array. Skip it. # Currently distinguishing whether or not vis is a valid MS from whether it # just doesn't have all the channels in spw is a bit crude. Sanjay is # working on adding some flexibility to _ms.msseltoindex. if not os.path.isdir(vis): raise ValueError(str(vis) + ' is not a valid MS.') sets = {} try: scharr = _ms.msseltoindex(vis, spw=spw)['channel'] for scr in scharr: if not scr[0] in sets: sets[scr[0]] = set([]) # scr[2] is the last selected channel. Bump it up for range(). scr[2] += 1 sets[scr[0]].update(range(*scr[1:])) except: # spw includes channels that aren't in vis, so it needs to be trimmed # down to make _ms.msseltoindex happy. allrec = _ms.msseltoindex(vis, spw='*') # casalog.post("Trimming {}".format(spw)) spwd = spw_to_dict(spw, {}, False) for s in spwd: if s in allrec['spw']: endchan = allrec['channel'][s, 2] if not s in sets: sets[s] = set([]) if spwd[s] == '': # We need to get the spw's # of channels without using # _ms.msseltoindex. mytb = tbtool() mytb.open(vis + '/SPECTRAL_WINDOW') spwd[s] = range(mytb.getcell('NUM_CHAN', s)) mytb.close() sets[s].update([c for c in spwd[s] if c <= endchan]) return sets def set_to_chanstr(chanset, totnchan=None): """ Essentially the reverse of expand_tilde. Given a set or list of integers chanset, returns the corresponding string form. It will not use non-unity steps (^) if multiple ranges (;) are necessary, but it will use ^ if it helps to eliminate any ;s. totnchan: the total number of channels for the input spectral window, used to abbreviate the return string. It returns '' for the empty set and '*' if Examples: >>> from update_spw import set_to_chanstr >>> set_to_chanstr(set([0, 1, 2, 4, 5, 6, 7, 9, 11, 13])) '0~2;4~7;9;11;13' >>> set_to_chanstr(set([7, 9, 11, 13])) '7~13^2' >>> set_to_chanstr(set([7, 9])) '7~9^2' >>> set_to_chanstr([0, 1, 2]) '0~2' >>> set_to_chanstr([0, 1, 2], 3) '*' >>> set_to_chanstr([0, 1, 2, 6], 3) '*' >>> set_to_chanstr([0, 1, 2, 6]) '0~2;6' >>> set_to_chanstr([1, 2, 4, 5, 6, 7, 8, 9, 10, 11], 12) '1~2;4~11' """ if totnchan: mylist = [c for c in chanset if c < totnchan] else: mylist = list(chanset) if totnchan == len(mylist): return '*' mylist.sort() retstr = '' if len(mylist) > 1: # Check whether the same step can be used throughout. step = mylist[1] - mylist[0] samestep = True for i in range(2, len(mylist)): if mylist[i] - mylist[i - 1] != step: samestep = False break if samestep: retstr = str(mylist[0]) + '~' + str(mylist[-1]) if step > 1: retstr += '^' + str(step) else: sc = mylist[0] oldc = sc retstr = str(sc) nc = len(mylist) for i in range(1, nc): cc = mylist[i] if (cc > oldc + 1) or (i == nc - 1): if (i == nc - 1) and (cc == oldc + 1): retstr += '~' + str(cc) else: if oldc != sc: retstr += '~' + str(oldc) retstr += ';' + str(cc) sc = cc oldc = cc elif len(mylist) > 0: retstr = str(mylist[0]) return retstr def sets_to_spwchan(spwsets, nchans={}): """ Returns a spw:chan selection string for a dict of sets of selected channels keyed by spectral window ID. nchans is a dict of the total number of channels keyed by spw, used to abbreviate the return string. Examples: >>> from update_spw import sets_to_spwchan >>> # Use nchans to get '1' instead of '1:0~3'. >>> sets_to_spwchan({1: [0, 1, 2, 3]}, {1: 4}) '1' >>> sets_to_spwchan({1: set([1, 2, 3, 5, 7, 9]), 8: set([0])}) '1:1~3;5;7;9,8:0' >>> sets_to_spwchan({0: set([4, 5, 6]), 1: [4, 5, 6], 2: [4, 5, 6]}) '0~2:4~6' >>> sets_to_spwchan({0: [4], 1: [4], 3: [0, 1], 4: [0, 1], 7: [0, 1]}, {3: 2, 4: 2, 7: 2}) '0~1:4,3~4,7' """ # Make a list of spws for each channel selection. csd = {} for s in spwsets: # Convert the set of channels to a string. if spwsets[s]: cstr = set_to_chanstr(spwsets[s], nchans.get(s)) if cstr: if not cstr in csd: csd[cstr] = [] csd[cstr].append(s) # Now convert those spw lists into strings, inverting as we go so the final # string can be sorted by spw: scd = {} while csd: cstr, slist = csd.popitem() slist.sort() startspw = slist[0] oldspw = startspw sstr = str(startspw) nselspw = len(slist) for sind in range(1, nselspw): currspw = slist[sind] if (currspw > oldspw + 1) or (sind == nselspw - 1): if currspw > oldspw + 1: if oldspw != startspw: sstr += '~' + str(oldspw) sstr += ';' + str(currspw) startspw = currspw else: # The range has come to an end on the last spw. sstr += '~' + str(currspw) oldspw = currspw scd[sstr] = cstr spwgrps = sorted(scd.keys()) # Finally stitch together the final string. scstr = '' for sstr in spwgrps: scstr += sstr if scd[sstr] != '*': scstr += ':' + scd[sstr] scstr += ',' return scstr.rstrip(',') def update_spwchan(vis, sch0, sch1, truncate=False, widths={}): """ Given an spw:chan selection string sch1, return what it must be changed to to get the same result if used with the output of split(vis, spw=sch0). '' is taken to mean '*' in the input but NOT the output! For the output '' means sch0 and sch1 do not intersect. truncate: If True and sch0 only partially overlaps sch1, return the update of the intersection. If (False and sch0 does not cover sch1), OR there is no intersection, raises a ValueError. widths is a dictionary of averaging widths (default 1) for each spw. Examples: >>> from update_spw import update_spwchan >>> newspw = update_spwchan('anything.ms', 'anything', 'anything') >>> newspw '*' >>> vis = casa['dirs']['data'] + '/regression/unittest/split/unordered_polspw.ms' >>> update_spwchan(vis, '0~1:1~3,5;7:10~20^2', '0~1:2~3,5;7:12~18^2') '0~1:1~2,2~3:1~4' >>> update_spwchan(vis, '7', '3') ValueError: '3' is not a subset of '7'. >>> update_spwchan(vis, '7:10~20^2', '7:12~18^3') ValueError: '7:12~18^3' is not a subset of '7:10~20^2'. >>> update_spwchan(vis, '7:10~20^2', '7:12~18^3', truncate=True) '0:1~4^3' >>> update_spwchan(vis, '7:10~20^2', '7:12~18^3', truncate=True, widths={7: 2}) '0:0~2^2' """ # Convert '' to 'select everything'. if not sch0: sch0 = '*' if not sch1: sch1 = '*' # Short circuits if sch1 == '*': return '*' elif sch1 in (sch0, '*'): return '*' sch0sets = spwchan_to_sets(vis, sch0) sch1sets = spwchan_to_sets(vis, sch1) outsets = {} outspw = 0 s0spws = sorted(sch0sets.keys()) s1spws = sorted(sch1sets.keys()) ns0spw = len(s0spws) nchans = {} for s in s1spws: if s in s0spws: s0 = sch0sets[s] s1 = sch1sets[s] # Check for and handle (throw or dispose) channels in sch1 that aren't in # sch0. if s1.difference(s0): if truncate: s1.intersection_update(s0) if not s1: raise ValueError("'%s' does not overlap '%s'." % (sch1, sch0)) else: raise ValueError("'%s' is not a subset of '%s'." % (sch1, sch0)) # Adapt s1 for a post-s0 world. s0list = sorted(list(s0)) s1list = sorted(list(s1)) outchan = 0 nc0 = len(s0list) for s1ind in range(len(s1list)): while (outchan < nc0) and (s0list[outchan] < s1list[s1ind]): outchan += 1 if outchan == nc0: # Shouldn't happen outchan -= 1 s1list[s1ind] = outchan // widths.get(s, 1) # Determine outspw. while (outspw < ns0spw) and (s0spws[outspw] < s): outspw += 1 if outspw == ns0spw: # Shouldn't happen outspw -= 1 outsets[outspw] = set(s1list) # Get the number of channels per spw that are selected by s0. nchans[outspw] = len(s0) elif not truncate: raise ValueError(str(s) + ' is not a selected spw of ' + sch0) return sets_to_spwchan(outsets, nchans) def expand_tilde(tstr, conv_multiranges=False): """ Expands a string like '8~11' to [8, 9, 10, 11]. Returns '*' if tstr is ''! conv_multiranges: If True, '*' will be returned if tstr contains ';'. (split can't yet handle multiple channel ranges per spw.) Examples: >>> from update_spw import expand_tilde >>> expand_tilde('8~11') [8, 9, 10, 11] >>> expand_tilde(None) '*' >>> expand_tilde('3~7^2;9~11') [3, 5, 7, 9, 10, 11] >>> expand_tilde('3~7^2;9~11', True) '*' """ tstr = str(tstr) # Allows bare ints. if (not tstr) or (conv_multiranges and tstr.find(';') > -1): return '*' tstr = tstr.replace("'", '') # Dequote tstr = tstr.replace('"', '') numset = set([]) for numrang in tstr.split(';'): step = 1 try: if numrang.find('~') > -1: if numrang.find('^') > -1: numrang, step = numrang.split('^') step = int(step) start, end = map(int, numrang.split('~')) else: start = int(numrang) end = start except: raise ValueError('numrang = ' + numrang + ', tstr = ' + tstr + ', conv_multiranges = ' + str(conv_multiranges)) numset.update(range(start, end + 1, step)) return sorted(list(numset)) def spw_to_dict(spw, spwdict={}, conv_multiranges=True): """ Expand an spw:chan string to {s0: [s0chans], s1: [s1chans, ...], ...} where s0, s1, ... are integers for _each_ selected spw, and s0chans is a set of selected chans (as integers) for s0. '' instead of a channel set means that all of the channels are selected. The spw:chan dict is unioned with spwdict. Returning an empty dict means everything should be selected (i.e. spw = ''). (split can't yet handle multiple channel ranges per spw.) conv_multiranges: If True, any spw with > 1 channel range selected will have ALL of its channels selected. (split can't yet handle multiple channel ranges per spw.) Examples: >>> from update_spw import spw_to_dict >>> spw_to_dict('', {}) {} >>> spw_to_dict('6~8:2~5', {})[6] set([2, 3, 4, 5]) >>> spw_to_dict('6~8:2~5', {})[8] set([2, 3, 4, 5]) >>> spw_to_dict('6~8:2~5', {6: ''})[6] '' >>> spw_to_dict('6~8:2~5', {6: '', 7: set([1, 7])})[7] set([1, 2, 3, 4, 5, 7]) >>> spw_to_dict('7', {6: '', 7: set([1, 7])})[7] '' >>> spw_to_dict('7:123~127;233~267', {6: '', 7: set([1, 7])})[7] # Multiple chan ranges '' >>> spw_to_dict('5,7:123~127;233~267', {6: '', 7: set([1, 7])})[5] '' >>> spw_to_dict('5:3~5,7:123~127;233~267', {6: '', 7: set([1, 7])})[5] set([3, 4, 5]) """ if not spw: return {} myspwdict = copy.deepcopy(spwdict) # Because ; means different things when it separates spws and channel # ranges, I can't think of a better way to construct myspwdict than an # explicit state machine. (But $spws_alone =~ s/:[^,]+//g;) inspw = True # Must start with an spw. spwgrp = '' chagrp = '' def enter_ranges(spwg, chag): spwrange = expand_tilde(spwg) if spwrange == '*': # This shouldn't happen. return {} else: charange = expand_tilde(chag, conv_multiranges) for s in spwrange: if charange == '*': myspwdict[s] = '' else: if not s in myspwdict: myspwdict[s] = set([]) if myspwdict[s] != '': myspwdict[s].update(charange) for c in spw: if c == ',' or (inspw and c == ';'): # Start new [spw, chan] pair. # Store old one. enter_ranges(spwgrp, chagrp) # Initialize new one. spwgrp = '' chagrp = '' inspw = True elif c == ':': inspw = False elif inspw: spwgrp += c else: chagrp += c # Store final [spw, chan] pair. enter_ranges(spwgrp, chagrp) return myspwdict def join_spws(spw1, spw2, span_semicolon=True): """ Returns the union of spw selection strings spw1 and spw2. span_semicolon (default True): If True, for any spws that have > 1 channel range, the entire spw will be selected. Examples: >>> from update_spw import join_spws >>> join_spws('0~2:3~5,3:9~13', '') '' >>> join_spws('0~2:3~5,3:9~13', '1~3:4~7') '0:3~5,1~2:3~7,3' >>> join_spws('1~10:5~122,15~22:5~122', '1~10:5~122,15~22:5~122') '1~10:5~122,15~22:5~122' >>> join_spws('', '') '' >>> join_spws('1~10:5~122,15~22:5~122', '0~21') '0~21,22:5~122' """ if not spw1 or not spw2: return '' spwdict = spw_to_dict(spw1, {}) spwdict = spw_to_dict(spw2, spwdict) res = '' # Convert channel sets to strings for s in spwdict: cstr = '' if isinstance(spwdict[s], set): cstr = set_to_chanstr(spwdict[s]) if span_semicolon and ';' in cstr: cstr = '' spwdict[s] = cstr # If consecutive spws have the same channel selection, merge them. slist = list(spwdict.keys()) slist.sort() res = str(slist[0]) laststart = 0 for i in range(1, len(slist)): # If consecutive spws have the same channel list, if slist[i] == slist[i - 1] + 1 and spwdict[slist[i]] == spwdict[slist[i - 1]]: if slist[i] == slist[laststart] + 1: res += '~' # Continue the spw range. else: # Terminate it and start a new one. if res[-1] == '~': # if start != end res += str(slist[i - 1]) if spwdict[slist[i - 1]] != '': # Add channel range, if any. res += ':' + spwdict[slist[i - 1]] res += ',' + str(slist[i]) laststart = i if res[-1] == '~': # Finish the last range if it is dangling. res += str(slist[-1]) if spwdict[slist[-1]] != '': # Add channel range, if any. res += ':' + spwdict[slist[-1]] return res def intersect_spws(spw1, spw2): """ Almost the opposite of join_spws(), this returns the list of spws that the spw:chan selection strings spw1 and spw2 have in common. Unlike join_spws(), channel ranges are ignored. '' in the input counts as 'select everything', so the intersection of '' with anything is anything. If the intersection really is everything, '' is returned instead of a set. Examples: >>> from update_spw import intersect_spws >>> intersect_spws('0~2:3~5,3:9~13', '') set([0, 1, 2, 3]) >>> intersect_spws('0~2:3~5,3:9~13', '0~2:7~9,5') set([0, 1, 2]) >>> intersect_spws('0~2:3~5;10~13,3:9~13', '0~2:7~9,5') set([0, 1, 2]) >>> intersect_spws('0~2:3~5,3:9~13', '10~12:7~9,5') # Empty set set([]) >>> intersect_spws('', '') # Everything '' """ if spw1 == '': if spw2 == '': return '' # intersection('', '') = '' else: # intersection('', spw2) = spw2 return set(spw_to_dict(spw2, {}).keys()) # Just the spws, no chan ranges elif spw2 == '': # intersection('', spw1) = spw1 return set(spw_to_dict(spw1, {}).keys()) # Just the spws, no chan ranges else: spwset1 = set(spw_to_dict(spw1, {}).keys()) # spws are the keys, chan spwset2 = set(spw_to_dict(spw2, {}).keys()) # ranges are the values. return spwset1.intersection(spwset2) def subtract_spws(spw1, spw2): """ Returns the set of spws of spw selection string spw1 that are not in spw2. Like intersect_spws(), this intentionally ignores channel ranges. It assumes that spw1 and spw2 refer to the same MS (this only matters for ''). subtract_spws('', '0~5') is a tough case: it is impossible to know whether '' is equivalent to '0~5' without reading the MS's SPECTRAL_WINDOW subtable, so it returns 'UNKNOWN'. Examples: >>> from update_spw import subtract_spws >>> subtract_spws('0~2:3~5,3:9~13', '') # Anything - Everything set([]) >>> subtract_spws('0~2:3~5,3:9~13', '0~2:7~9,5') set([3]) >>> subtract_spws('', '0~2:7~9,5') # Everything - Something 'UNKNOWN' >>> subtract_spws('0~2,3:9~13', '4~7:7') # Something - Something Else set([0, 1, 2, 3]) >>> subtract_spws('', '') # Everything - Everything set([]) """ if spw1 == '': if spw2 == '': return set([]) else: return 'UNKNOWN' elif spw2 == '': return set([]) else: spwset1 = set(spw_to_dict(spw1, {}).keys()) # spws are the keys, chan spwset2 = set(spw_to_dict(spw2, {}).keys()) # ranges are the values. return spwset1.difference(spwset2) if __name__ == '__main__': import doctest, sys doctest.testmod(verbose=True)