######################################################################## # test_task_sdpolaverage.py # # Copyright (C) 2018 # Associated Universities, Inc. Washington DC, USA # # This script is free software; you can redistribute it and/or modify it # under the terms of the GNU Library General Public License as published by # the Free Software Foundation; either version 2 of the License, or (at your # option) any later version. # # This library is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public # License for more details. # # [Add the link to the JIRA ticket here once it exists] # # Based on the requirements listed in plone found here: # https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.single.sdpolaverage.html # # ########################################################################## import math import os import sys import unittest from casatasks import sdpolaverage from casatasks.private.sdutil import table_manager from casatools import ctsys datapath = ctsys.resolve('unittest/sdpolaverage/') def weighToSigma(weight): if weight > sys.float_info.min: return 1.0 / math.sqrt(weight) else: return -1.0 def sigmaToWeight(sigma): if sigma > sys.float_info.min: return 1.0 / math.pow(sigma, 2) else: return 0.0 def check_eq(val, expval, tol=None): """Check that val matches expval within tol.""" # print val if type(val) == dict: for k in val: check_eq(val[k], expval[k], tol) else: try: if tol and hasattr(val, '__rsub__'): are_eq = abs(val - expval) < tol else: are_eq = val == expval if hasattr(are_eq, 'all'): are_eq = are_eq.all() if not are_eq: raise ValueError('!=') except ValueError: errmsg = "%r != %r" % (val, expval) if (len(errmsg) > 66): # 66 = 78 - len('ValueError: ') errmsg = "\n%r\n!=\n%r" % (val, expval) raise ValueError(errmsg) except Exception as e: print("Error comparing {} to {}".format(val, expval)) raise e class test_sdpolaverage(unittest.TestCase): def setUp(self): self.inputms = "analytic_type1.fit.ms" self.outputms = "polave.ms" os.system('cp -RH ' + datapath + self.inputms + ' ' + self.inputms) def tearDown(self): os.system('rm -rf ' + self.inputms) os.system('rm -rf ' + self.outputms) def test_default(self): sdpolaverage(infile=self.inputms, outfile=self.outputms, datacolumn='float_data') with table_manager(self.inputms) as tb: indata = tb.getcell('FLOAT_DATA', 0) with table_manager(self.outputms) as tb: outdata = tb.getcell('FLOAT_DATA', 0) self.assertEqual(len(indata), len(outdata), 'Input and output data have different shape.') for i in range(len(indata)): for j in range(len(indata[0])): self.assertEqual(indata[i][j], outdata[i][j], 'Input and output data unidentical.') def test_stokes_float_data(self): sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='float_data') # check data with table_manager(self.inputms) as tb: indata = tb.getcell('FLOAT_DATA', 0) with table_manager(self.outputms) as tb: outdata = tb.getcell('FLOAT_DATA', 0) self.assertEqual(len(outdata), 1, 'No averaging over polarization?') tol = 1e-5 for i in range(len(indata[0])): mean = 0.5 * (indata[0][i] + indata[1][i]) check_eq(outdata[0][i], mean, tol) # check polarization id (should be 1) with table_manager(self.outputms) as tb: outddesc = tb.getcell('DATA_DESC_ID', 0) with table_manager(self.outputms + '/DATA_DESCRIPTION') as tb: outpolid = tb.getcol('POLARIZATION_ID') with table_manager(self.outputms + '/POLARIZATION') as tb: outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc]) self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.') self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.') def test_stokes_corrected_data(self): sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='corrected') # check data with table_manager(self.inputms) as tb: indata = tb.getcell('CORRECTED_DATA', 0) with table_manager(self.outputms) as tb: outdata = tb.getcell('DATA', 0) self.assertEqual(len(outdata), 1, 'No averaging over polarization?') tol = 1e-5 for i in range(len(indata[0])): mean = 0.5 * (indata[0][i] + indata[1][i]) check_eq(outdata[0][i].real, mean.real, tol) check_eq(outdata[0][i].imag, mean.imag, tol) # check polarization id (should be 1) with table_manager(self.outputms) as tb: outddesc = tb.getcell('DATA_DESC_ID', 0) with table_manager(self.outputms + '/DATA_DESCRIPTION') as tb: outpolid = tb.getcol('POLARIZATION_ID') with table_manager(self.outputms + '/POLARIZATION') as tb: outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc]) self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.') self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.') if __name__ == '__main__': unittest.main()