Commits

Takeshi Nakazato authored 832ba36b7db
CAS-12940 add test for sdpolaverage
No tags

casatasks/tests/tasks/test_sdpolaverage.py

Added
1 +import unittest
2 +import os
3 +import math
4 +import sys
5 +
6 +from casatasks.private.casa_transition import is_CASA6
7 +if is_CASA6:
8 + from casatasks import sdpolaverage
9 + from casatasks.private.sdutil import tbmanager
10 + from casatools import ctsys
11 + datapath = ctsys.resolve('regression/unittest/tsdfit/')
12 +
13 + # default isn't used in casatasks
14 + def default(atask):
15 + pass
16 +
17 +else:
18 + from tasks import sdpolaverage
19 + from __main__ import default
20 + from sdutil import tbmanager
21 +
22 + # Define the root for the data files
23 + datapath = os.environ.get('CASAPATH').split()[0] + "/data/regression/unittest/tsdfit/"
24 +
25 +
26 +def weighToSigma(weight):
27 + if weight > sys.float_info.min:
28 + return 1.0 / math.sqrt(weight)
29 + else:
30 + return -1.0
31 +
32 +
33 +def sigmaToWeight(sigma):
34 + if sigma > sys.float_info.min:
35 + return 1.0 / math.pow(sigma, 2)
36 + else:
37 + return 0.0
38 +
39 +
40 +def check_eq(val, expval, tol=None):
41 + """Checks that val matches expval within tol."""
42 +# print val
43 + if type(val) == dict:
44 + for k in val:
45 + check_eq(val[k], expval[k], tol)
46 + else:
47 + try:
48 + if tol and hasattr(val, '__rsub__'):
49 + are_eq = abs(val - expval) < tol
50 + else:
51 + are_eq = val == expval
52 + if hasattr(are_eq, 'all'):
53 + are_eq = are_eq.all()
54 + if not are_eq:
55 + raise ValueError('!=')
56 + except ValueError:
57 + errmsg = "%r != %r" % (val, expval)
58 + if (len(errmsg) > 66): # 66 = 78 - len('ValueError: ')
59 + errmsg = "\n%r\n!=\n%r" % (val, expval)
60 + raise ValueError(errmsg)
61 + except Exception as e:
62 + print("Error comparing {} to {}".format(val, expval))
63 + raise e
64 +
65 +
66 +class test_sdpolaverage(unittest.TestCase):
67 + def setUp(self):
68 + self.inputms = "analytic_type1.fit.ms"
69 + self.outputms = "polave.ms"
70 + #datapath = os.environ.get('CASAPATH').split()[0] + "/data/regression/unittest/tsdfit/"
71 + os.system('cp -RL ' + datapath + self.inputms + ' ' + self.inputms)
72 + default(sdpolaverage)
73 +
74 + def tearDown(self):
75 + os.system('rm -rf ' + self.inputms)
76 + os.system('rm -rf ' + self.outputms)
77 +
78 + def test_default(self):
79 + sdpolaverage(infile=self.inputms, outfile=self.outputms, datacolumn='float_data')
80 + with tbmanager(self.inputms) as tb:
81 + indata = tb.getcell('FLOAT_DATA', 0)
82 + with tbmanager(self.outputms) as tb:
83 + outdata = tb.getcell('FLOAT_DATA', 0)
84 +
85 + self.assertEqual(len(indata), len(outdata), 'Input and output data have different shape.')
86 + for i in range(len(indata)):
87 + for j in range(len(indata[0])):
88 + self.assertEqual(indata[i][j], outdata[i][j], 'Input and output data unidentical.')
89 +
90 + def test_stokes_float_data(self):
91 + sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='float_data')
92 + # check data
93 + with tbmanager(self.inputms) as tb:
94 + indata = tb.getcell('FLOAT_DATA', 0)
95 + with tbmanager(self.outputms) as tb:
96 + outdata = tb.getcell('FLOAT_DATA', 0)
97 +
98 + self.assertEqual(len(outdata), 1, 'No averaging over polarization?')
99 + tol = 1e-5
100 + for i in range(len(indata[0])):
101 + mean = 0.5 * (indata[0][i] + indata[1][i])
102 + check_eq(outdata[0][i], mean, tol)
103 +
104 + # check polarization id (should be 1)
105 + with tbmanager(self.outputms) as tb:
106 + outddesc = tb.getcell('DATA_DESC_ID', 0)
107 + with tbmanager(self.outputms + '/DATA_DESCRIPTION') as tb:
108 + outpolid = tb.getcol('POLARIZATION_ID')
109 + with tbmanager(self.outputms + '/POLARIZATION') as tb:
110 + outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc])
111 +
112 + self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.')
113 + self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.')
114 +
115 + def test_stokes_corrected_data(self):
116 + sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='corrected')
117 + # check data
118 + with tbmanager(self.inputms) as tb:
119 + indata = tb.getcell('CORRECTED_DATA', 0)
120 + with tbmanager(self.outputms) as tb:
121 + outdata = tb.getcell('DATA', 0)
122 +
123 + self.assertEqual(len(outdata), 1, 'No averaging over polarization?')
124 + tol = 1e-5
125 + for i in range(len(indata[0])):
126 + mean = 0.5 * (indata[0][i] + indata[1][i])
127 + check_eq(outdata[0][i].real, mean.real, tol)
128 + check_eq(outdata[0][i].imag, mean.imag, tol)
129 +
130 + # check polarization id (should be 1)
131 + with tbmanager(self.outputms) as tb:
132 + outddesc = tb.getcell('DATA_DESC_ID', 0)
133 + with tbmanager(self.outputms + '/DATA_DESCRIPTION') as tb:
134 + outpolid = tb.getcol('POLARIZATION_ID')
135 + with tbmanager(self.outputms + '/POLARIZATION') as tb:
136 + outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc])
137 +
138 + self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.')
139 + self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.')
140 +
141 +
142 +def suite():
143 + return [test_sdpolaverage]
144 +
145 +
146 +if is_CASA6:
147 + if __name__ == '__main__':
148 + unittest.main()

Everything looks good. We'll let you know here if there's anything you should know about.

Add shortcut