from __future__ import absolute_import from __future__ import print_function import unittest import os import math import sys from casatasks.private.casa_transition import is_CASA6 if is_CASA6: from casatasks import sdpolaverage from casatasks.private.sdutil import tbmanager from casatools import ctsys datapath = ctsys.resolve('unittest/sdpolaverage/') # default isn't used in casatasks def default(atask): pass else: from tasks import sdpolaverage from __main__ import default from sdutil import tbmanager # Define the root for the data files datapath = os.environ.get('CASAPATH').split()[0] + "/casatestdata/unittest/sdpolaverage/" def weighToSigma(weight): if weight > sys.float_info.min: return 1.0 / math.sqrt(weight) else: return -1.0 def sigmaToWeight(sigma): if sigma > sys.float_info.min: return 1.0 / math.pow(sigma, 2) else: return 0.0 def check_eq(val, expval, tol=None): """Checks that val matches expval within tol.""" # print val if type(val) == dict: for k in val: check_eq(val[k], expval[k], tol) else: try: if tol and hasattr(val, '__rsub__'): are_eq = abs(val - expval) < tol else: are_eq = val == expval if hasattr(are_eq, 'all'): are_eq = are_eq.all() if not are_eq: raise ValueError('!=') except ValueError: errmsg = "%r != %r" % (val, expval) if (len(errmsg) > 66): # 66 = 78 - len('ValueError: ') errmsg = "\n%r\n!=\n%r" % (val, expval) raise ValueError(errmsg) except Exception as e: print("Error comparing {} to {}".format(val, expval)) raise e class test_sdpolaverage(unittest.TestCase): def setUp(self): self.inputms = "analytic_type1.fit.ms" self.outputms = "polave.ms" #datapath = os.environ.get('CASAPATH').split()[0] + "/data/regression/unittest/tsdfit/" os.system('cp -RH ' + datapath + self.inputms + ' ' + self.inputms) default(sdpolaverage) def tearDown(self): os.system('rm -rf ' + self.inputms) os.system('rm -rf ' + self.outputms) def test_default(self): sdpolaverage(infile=self.inputms, outfile=self.outputms, datacolumn='float_data') with tbmanager(self.inputms) as tb: indata = tb.getcell('FLOAT_DATA', 0) with tbmanager(self.outputms) as tb: outdata = tb.getcell('FLOAT_DATA', 0) self.assertEqual(len(indata), len(outdata), 'Input and output data have different shape.') for i in range(len(indata)): for j in range(len(indata[0])): self.assertEqual(indata[i][j], outdata[i][j], 'Input and output data unidentical.') def test_stokes_float_data(self): sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='float_data') # check data with tbmanager(self.inputms) as tb: indata = tb.getcell('FLOAT_DATA', 0) with tbmanager(self.outputms) as tb: outdata = tb.getcell('FLOAT_DATA', 0) self.assertEqual(len(outdata), 1, 'No averaging over polarization?') tol = 1e-5 for i in range(len(indata[0])): mean = 0.5 * (indata[0][i] + indata[1][i]) check_eq(outdata[0][i], mean, tol) # check polarization id (should be 1) with tbmanager(self.outputms) as tb: outddesc = tb.getcell('DATA_DESC_ID', 0) with tbmanager(self.outputms + '/DATA_DESCRIPTION') as tb: outpolid = tb.getcol('POLARIZATION_ID') with tbmanager(self.outputms + '/POLARIZATION') as tb: outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc]) self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.') self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.') def test_stokes_corrected_data(self): sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='corrected') # check data with tbmanager(self.inputms) as tb: indata = tb.getcell('CORRECTED_DATA', 0) with tbmanager(self.outputms) as tb: outdata = tb.getcell('DATA', 0) self.assertEqual(len(outdata), 1, 'No averaging over polarization?') tol = 1e-5 for i in range(len(indata[0])): mean = 0.5 * (indata[0][i] + indata[1][i]) check_eq(outdata[0][i].real, mean.real, tol) check_eq(outdata[0][i].imag, mean.imag, tol) # check polarization id (should be 1) with tbmanager(self.outputms) as tb: outddesc = tb.getcell('DATA_DESC_ID', 0) with tbmanager(self.outputms + '/DATA_DESCRIPTION') as tb: outpolid = tb.getcol('POLARIZATION_ID') with tbmanager(self.outputms + '/POLARIZATION') as tb: outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc]) self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.') self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.') def suite(): return [test_sdpolaverage] if is_CASA6: if __name__ == '__main__': unittest.main()