1 + | import unittest |
2 + | import os |
3 + | import math |
4 + | import sys |
5 + | |
6 + | from casatasks.private.casa_transition import is_CASA6 |
7 + | if is_CASA6: |
8 + | from casatasks import sdpolaverage |
9 + | from casatasks.private.sdutil import tbmanager |
10 + | from casatools import ctsys |
11 + | datapath = ctsys.resolve('regression/unittest/tsdfit/') |
12 + | |
13 + | |
14 + | def default(atask): |
15 + | pass |
16 + | |
17 + | else: |
18 + | from tasks import sdpolaverage |
19 + | from __main__ import default |
20 + | from sdutil import tbmanager |
21 + | |
22 + | |
23 + | datapath = os.environ.get('CASAPATH').split()[0] + "/data/regression/unittest/tsdfit/" |
24 + | |
25 + | |
26 + | def weighToSigma(weight): |
27 + | if weight > sys.float_info.min: |
28 + | return 1.0 / math.sqrt(weight) |
29 + | else: |
30 + | return -1.0 |
31 + | |
32 + | |
33 + | def sigmaToWeight(sigma): |
34 + | if sigma > sys.float_info.min: |
35 + | return 1.0 / math.pow(sigma, 2) |
36 + | else: |
37 + | return 0.0 |
38 + | |
39 + | |
40 + | def check_eq(val, expval, tol=None): |
41 + | """Checks that val matches expval within tol.""" |
42 + | |
43 + | if type(val) == dict: |
44 + | for k in val: |
45 + | check_eq(val[k], expval[k], tol) |
46 + | else: |
47 + | try: |
48 + | if tol and hasattr(val, '__rsub__'): |
49 + | are_eq = abs(val - expval) < tol |
50 + | else: |
51 + | are_eq = val == expval |
52 + | if hasattr(are_eq, 'all'): |
53 + | are_eq = are_eq.all() |
54 + | if not are_eq: |
55 + | raise ValueError('!=') |
56 + | except ValueError: |
57 + | errmsg = "%r != %r" % (val, expval) |
58 + | if (len(errmsg) > 66): |
59 + | errmsg = "\n%r\n!=\n%r" % (val, expval) |
60 + | raise ValueError(errmsg) |
61 + | except Exception as e: |
62 + | print("Error comparing {} to {}".format(val, expval)) |
63 + | raise e |
64 + | |
65 + | |
66 + | class test_sdpolaverage(unittest.TestCase): |
67 + | def setUp(self): |
68 + | self.inputms = "analytic_type1.fit.ms" |
69 + | self.outputms = "polave.ms" |
70 + | |
71 + | os.system('cp -RL ' + datapath + self.inputms + ' ' + self.inputms) |
72 + | default(sdpolaverage) |
73 + | |
74 + | def tearDown(self): |
75 + | os.system('rm -rf ' + self.inputms) |
76 + | os.system('rm -rf ' + self.outputms) |
77 + | |
78 + | def test_default(self): |
79 + | sdpolaverage(infile=self.inputms, outfile=self.outputms, datacolumn='float_data') |
80 + | with tbmanager(self.inputms) as tb: |
81 + | indata = tb.getcell('FLOAT_DATA', 0) |
82 + | with tbmanager(self.outputms) as tb: |
83 + | outdata = tb.getcell('FLOAT_DATA', 0) |
84 + | |
85 + | self.assertEqual(len(indata), len(outdata), 'Input and output data have different shape.') |
86 + | for i in range(len(indata)): |
87 + | for j in range(len(indata[0])): |
88 + | self.assertEqual(indata[i][j], outdata[i][j], 'Input and output data unidentical.') |
89 + | |
90 + | def test_stokes_float_data(self): |
91 + | sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='float_data') |
92 + | |
93 + | with tbmanager(self.inputms) as tb: |
94 + | indata = tb.getcell('FLOAT_DATA', 0) |
95 + | with tbmanager(self.outputms) as tb: |
96 + | outdata = tb.getcell('FLOAT_DATA', 0) |
97 + | |
98 + | self.assertEqual(len(outdata), 1, 'No averaging over polarization?') |
99 + | tol = 1e-5 |
100 + | for i in range(len(indata[0])): |
101 + | mean = 0.5 * (indata[0][i] + indata[1][i]) |
102 + | check_eq(outdata[0][i], mean, tol) |
103 + | |
104 + | |
105 + | with tbmanager(self.outputms) as tb: |
106 + | outddesc = tb.getcell('DATA_DESC_ID', 0) |
107 + | with tbmanager(self.outputms + '/DATA_DESCRIPTION') as tb: |
108 + | outpolid = tb.getcol('POLARIZATION_ID') |
109 + | with tbmanager(self.outputms + '/POLARIZATION') as tb: |
110 + | outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc]) |
111 + | |
112 + | self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.') |
113 + | self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.') |
114 + | |
115 + | def test_stokes_corrected_data(self): |
116 + | sdpolaverage(infile=self.inputms, outfile=self.outputms, polaverage='stokes', datacolumn='corrected') |
117 + | |
118 + | with tbmanager(self.inputms) as tb: |
119 + | indata = tb.getcell('CORRECTED_DATA', 0) |
120 + | with tbmanager(self.outputms) as tb: |
121 + | outdata = tb.getcell('DATA', 0) |
122 + | |
123 + | self.assertEqual(len(outdata), 1, 'No averaging over polarization?') |
124 + | tol = 1e-5 |
125 + | for i in range(len(indata[0])): |
126 + | mean = 0.5 * (indata[0][i] + indata[1][i]) |
127 + | check_eq(outdata[0][i].real, mean.real, tol) |
128 + | check_eq(outdata[0][i].imag, mean.imag, tol) |
129 + | |
130 + | |
131 + | with tbmanager(self.outputms) as tb: |
132 + | outddesc = tb.getcell('DATA_DESC_ID', 0) |
133 + | with tbmanager(self.outputms + '/DATA_DESCRIPTION') as tb: |
134 + | outpolid = tb.getcol('POLARIZATION_ID') |
135 + | with tbmanager(self.outputms + '/POLARIZATION') as tb: |
136 + | outpoltype = tb.getcell('CORR_TYPE', outpolid[outddesc]) |
137 + | |
138 + | self.assertEqual(len(outpoltype), 1, 'Polarization id is inconsistent with data.') |
139 + | self.assertEqual(outpoltype[0], 1, 'Has wrong polarization id.') |
140 + | |
141 + | |
142 + | def suite(): |
143 + | return [test_sdpolaverage] |
144 + | |
145 + | |
146 + | if is_CASA6: |
147 + | if __name__ == '__main__': |
148 + | unittest.main() |