from collections import Counter
import datetime
import os
import shutil
from casatasks import casalog
from casatools import ms as mstool
from casatools import singledishms
from . import sdutil
from .mstools import write_history
ms = mstool()
@sdutil.callable_sdtask_decorator
def sdbaseline(infile=None, datacolumn=None, antenna=None, field=None,
spw=None, timerange=None, scan=None, pol=None, intent=None,
reindex=None, maskmode=None, thresh=None, avg_limit=None,
minwidth=None, edge=None, blmode=None, dosubtract=None,
blformat=None, bloutput=None, bltable=None, blfunc=None,
order=None, npiece=None, applyfft=None, fftmethod=None,
fftthresh=None, addwn=None, rejwn=None, clipthresh=None,
clipniter=None, blparam=None, verbose=None,
updateweight=None, sigmavalue=None,
showprogress=None, minnrow=None,
outfile=None, overwrite=None):
temp_outfile = ''
try:
# CAS-12985 requests the following params be given case insensitively,
# so they need to be converted to lowercase here (2021/1/28 WK)
blfunc = blfunc.lower()
blmode = blmode.lower()
fftmethod = fftmethod.lower()
if isinstance(fftthresh, str):
fftthresh = fftthresh.lower()
if (spw == ''):
spw = '*'
if not os.path.exists(infile):
raise ValueError("infile='" + str(infile) + "' does not exist.")
if (outfile == '') or not isinstance(outfile, str):
outfile = infile.rstrip('/') + '_bs'
casalog.post("outfile is empty or non-string. set to '" + outfile + "'")
if (not overwrite) and dosubtract and os.path.exists(outfile):
raise ValueError("outfile='%s' exists, and cannot overwrite it." % (outfile))
if (blfunc == 'variable') and not os.path.exists(blparam):
raise ValueError("input file '%s' does not exists" % blparam)
if (blmode == 'fit'):
temp_outfile = _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan,
pol, intent, reindex, maskmode, thresh, avg_limit, minwidth,
edge, dosubtract, blformat, bloutput, blfunc, order, npiece,
applyfft, fftmethod, fftthresh, addwn, rejwn, clipthresh,
clipniter, blparam, verbose, updateweight, sigmavalue,
outfile, overwrite)
elif (blmode == 'apply'):
_do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
reindex, bltable, updateweight, sigmavalue, outfile, overwrite)
# Remove {WEIGHT|SIGMA}_SPECTRUM columns if updateweight=True (CAS-13161)
if updateweight:
with sdutil.table_manager(outfile, nomodify=False) as mytb:
cols_spectrum = ['WEIGHT_SPECTRUM', 'SIGMA_SPECTRUM']
cols_remove = [col for col in cols_spectrum if col in mytb.colnames()]
if len(cols_remove) > 0:
mytb.removecols(' '.join(cols_remove))
# Write history to outfile
if dosubtract:
param_names = sdbaseline.__code__.co_varnames[:sdbaseline.__code__.co_argcount]
var_local = locals()
param_vals = [var_local[p] for p in param_names]
write_history(ms, outfile, 'sdbaseline', param_names,
param_vals, casalog)
finally:
if (not dosubtract):
# Remove (skeleton) outfile
if temp_outfile != '':
outfile = temp_outfile
remove_data(outfile)
blformat_item = ['csv', 'text', 'table']
blformat_ext = ['csv', 'txt', 'bltable']
mesg_invalid_wavenumber = 'wrong value given for addwn/rejwn'
def remove_data(filename):
if not os.path.exists(filename):
return
if os.path.isdir(filename):
shutil.rmtree(filename)
elif os.path.isfile(filename):
os.remove(filename)
else:
# could be a symlink
os.remove(filename)
def is_empty(blformat):
"""Check if blformat is empty.
returns True if blformat is None, '', [] and
a string list containing only '' (i.e., ['', '', ..., ''])
"""
if isinstance(blformat, list):
return all(map(is_empty, blformat))
return not blformat
def prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite):
# force to string list
blformat = force_to_string_list(blformat, 'blformat')
bloutput = force_to_string_list(bloutput, 'bloutput')
# the default bloutput value '' is expanded to a list
# with length of blformat, and with '' throughout.
if (bloutput == ['']):
bloutput *= len(blformat)
# check length
if (len(blformat) != len(bloutput)):
raise ValueError('blformat and bloutput must have the same length.')
# check duplication
if has_duplicate_nonnull_element(blformat):
raise ValueError('duplicate elements in blformat.')
if has_duplicate_nonnull_element_ex(bloutput, blformat):
raise ValueError('duplicate elements in bloutput.')
# fill bloutput items to be output, then rearrange them
# in the order of blformat_item.
bloutput = normalise_bloutput(infile, blformat, bloutput, overwrite)
return blformat, bloutput
def force_to_string_list(s, name):
mesg = '%s must be string or list of string.' % name
if isinstance(s, str):
s = [s]
elif isinstance(s, list):
for i in range(len(s)):
if not isinstance(s[i], str):
raise ValueError(mesg)
else:
raise ValueError(mesg)
return s
def has_duplicate_nonnull_element(in_list):
# return True if in_list has duplicated elements other than ''
duplicates = [key for key, val in Counter(in_list).items() if val > 1]
len_duplicates = len(duplicates)
if (len_duplicates >= 2):
return True
elif (len_duplicates == 1):
return (duplicates[0] != '')
else: # len_duplicates == 0
return False
def has_duplicate_nonnull_element_ex(lst, base):
# lst and base must have the same length.
#
# (1) extract elements from lst and make a new list
# if the element of base with the same index
# is not ''.
# (2) check if the list made in (1) has duplicated
# elements other than ''.
return has_duplicate_nonnull_element(
[lst[i] for i in range(len(lst)) if base[i] != ''])
def normalise_bloutput(infile, blformat, bloutput, overwrite):
return [get_normalised_name(infile, blformat, bloutput, item[0], item[1], overwrite)
for item in zip(blformat_item, blformat_ext)]
def get_normalised_name(infile, blformat, bloutput, name, ext, overwrite):
fname = ''
blformat_lower = [s.lower() for s in blformat]
if (name in blformat_lower):
fname = bloutput[blformat_lower.index(name)]
if (fname == ''):
fname = infile + '_blparam.' + ext
if os.path.exists(fname):
if overwrite:
remove_data(fname)
else:
raise ValueError(fname + ' exists.')
return fname
def output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile):
fname = bloutput[blformat_item.index('text')]
if (fname == ''):
return
with open(fname, 'w') as f:
info = [['Source Table', infile],
['Output File', outfile if (outfile != '') else infile],
['Mask mode', maskmode]]
separator = '#' * 60 + '\n'
f.write(separator)
for i in range(len(info)):
f.write('%12s: %s\n' % tuple(info[i]))
f.write(separator)
f.write('\n')
def get_temporary_file_name(basename):
name = basename + '_sdbaseline_pid' + str(os.getpid()) + '_' \
+ datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
return name
def parse_wavenumber_param(wn):
if isinstance(wn, bool):
raise ValueError(mesg_invalid_wavenumber)
elif isinstance(wn, list):
__check_positive_or_zero(wn)
wn_uniq = list(set(wn))
wn_uniq.sort()
return ','.join(__get_strlist(wn_uniq))
elif isinstance(wn, tuple):
__check_positive_or_zero(wn)
wn_uniq = list(set(wn))
wn_uniq.sort()
return ','.join(__get_strlist(wn_uniq))
elif isinstance(wn, int):
__check_positive_or_zero(wn)
return str(wn)
elif isinstance(wn, str):
if '.' in wn:
# case of float value as string
raise ValueError(mesg_invalid_wavenumber)
elif ',' in wn:
# cases 'a,b,c,...'
val0 = wn.split(',')
__check_positive_or_zero(val0)
val = []
for v in val0:
val.append(int(v))
res = list(set(val)) # uniq
res.sort()
elif '-' in wn:
# case 'a-b' : return [a,a+1,...,b-1,b]
val = wn.split('-')
__check_positive_or_zero(val)
val = [int(val[0]), int(val[1])]
val.sort()
res = [i for i in range(val[0], val[1] + 1)]
elif '~' in wn:
# case 'a~b' : return [a,a+1,...,b-1,b]
val = wn.split('~')
__check_positive_or_zero(val)
val = [int(val[0]), int(val[1])]
val.sort()
res = [i for i in range(val[0], val[1] + 1)]
elif wn[:2] == '<=' or wn[:2] == '=<':
# cases '<=a','==' or wn[-2:] == '=>':
# cases 'a>=','a=>' : return [0,1,...,a-1,a]
val = wn[:-2]
__check_positive_or_zero(val)
res = [i for i in range(int(val) + 1)]
elif wn[0] == '<':
# case '':
# case 'a>' : return [0,1,...,a-2,a-1]
val = wn[:-1]
__check_positive_or_zero(val, False)
res = [i for i in range(int(val))]
elif wn[:2] == '>=' or wn[:2] == '=>':
# cases '>=a','=>a' : return [a,-999], which is
# then interpreted in C++
# side as [a,a+1,...,a_nyq]
# (CAS-3759)
val = wn[2:]
__check_positive_or_zero(val)
res = [int(val), -999]
elif wn[-2:] == '<=' or wn[-2:] == '=<':
# cases 'a<=','a=<' : return [a,-999], which is
# then interpreted in C++
# side as [a,a+1,...,a_nyq]
# (CAS-3759)
val = wn[:-2]
__check_positive_or_zero(val)
res = [int(val), -999]
elif wn[0] == '>':
# case '>a' : return [a+1,-999], which is
# then interpreted in C++
# side as [a+1,a+2,...,a_nyq]
# (CAS-3759)
val0 = wn[1:]
val = int(val0) + 1
__check_positive_or_zero(val)
res = [val, -999]
elif wn[-1] == '<':
# case 'a<' : return [a+1,-999], which is
# then interpreted in C++
# side as [a+1,a+2,...,a_nyq]
# (CAS-3759)
val0 = wn[:-1]
val = int(val0) + 1
__check_positive_or_zero(val)
res = [val, -999]
else:
# case 'a'
__check_positive_or_zero(wn)
res = [int(wn)]
# return res
return ','.join(__get_strlist(res))
else:
raise ValueError(mesg_invalid_wavenumber)
def __get_strlist(param):
return [str(p) for p in param]
def check_fftthresh(fftthresh):
"""Validate fftthresh value.
The fftthresh must be one of the following:
(1) positive value (float, integer or string)
(2) 'top' + positive integer value
(3) positive float value + 'sigma'
"""
has_invalid_type = False
val_not_positive = False
if isinstance(fftthresh, bool):
# Checking for bool must precede checking for integer
has_invalid_type = True
elif isinstance(fftthresh, int) or isinstance(fftthresh, float):
if (fftthresh <= 0.0):
val_not_positive = True
elif isinstance(fftthresh, str):
try:
if (3 < len(fftthresh)) and (fftthresh[:3] == 'top'):
if (int(fftthresh[3:]) <= 0):
val_not_positive = True
elif (5 < len(fftthresh)) and (fftthresh[-5:] == 'sigma'):
if (float(fftthresh[:-5]) <= 0.0):
val_not_positive = True
else:
if (float(fftthresh) <= 0.0):
val_not_positive = True
except Exception:
raise ValueError('fftthresh has a wrong format.')
else:
has_invalid_type = True
if has_invalid_type:
raise ValueError('fftthresh must be float or integer or string.')
if val_not_positive:
raise ValueError('threshold given to fftthresh must be positive.')
def __check_positive_or_zero(param, allowzero=True):
if isinstance(param, list) or isinstance(param, tuple):
for i in range(len(param)):
__do_check_positive_or_zero(int(param[i]), allowzero)
elif isinstance(param, int):
__do_check_positive_or_zero(param, allowzero)
elif isinstance(param, str):
__do_check_positive_or_zero(int(param), allowzero)
else:
raise ValueError(mesg_invalid_wavenumber)
def __do_check_positive_or_zero(param, allowzero):
if (param < 0) or ((param == 0) and not allowzero):
raise ValueError(mesg_invalid_wavenumber)
def prepare_for_baselining(**keywords):
params = {}
funcname = 'subtract_baseline'
blfunc = keywords['blfunc']
keys = ['datacolumn', 'outfile', 'bloutput', 'dosubtract', 'spw',
'updateweight', 'sigmavalue']
if blfunc in ['poly', 'chebyshev']:
keys += ['blfunc', 'order']
elif blfunc == 'cspline':
keys += ['npiece']
funcname += ('_' + blfunc)
elif blfunc == 'sinusoid':
keys += ['applyfft', 'fftmethod', 'fftthresh', 'addwn', 'rejwn']
funcname += ('_' + blfunc)
elif blfunc == 'variable':
keys += ['blparam', 'verbose']
funcname += ('_' + blfunc)
else:
raise ValueError("Unsupported blfunc = %s" % blfunc)
if blfunc != 'variable':
keys += ['clip_threshold_sigma', 'num_fitting_max']
keys += ['linefinding', 'threshold', 'avg_limit', 'minwidth', 'edge']
for key in keys:
params[key] = keywords[key]
baseline_func = getattr(keywords['sdms'], funcname)
return params, baseline_func
def remove_sorted_table_keyword(infile):
res = {'is_sorttab': False, 'sorttab_keywd': '', 'sorttab_name': ''}
with sdutil.table_manager(infile, nomodify=False) as tb:
sorttab_keywd = 'SORTED_TABLE'
if sorttab_keywd in tb.keywordnames():
res['is_sorttab'] = True
res['sorttab_keywd'] = sorttab_keywd
res['sorttab_name'] = tb.getkeyword(sorttab_keywd)
tb.removekeyword(sorttab_keywd)
return res
def restore_sorted_table_keyword(infile, sorttab_info):
if sorttab_info['is_sorttab'] and (sorttab_info['sorttab_name'] != ''):
with sdutil.table_manager(infile, nomodify=False) as tb:
tb.putkeyword(sorttab_info['sorttab_keywd'], sorttab_info['sorttab_name'])
def _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
reindex, bltable, updateweight, sigmavalue, outfile, overwrite):
if not os.path.exists(bltable):
raise ValueError("file specified in bltable '%s' does not exist." % bltable)
# Note: the condition "infile != outfile" in the following line is for safety
# to prevent from accidentally removing infile by setting outfile=infile.
# Don't remove it.
if overwrite and (infile != outfile) and os.path.exists(outfile):
remove_data(outfile)
sorttab_info = remove_sorted_table_keyword(infile)
with sdutil.tool_manager(infile, singledishms) as mysdms:
selection = ms.msseltoindex(vis=infile, spw=spw, field=field,
baseline=antenna, time=timerange,
scan=scan)
mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field,
antenna=antenna, timerange=timerange,
scan=scan, polarization=pol, intent=intent,
reindex=reindex)
mysdms.apply_baseline_table(bltable=bltable,
datacolumn=datacolumn,
spw=spw,
updateweight=updateweight,
sigmavalue=sigmavalue,
outfile=outfile)
restore_sorted_table_keyword(infile, sorttab_info)
def _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
reindex, maskmode, thresh, avg_limit, minwidth, edge, dosubtract, blformat,
bloutput, blfunc, order, npiece, applyfft, fftmethod, fftthresh, addwn,
rejwn, clipthresh, clipniter, blparam, verbose, updateweight, sigmavalue,
outfile, overwrite):
temp_outfile = ''
if (not dosubtract) and is_empty(blformat):
raise ValueError("blformat must be specified when dosubtract is False")
blformat, bloutput = prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite)
output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile)
# Set temporary name for output MS if dosubtract is False and outfile exists
# for not removing/overwriting outfile that already exists
if os.path.exists(outfile):
# Note: the condition "infile != outfile" in the following line is for safety
# to prevent from accidentally removing infile by setting outfile=infile
# Don't remove it.
if dosubtract and overwrite and (infile != outfile):
remove_data(outfile)
elif (not dosubtract):
outfile = get_temporary_file_name(infile)
temp_outfile = outfile
if (blfunc == 'variable'):
sorttab_info = remove_sorted_table_keyword(infile)
elif (blfunc == 'sinusoid'):
addwn = parse_wavenumber_param(addwn)
rejwn = parse_wavenumber_param(rejwn)
check_fftthresh(fftthresh)
with sdutil.tool_manager(infile, singledishms) as mysdms:
selection = ms.msseltoindex(vis=infile, spw=spw, field=field, baseline=antenna,
time=timerange, scan=scan)
mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field, antenna=antenna,
timerange=timerange, scan=scan, polarization=pol, intent=intent,
reindex=reindex)
params, func = prepare_for_baselining(sdms=mysdms,
blfunc=blfunc,
datacolumn=datacolumn,
outfile=outfile,
bloutput=','.join(bloutput),
dosubtract=dosubtract,
spw=spw,
pol=pol,
linefinding=(maskmode == 'auto'),
threshold=thresh,
avg_limit=avg_limit,
minwidth=minwidth,
edge=edge,
order=order,
npiece=npiece,
applyfft=applyfft,
fftmethod=fftmethod,
fftthresh=fftthresh,
addwn=addwn,
rejwn=rejwn,
clip_threshold_sigma=clipthresh,
num_fitting_max=clipniter + 1,
blparam=blparam,
verbose=verbose,
updateweight=updateweight,
sigmavalue=sigmavalue)
func(**params)
if (blfunc == 'variable'):
restore_sorted_table_keyword(infile, sorttab_info)
return temp_outfile