Commits

Federico Montesino Pouzols authored ff3bfd9290c Merge
Merge remote-tracking branch 'origin/master' into bugfix/CAS-11397
No tags

code/mstransform/MSTransform/StatWt.cc

Modified
11 11 //# This library is distributed in the hope that it will be useful,
12 12 //# but WITHOUT ANY WARRANTY, without even the implied warranty of
13 13 //# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 14 //# Lesser General Public License for more details.
15 15 //#
16 16 //# You should have received a copy of the GNU Lesser General Public
17 17 //# License along with this library; if not, write to the Free Software
18 18 //# Foundation, Inc., 59 Temple Place, Suite 330, Boston,
19 19 //# MA 02111-1307 USA
20 20
21 +#include <mstransform/MSTransform/StatWt.h>
22 +
21 23 #include <casacore/casa/Containers/ValueHolder.h>
22 24 #include <casacore/casa/Quanta/QuantumHolder.h>
23 25 #include <casacore/casa/System/ProgressMeter.h>
24 26 #include <casacore/ms/MSOper/MSMetaData.h>
25 27 #include <casacore/tables/Tables/ArrColDesc.h>
26 28 #include <casacore/tables/Tables/TableProxy.h>
27 29 #include <casacore/tables/DataMan/TiledShapeStMan.h>
28 30
29 -#include <mstransform/MSTransform/StatWt.h>
31 +#include <mstransform/MSTransform/StatWtColConfig.h>
30 32 #include <mstransform/TVI/StatWtTVI.h>
31 33 #include <mstransform/TVI/StatWtTVILayerFactory.h>
32 34 #include <msvis/MSVis/ViImplementation2.h>
33 35 #include <msvis/MSVis/IteratingParameters.h>
34 -#include <msvis/MSVis/LayeredVi2Factory.h>
35 36
36 37 using namespace casacore;
37 38
38 39 namespace casa {
39 40
40 -StatWt::StatWt(MeasurementSet* ms) : _ms(ms), _saf() {
41 +StatWt::StatWt(
42 + MeasurementSet* ms,
43 + const StatWtColConfig* const statwtColConfig
44 +) : _ms(ms),
45 + _saf(), _statwtColConfig(statwtColConfig) {
41 46 ThrowIf(! _ms, "Input MS pointer cannot be NULL");
47 + ThrowIf(
48 + ! _statwtColConfig,
49 + "Input column configuration pointer cannot be NULL"
50 + );
42 51 }
43 52
44 53 StatWt::~StatWt() {}
45 54
46 55 void StatWt::setOutputMS(const casacore::String& outname) {
47 56 _outname = outname;
48 57 }
49 58
50 59 void StatWt::setTimeBinWidth(const casacore::Quantity& binWidth) {
51 60 _timeBinWidth = vi::StatWtTVI::getTimeBinWidthInSec(binWidth);
69 78 }
70 79
71 80 void StatWt::setPreview(casacore::Bool preview) {
72 81 _preview = preview;
73 82 }
74 83
75 84 void StatWt::setTVIConfig(const Record& config) {
76 85 _tviConfig = config;
77 86 }
78 87
79 -Record StatWt::writeWeights() const {
80 - auto hasWtSp = _ms->isColumn(MSMainEnums::WEIGHT_SPECTRUM);
81 - auto mustWriteWtSp = ! _preview
82 - && _tviConfig.isDefined(vi::StatWtTVI::CHANBIN);
83 - if (mustWriteWtSp) {
84 - auto type = _tviConfig.type(_tviConfig.fieldNumber(vi::StatWtTVI::CHANBIN));
85 - if (type == TpArrayBool) {
86 - // default variant type
87 - mustWriteWtSp = False;
88 - }
89 - else if (type == TpString) {
90 - auto val = _tviConfig.asString(vi::StatWtTVI::CHANBIN);
91 - val.downcase();
92 - if (val == "spw") {
93 - mustWriteWtSp = False;
88 +Record StatWt::writeWeights() {
89 + auto mustWriteWt = False;
90 + auto mustWriteWtSp = False;
91 + auto mustWriteSig = False;
92 + auto mustWriteSigSp = False;
93 + _statwtColConfig->getColWriteFlags(
94 + mustWriteWt, mustWriteWtSp, mustWriteSig, mustWriteSigSp
95 + );
96 + shared_ptr<vi::VisibilityIterator2> vi;
97 + std::shared_ptr<vi::StatWtTVILayerFactory> factory;
98 + _constructVi(vi, factory);
99 + vi::VisBuffer2 *vb = vi->getVisBuffer();
100 + ProgressMeter pm(0, _ms->nrow(), "StatWt Progress");
101 + uInt64 count = 0;
102 + for (vi->originChunks(); vi->moreChunks(); vi->nextChunk()) {
103 + for (vi->origin(); vi->more(); vi->next()) {
104 + auto nrow = vb->nRows();
105 + if (_preview) {
106 + // just need to run the flags to accumulate
107 + // flagging info
108 + vb->flagCube();
94 109 }
95 - }
96 - }
97 - auto mustInitWtSp = False;
98 - if (! hasWtSp && mustWriteWtSp) {
99 - // we must create WEIGHT_SPECTRUM
100 - hasWtSp = True;
101 - mustInitWtSp = True;
102 - // from Calibrater.cc
103 - // Nominal default tile shape
104 - IPosition dts(3, 4, 32, 1024);
105 - // Discern DATA's default tile shape and use it
106 - const auto dminfo = _ms->dataManagerInfo();
107 - for (uInt i=0; i<dminfo.nfields(); ++i) {
108 - Record col = dminfo.asRecord(i);
109 - if (anyEQ(col.asArrayString("COLUMNS"), String("DATA"))) {
110 - dts = IPosition(col.asRecord("SPEC").asArrayInt("DEFAULTTILESHAPE"));
111 - break;
110 + else {
111 + if (mustWriteWtSp) {
112 + auto& x = vb->weightSpectrum();
113 + ThrowIf(
114 + x.empty(),
115 + "WEIGHT_SPECTRUM is only partially initialized. "
116 + "StatWt2 cannot deal with such an MS"
117 + );
118 + vb->setWeightSpectrum(x);
119 + }
120 + if (mustWriteSigSp) {
121 + auto& x = vb->sigmaSpectrum();
122 + ThrowIf(
123 + x.empty(),
124 + "SIGMA_SPECTRUM is only partially initialized. "
125 + "StatWt2 cannot deal with such an MS"
126 + );
127 + vb->setSigmaSpectrum(x);
128 + }
129 + if (mustWriteWt) {
130 + vb->setWeight(vb->weight());
131 + }
132 + if (mustWriteSig) {
133 + vb->setSigma(vb->sigma());
134 + }
135 + vb->setFlagCube(vb->flagCube());
136 + vb->setFlagRow(vb->flagRow());
137 + vb->writeChangesBack();
112 138 }
139 + count += nrow;
140 + pm.update(count);
113 141 }
114 - // Add the column
115 - String colWtSp = MS::columnName(MS::WEIGHT_SPECTRUM);
116 - TableDesc tdWtSp;
117 - tdWtSp.addColumn(ArrayColumnDesc<Float>(colWtSp, "weight spectrum", 2));
118 - TiledShapeStMan wtSpStMan("TiledWgtSpectrum", dts);
119 - _ms->addColumn(tdWtSp, wtSpStMan);
120 142 }
121 - else if (! _preview) {
122 - // check to see if extant WEIGHT_SPECTRUM needs to be initialized
123 - ArrayColumn<Float> col(*_ms, MS::columnName(MS::WEIGHT_SPECTRUM));
124 - try {
125 - col.get(0);
126 - // its initialized, so even if we are using the full spw for
127 - // binning, we still need to update WEIGHT_SPECTRUM
128 - mustWriteWtSp = True;
129 - }
130 - catch (const AipsError& x) {
131 - // its not initialized, so we aren't going to write to it unless
132 - // chanbin has been specified to be less than the spw width
133 - mustInitWtSp = mustWriteWtSp;
134 - }
143 + if (_preview) {
144 + LogIO log(LogOrigin("StatWt", __func__));
145 + log << LogIO::NORMAL
146 + << "RAN IN PREVIEW MODE. NO WEIGHTS NOR FLAGS WERE CHANGED."
147 + << LogIO::POST;
135 148 }
149 + factory->getTVI()->summarizeFlagging();
150 + Double mean, variance;
151 + factory->getTVI()->summarizeStats(mean, variance);
152 + Record ret;
153 + ret.define("mean", mean);
154 + ret.define("variance", variance);
155 + return ret;
156 +}
157 +
158 +void StatWt::_constructVi(
159 + std::shared_ptr<vi::VisibilityIterator2>& vi,
160 + std::shared_ptr<vi::StatWtTVILayerFactory>& factory
161 +) const {
136 162 // default sort columns are from MSIter and are ARRAY_ID, FIELD_ID, DATA_DESC_ID, and TIME
137 163 // I'm adding scan and state because, according to the statwt requirements, by default, scan
138 164 // and state changes should mark boundaries in the weights computation
139 165 std::vector<Int> scs;
140 166 scs.push_back(MS::ARRAY_ID);
141 167 if (! _combine.contains("scan")) {
142 168 scs.push_back(MS::SCAN_NUMBER);
143 169 }
144 170 if (! _combine.contains("state")) {
145 171 scs.push_back(MS::STATE_ID);
152 178 Block<int> sort(scs.size());
153 179 uInt i = 0;
154 180 for (const auto& col: scs) {
155 181 sort[i] = col;
156 182 ++i;
157 183 }
158 184 vi::SortColumns sc(sort, False);
159 185 vi::IteratingParameters ipar(_timeBinWidth, sc);
160 186 vi::VisIterImpl2LayerFactory data(_ms, ipar, True);
161 187 unique_ptr<Record> config(dynamic_cast<Record*>(_tviConfig.clone()));
162 - vi::StatWtTVILayerFactory statWtLayerFactory(*config);
188 + factory.reset(new vi::StatWtTVILayerFactory(*config));
163 189 Vector<vi::ViiLayerFactory*> facts(2);
164 190 facts[0] = &data;
165 - facts[1] = &statWtLayerFactory;
166 - vi::VisibilityIterator2 vi(facts);
167 - vi::VisBuffer2 *vb = vi.getVisBuffer();
168 - Vector<Int> vr(1);
169 - ProgressMeter pm(0, _ms->nrow(), "StatWt Progress");
170 - uInt64 count = 0;
171 - for (vi.originChunks(); vi.moreChunks(); vi.nextChunk()) {
172 - for (vi.origin(); vi.more(); vi.next()) {
173 - auto nrow = vb->nRows();
174 - if (_preview) {
175 - // just need to run the flags to accumulate
176 - // flagging info
177 - vb->flagCube();
178 - }
179 - else {
180 - if (mustInitWtSp) {
181 - auto nchan = vb->nChannels();
182 - auto ncor = vb->nCorrelations();
183 - Cube<Float> newwtsp(ncor, nchan, nrow, 0);
184 - vb->initWeightSpectrum(newwtsp);
185 - vb->writeChangesBack();
186 - }
187 - if (mustWriteWtSp) {
188 - vb->setWeightSpectrum(vb->weightSpectrum());
189 - }
190 - vb->setWeight(vb->weight());
191 - vb->setFlagCube(vb->flagCube());
192 - vb->setFlagRow(vb->flagRow());
193 - vb->writeChangesBack();
194 - }
195 - count += nrow;
196 - pm.update(count);
197 - }
198 - }
199 - if (_preview) {
200 - LogIO log(LogOrigin("StatWt", __func__));
201 - log << LogIO::NORMAL
202 - << "RAN IN PREVIEW MODE. NO WEIGHTS NOR FLAGS WERE CHANGED."
203 - << LogIO::POST;
204 - }
205 - statWtLayerFactory.getTVI()->summarizeFlagging();
206 - Double mean, variance;
207 - statWtLayerFactory.getTVI()->summarizeStats(mean, variance);
208 - Record ret;
209 - ret.define("mean", mean);
210 - ret.define("variance", variance);
211 - return ret;
191 + facts[1] = factory.get();
192 + vi.reset(new vi::VisibilityIterator2(facts));
212 193 }
213 194
214 195 }
215 -

Everything looks good. We'll let you know here if there's anything you should know about.

Add shortcut