Commits

Ville Suoranta authored 4b561e35a35
Revert "Merge pull request #19 in CASA/casa from hotfix/CAS-9664 to master"

This reverts commit 08f6cdd446551759730b9dd3f868cf1ab0daddee, reversing changes made to 0937f096c1b7f58f80d81be8bf97763181e1d141.
No tags

code/mstransform/TVI/test/PolAverageTVI_GTest.cc

Modified
32 32 #include <casacore/casa/OS/File.h>
33 33 #include <casacore/casa/OS/RegularFile.h>
34 34 #include <casacore/casa/OS/SymLink.h>
35 35 #include <casacore/casa/OS/Directory.h>
36 36 #include <casacore/casa/OS/DirectoryIterator.h>
37 37 #include <casacore/casa/Exceptions/Error.h>
38 38 #include <casacore/casa/iostream.h>
39 39 #include <casacore/casa/Arrays/ArrayMath.h>
40 40 #include <casacore/casa/iomanip.h>
41 41 #include <casacore/casa/Containers/Record.h>
42 -#include <casacore/casa/Utilities/GenSort.h>
43 42
44 43 #include <msvis/MSVis/VisibilityIteratorImpl2.h>
45 44 #include <msvis/MSVis/LayeredVi2Factory.h>
46 45
47 46 #include <gtest/gtest.h>
48 47 #include <memory>
49 48 #include <typeinfo>
50 -#include <limits>
51 -#include <cmath>
52 49
53 50 using namespace std;
54 51 using namespace casa;
55 52 using namespace casacore;
56 53 using namespace casa::vi;
57 54
58 55 namespace {
59 56 string GetCasaDataPath() {
60 57 if (casacore::EnvironmentVariable::isDefined("CASAPATH")) {
61 58 string casapath = casacore::EnvironmentVariable::get("CASAPATH");
158 155 EXPECT_FLOAT_EQ(ref, result);
159 156 }
160 157
161 158 template<>
162 159 void ValidatorUtil::ValidateScalar<Complex>(Complex const ref,
163 160 Complex const result) {
164 161 EXPECT_FLOAT_EQ(ref.real(), result.real());
165 162 EXPECT_FLOAT_EQ(ref.imag(), result.imag());
166 163 }
167 164
168 -// Base class for validating polarization average
169 165 template<class CollapserImpl>
170 166 struct ValidatorBase {
171 -private:
172 167 template<class T>
173 168 static void CollapseData(Array<T> const &baseData,
174 169 Array<Bool> const &baseFlag, Array<Float> const &baseWeight,
175 170 Array<T> &result) {
176 171 ASSERT_EQ(baseData.shape(), baseWeight.shape());
177 172 ASSERT_EQ(baseData.ndim(), (uInt )3);
178 173 IPosition const baseShape = baseData.shape();
179 174 IPosition shape(baseShape);
180 175 shape[0] = 1;
181 176 result.resize(shape);
202 197 p_r[j] += p_d[j] * p_w[j];
203 198 p_s[j] += p_w[j];
204 199 }
205 200 }
206 201 result.putStorage(p_r, b1);
207 202 weightSum.putStorage(p_s, b2);
208 203 dSlice.freeStorage(p_d, b3);
209 204 weight.freeStorage(p_w, b4);
210 205 fSlice.freeStorage(p_f, b5);
211 206 }
212 -// std::cout << "DEBUG: result = " << result.nonDegenerate()
213 -// << " weightSum = " << weightSum.nonDegenerate() << std::endl;
214 - Bool b1, b2;
215 - T *p_r = result.getStorage(b1);
216 - Float const *p_s = weightSum.getStorage(b2);
217 - for (ssize_t i = 0; i < result.shape().product(); ++i) {
218 - if (p_s[i] > 0.0f) {
219 - p_r[i] /= p_s[i];
220 - }
221 - }
222 - result.putStorage(p_r, b1);
223 - weightSum.freeStorage(p_s, b2);
224 -// std::cout << "DEBUG: r/w = " << result << std::endl;
207 + result /= weightSum;
225 208 }
226 209
227 210 static void CollapseWeight(Array<Float> const &baseWeight,
228 211 Array<Float> &result) {
229 212 IPosition const baseShape = baseWeight.shape();
230 213 IPosition resultShape(baseShape);
231 214 resultShape[0] = 1;
232 215 result.resize(resultShape);
233 216 result = 0;
234 217 size_t nPol = baseShape[0];
235 218 for (size_t i = 0; i < nPol; ++i) {
236 219 IPosition start(baseShape.size(), 0);
237 220 start[0] = i;
238 221 IPosition end(baseShape);
239 222 end -= 1;
240 223 end[0] = i;
241 224 auto wslice = baseWeight(start, end);
242 - //CollapserImpl::AccumulateWeight(wslice, result);
243 - CollapserImpl::AccumulateWeight(start, end, baseWeight, result);
225 + CollapserImpl::AccumulateWeight(wslice, result);
244 226 }
245 227 CollapserImpl::NormalizeWeight(result);
246 228 }
247 229
248 230 template<class T>
249 231 static void ValidateDataColumn(Array<T> const &data, MeasurementSet const &ms,
250 232 String const &columnName, Vector<uInt> const &rowIds) {
251 233 cout << "Start " << __func__ << endl;
252 - ASSERT_EQ(data.shape()[0], (ssize_t )1);
253 234 Array<T> baseData;
254 235 Array<Float> baseWeight;
255 236 Array<Bool> baseFlag;
256 237 Filler<T>::FillArrayReference(ms, columnName, rowIds, baseData);
257 238 Filler<Bool>::FillArrayReference(ms, "FLAG", rowIds, baseFlag);
258 239 baseWeight.resize(baseData.shape());
259 240 Filler<Float>::FillWeightSp(ms, rowIds, baseWeight);
260 241 Cube<T> ref(data.shape());
261 242 CollapseData(baseData, baseFlag, baseWeight, ref);
262 243 ValidatorUtil::ValidateArray<T>(ref, data);
263 244 }
264 245
265 246 static void ValidateWeightColumn(Array<Float> const &weight,
266 247 MeasurementSet const &ms, String const &columnName,
267 248 Vector<uInt> const &rowIds) {
268 - ASSERT_EQ(weight.shape()[0], (ssize_t )1);
269 249 Array<Float> baseWeight;
270 250 Filler<Float>::FillArrayReference(ms, columnName, rowIds, baseWeight);
271 251 Array<Float> ref;
272 252 CollapseWeight(baseWeight, ref);
273 253 ValidatorUtil::ValidateArray(ref, weight);
274 254 }
275 255
276 -public:
277 - static void ValidatePolarization(Vector<Int> const &corrType) {
278 - // correlation type is always I
279 - ASSERT_EQ((size_t )1, corrType.size());
280 - ASSERT_TRUE(allEQ(corrType, (Int )Stokes::I));
281 - }
282 256 static void ValidateData(Array<Complex> const &data, MeasurementSet const &ms,
283 257 Vector<uInt> const &rowIds) {
284 258 ValidateDataColumn(data, ms, "DATA", rowIds);
285 259 }
286 260
287 261 static void ValidateCorrected(Cube<Complex> const &data,
288 262 MeasurementSet const &ms, Vector<uInt> const &rowIds) {
289 263 ValidateDataColumn(data, ms, "CORRECTED_DATA", rowIds);
290 264 }
291 265
315 289 IPosition start(3, i, 0, 0);
316 290 IPosition end(3, i, nChan - 1, nRow - 1);
317 291 auto fslice = baseFlag(start, end);
318 292 ref &= fslice;
319 293 }
320 294 ValidatorUtil::ValidateArray(ref, flag);
321 295 }
322 296
323 297 static void ValidateFlagRow(Vector<Bool> const &flag,
324 298 MeasurementSet const &ms, Vector<uInt> const &rowIds) {
325 - ASSERT_EQ(flag.size(), rowIds.size());
326 299 Vector<Bool> ref;
327 300 Filler<Bool>::FillScalarReference(ms, "FLAG_ROW", rowIds, ref);
328 301 ASSERT_EQ(flag.shape(), ref.shape());
329 302 EXPECT_TRUE(allEQ(flag, ref));
330 303 }
331 304
332 305 static void ValidateWeight(Matrix<Float> const &weight,
333 306 MeasurementSet const &ms, Vector<uInt> const &rowIds) {
334 307 ValidateWeightColumn(weight, ms, "WEIGHT", rowIds);
335 308 }
353 326 IPosition end(3, i, nChan - 1, j);
354 327 auto wslice = weight(start, end);
355 328 auto wref = scalarWeight(i, j);
356 329 EXPECT_TRUE(allEQ(wref, wslice));
357 330 }
358 331 }
359 332 }
360 333 }
361 334 };
362 335
363 -// Base class for validating polarization average being skipped
364 -struct IdenticalValidator {
365 -public:
366 - static void ValidatePolarization(Vector<Int> const &corrType) {
367 - Int possibleTypes[] = { Stokes::I, Stokes::Q, Stokes::U, Stokes::V,
368 - Stokes::XY, Stokes::YX };
369 - size_t len = sizeof(possibleTypes) / sizeof(Int);
370 - Vector<Int> possibleTypesV(IPosition(1, len), possibleTypes, SHARE);
371 - auto iterend = corrType.end();
372 - for (auto iter = corrType.begin(); iter != iterend; ++iter) {
373 - ASSERT_TRUE(anyEQ(possibleTypesV, *iter));
374 - }
375 - }
376 - static void ValidateData(Array<Complex> const &data, MeasurementSet const &ms,
377 - Vector<uInt> const &rowIds) {
378 - ValidateArrayColumn(data, ms, "DATA", rowIds);
379 - }
380 - static void ValidateCorrected(Cube<Complex> const &data,
381 - MeasurementSet const &ms, Vector<uInt> const &rowIds) {
382 - ValidateArrayColumn(data, ms, "CORRECTED_DATA", rowIds);
383 - }
384 - static void ValidateModel(Cube<Complex> const &data, MeasurementSet const &ms,
385 - Vector<uInt> const &rowIds) {
386 - ValidateArrayColumn(data, ms, "MODEL_DATA", rowIds);
387 - }
388 - static void ValidateFloat(Cube<Float> const &data, MeasurementSet const &ms,
389 - Vector<uInt> const &rowIds) {
390 - ValidateArrayColumn(data, ms, "FLOAT_DATA", rowIds);
391 - }
392 - static void ValidateFlag(Cube<Bool> const &flag, MeasurementSet const &ms,
393 - Vector<uInt> const &rowIds) {
394 - ValidateArrayColumn(flag, ms, "FLAG", rowIds);
395 - }
396 - static void ValidateFlagRow(Vector<Bool> const &flag,
397 - MeasurementSet const &ms, Vector<uInt> const &rowIds) {
398 - ValidateScalarColumn(flag, ms, "FLAG_ROW", rowIds);
399 - }
400 - static void ValidateWeight(Matrix<Float> const &weight,
401 - MeasurementSet const &ms, Vector<uInt> const &rowIds) {
402 - ValidateArrayColumn(weight, ms, "WEIGHT", rowIds);
403 - }
404 - static void ValidateWeightSp(Cube<Float> const &weight,
405 - MeasurementSet const &ms, Vector<uInt> const &rowIds) {
406 - if (ms.tableDesc().isColumn("WEIGHT_SPECTRUM")) {
407 - ValidateArrayColumn(weight, ms, "WEIGHT_SPECTRUM", rowIds);
408 - }
409 - }
410 -private:
411 - template<class T>
412 - static void ValidateArrayColumn(Array<T> const &data,
413 - MeasurementSet const &ms, String const &columnName,
414 - Vector<uInt> const &rowIds) {
415 - Array<T> ref;
416 - Filler<T>::FillArrayReference(ms, columnName, rowIds, ref);
417 - EXPECT_EQ(data.shape(), ref.shape());
418 - EXPECT_TRUE(allEQ(data, ref));
419 - }
420 - template<class T>
421 - static void ValidateScalarColumn(Array<T> const &data,
422 - MeasurementSet const &ms, String const &columnName,
423 - Vector<uInt> const &rowIds) {
424 - Vector<T> ref;
425 - Filler<T>::FillScalarReference(ms, columnName, rowIds, ref);
426 - EXPECT_EQ(data.shape(), ref.shape());
427 - EXPECT_TRUE(allEQ(data, ref));
428 - }
429 -};
430 -
431 -// Base class for Geometric type validator
432 -struct GeometricValidatorBase {
336 +struct GeometricAverageValidator: public ValidatorBase<GeometricAverageValidator> {
433 337 static String GetMode() {
434 338 return "geometric";
435 339 }
436 340
437 341 static String GetTypePrefix() {
438 342 return "GeometricPolAverage(";
439 343 }
440 -};
441 -
442 -// Base class for Stokes type validator
443 -struct StokesValidatorBase {
444 - static String GetMode() {
445 - return "stokes";
446 - }
447 -
448 - static String GetTypePrefix() {
449 - return "StokesPolAverage(";
450 - }
451 -};
452 344
453 -// Validator for Geometric polarization average
454 -struct GeometricAverageValidator: public GeometricValidatorBase,
455 - public ValidatorBase<GeometricAverageValidator> {
456 345 static void SetWeight(IPosition const &start, IPosition const &end,
457 346 Array<Float> const &baseWeight, Array<Float> &weight) {
458 347 weight = baseWeight(start, end);
459 348 }
460 349
461 - static void AccumulateWeight(IPosition const &start, IPosition const &end,
462 - Array<Float> const &weight, Array<Float> &result) {
463 - result += weight(start, end);
350 + static void AccumulateWeight(Array<Float> const &weight,
351 + Array<Float> &result) {
352 + result += weight;
464 353 }
465 354
466 355 static void NormalizeWeight(Array<Float> &/*result*/) {
467 356 }
468 357 };
469 358
470 -// Validator for Stokes polarization average
471 -struct StokesAverageValidator: public StokesValidatorBase, public ValidatorBase<
472 - StokesAverageValidator> {
473 - static void SetWeight(IPosition const &/*start*/, IPosition const &/*end*/,
474 - Array<Float> const &/*baseWeight*/, Array<Float> &weight) {
475 - weight = 1.0f;
476 - }
477 -
478 - static void AccumulateWeight(IPosition const &start, IPosition const &end,
479 - Array<Float> const &weight, Array<Float> &result) {
480 - result += 1.0f / weight(start, end);
481 - }
482 -
483 - static void NormalizeWeight(Array<Float> &result) {
484 - result = 4.0f / result;
485 - }
486 -};
487 -
488 -// Validator for Geometric polarization average including cross-polarization
489 -struct GeometricAverageCrossPolarizationValidator: public GeometricValidatorBase,
490 - public ValidatorBase<GeometricAverageCrossPolarizationValidator> {
491 - static void SetWeight(IPosition const &start, IPosition const &end,
492 - Array<Float> const &baseWeight, Array<Float> &weight) {
493 - // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL)
494 - ASSERT_EQ(start[0], end[0]);
495 - if (start[0] == 1 || start[0] == 2) {
496 - // set weight for cross-polarization component to exclude it from the average
497 - weight = 0.0f;
498 - } else {
499 - weight = baseWeight(start, end);
500 - }
501 - }
502 -
503 - static void AccumulateWeight(IPosition const &start, IPosition const &end,
504 - Array<Float> const &weight, Array<Float> &result) {
505 - // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL)
506 - ASSERT_EQ(start[0], end[0]);
507 - if (start[0] != 1 && start[0] != 2) {
508 - result += weight(start, end);
509 - }
359 +struct StokesAverageValidator: public ValidatorBase<StokesAverageValidator> {
360 + static String GetMode() {
361 + return "stokes";
510 362 }
511 363
512 - static void NormalizeWeight(Array<Float> &/*result*/) {
364 + static String GetTypePrefix() {
365 + return "StokesPolAverage(";
513 366 }
514 -};
515 367
516 -// Validator for Stokes polarization average including cross-polarization
517 -struct StokesAverageCrossPolarizationValidator: public StokesValidatorBase,
518 - public ValidatorBase<StokesAverageCrossPolarizationValidator> {
519 - static void SetWeight(IPosition const &start, IPosition const &end,
368 + static void SetWeight(IPosition const &/*start*/, IPosition const &/*end*/,
520 369 Array<Float> const &/*baseWeight*/, Array<Float> &weight) {
521 - // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL)
522 - ASSERT_EQ(start[0], end[0]);
523 - if (start[0] == 1 || start[0] == 2) {
524 - // set weight for cross-polarization component to exclude it from the average
525 - weight = 0.0f;
526 - } else {
527 - weight = 1.0f;
528 - }
370 + weight = 1.0f;
529 371 }
530 372
531 - static void AccumulateWeight(IPosition const &start, IPosition const &end,
532 - Array<Float> const &weight, Array<Float> &result) {
533 - // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL)
534 - ASSERT_EQ(start[0], end[0]);
535 - if (start[0] != 1 && start[0] != 2) {
536 - result += 1.0f / weight(start, end);
537 - }
373 + static void AccumulateWeight(Array<Float> const &weight,
374 + Array<Float> &result) {
375 + result += 1.0f / weight;
538 376 }
539 377
540 378 static void NormalizeWeight(Array<Float> &result) {
541 379 result = 4.0f / result;
542 380 }
543 381 };
544 382
545 -// Validator for Geometric polarization average (identical case = skip)
546 -struct GeometricIdenticalValidator: public GeometricValidatorBase,
547 - public IdenticalValidator {
548 -
549 -};
550 -
551 -// Validator for Stokes polarization average (identical case = skip)
552 -struct StokesIdenticalValidator: public StokesValidatorBase,
553 - public IdenticalValidator {
554 -
555 -};
556 -
557 383 template<class Impl>
558 384 class Manufacturer {
559 385 public:
560 - struct Product {
561 - ViFactory *factory;
562 - ViImplementation2 *vii;
563 - };
564 386 static VisibilityIterator2 *ManufactureVI(MeasurementSet *ms,
565 387 String const &mode) {
566 388 cout << "### Manufacturer: " << endl << "### " << Impl::GetTestPurpose()
567 389 << endl;
568 390
569 391 Record modeRec;
570 392 if (mode.size() > 0) {
571 393 modeRec.define("mode", mode);
572 394 }
573 395
574 396 // build factory object
575 - Product p = Impl::BuildFactory(ms, modeRec);
576 - std::unique_ptr < ViFactory > factory(p.factory);
397 + std::unique_ptr<ViFactory> factory(Impl::BuildFactory(ms, modeRec));
577 398
578 - std::unique_ptr < VisibilityIterator2 > vi;
399 + std::unique_ptr<VisibilityIterator2> vi;
579 400 try {
580 401 vi.reset(new VisibilityIterator2(*factory.get()));
581 402 } catch (...) {
582 403 cout << "Failed to create VI at factory" << endl;
583 - // vii must be deleted since it is never managed by vi
584 - if (p.vii) {
585 - delete p.vii;
586 - }
587 404 throw;
588 405 }
589 406
590 407 cout << "Created VI type \"" << vi->ViiType() << "\"" << endl;
591 408
592 409 return vi.release();
593 410 }
594 411 };
595 412
596 413 class BasicManufacturer1: public Manufacturer<BasicManufacturer1> {
597 414 public:
598 - static Product BuildFactory(MeasurementSet *ms, Record const &mode) {
415 + static ViFactory *BuildFactory(MeasurementSet *ms, Record const &mode) {
599 416 // create read-only VI impl
600 417 Block<MeasurementSet const *> const mss(1, ms);
601 418 SortColumns defaultSortColumns;
602 419
603 420 std::unique_ptr<ViImplementation2> inputVii(
604 421 new VisibilityIteratorImpl2(mss, defaultSortColumns, 0.0, VbPlain,
605 422 False));
606 423
607 424 std::unique_ptr<ViFactory> factory(
608 425 new PolAverageVi2Factory(mode, inputVii.get()));
609 426
610 - Product p;
611 -
612 427 // vi will be responsible for releasing inputVii so unique_ptr
613 428 // should release the ownership here
614 - p.vii = inputVii.release();
615 - p.factory = factory.release();
429 + inputVii.release();
616 430
617 - return p;
431 + return factory.release();
618 432 }
619 433
620 434 static String GetTestPurpose() {
621 435 return "Test PolAverageVi2Factory(Record const &, ViImplementation2 *)";
622 436 }
623 437 };
624 438
625 439 class BasicManufacturer2: public Manufacturer<BasicManufacturer2> {
626 440 public:
627 - static Product BuildFactory(MeasurementSet *ms, Record const &mode) {
441 + static ViFactory *BuildFactory(MeasurementSet *ms, Record const &mode) {
628 442 // create factory directly from MS
629 443 SortColumns defaultSortColumns;
630 444 std::unique_ptr<ViFactory> factory(
631 445 new PolAverageVi2Factory(mode, ms, defaultSortColumns, 0.0, False));
632 446
633 - Product p;
634 - p.vii = nullptr;
635 - p.factory = factory.release();
636 - return p;
447 + return factory.release();
637 448 }
638 449
639 450 static String GetTestPurpose() {
640 451 return "Test PolAverageVi2Factory(Record const &, MeasurementSet const *, ...)";
641 452 }
642 453 };
643 454
644 455 class LayerManufacturer: public Manufacturer<LayerManufacturer> {
645 456 public:
646 457 class LayerFactoryWrapper: public ViFactory {
647 458 public:
648 459 LayerFactoryWrapper(MeasurementSet *ms, Record const &mode) :
649 460 ms_(ms), mode_(mode) {
650 461 }
651 462
652 463 ViImplementation2 *createVi() const {
653 464 Vector<ViiLayerFactory *> v(1);
654 - auto layer0 = VisIterImpl2LayerFactory(ms_, IteratingParameters(0.0),
655 - false);
465 + auto layer0 = VisIterImpl2LayerFactory(ms_, IteratingParameters(0.0), false);
656 466 auto layer1 = PolAverageTVILayerFactory(mode_);
657 467 v[0] = &layer0;
658 468 return layer1.createViImpl2(v);
659 469 }
660 470 private:
661 471 MeasurementSet *ms_;
662 472 Record const mode_;
663 473 };
664 474
665 - static Product BuildFactory(MeasurementSet *ms, Record const &mode) {
475 + static ViFactory *BuildFactory(MeasurementSet *ms, Record const &mode) {
666 476 // create read-only VI impl
667 - std::unique_ptr<ViFactory> factory(new LayerFactoryWrapper(ms, mode));
477 + std::unique_ptr<ViFactory> factory(
478 + new LayerFactoryWrapper(ms, mode));
668 479
669 - Product p;
670 - p.vii = nullptr;
671 - p.factory = factory.release();
672 - return p;
480 + return factory.release();
673 481 }
674 482
675 483 static String GetTestPurpose() {
676 484 return "Test PolAverageTVILayerFactory";
677 485 }
678 486 };
679 487
680 488 } // anonymous namespace
681 489
682 -class PolAverageTVITestBase: public ::testing::Test {
490 +class PolAverageTVITest: public ::testing::Test {
683 491 public:
684 - PolAverageTVITestBase() :
685 - my_ms_name_("polaverage_test.ms"), my_data_name_(), ms_(nullptr) {
686 - }
687 -
688 492 virtual void SetUp() {
689 493 // my_data_name_ = "analytic_spectra.ms";
690 - my_data_name_ = GetDataName(); //"analytic_type1.bl.ms";
691 - std::string const data_path = ::GetCasaDataPath() + "/regression/unittest/"
692 - + GetRelativeDataPath() + "/";
693 -// + "/regression/unittest/tsdbaseline/";
494 + my_data_name_ = "analytic_type1.bl.ms";
495 + my_ms_name_ = "polaverage_test.ms";
496 + std::string const data_path = ::GetCasaDataPath()
497 + + "/regression/unittest/tsdbaseline/";
694 498 // + "/regression/unittest/singledish/";
695 499
696 - ASSERT_TRUE(Directory(data_path).exists());
697 500 copyDataFromRepository(data_path);
698 501 ASSERT_TRUE(File(my_data_name_).exists());
699 502 deleteTable(my_ms_name_);
700 503
701 504 // create MS
702 - ms_ = new MeasurementSet(my_data_name_, Table::Update);
505 + ms_ = MeasurementSet(my_data_name_, Table::Old);
703 506 }
704 507
705 508 virtual void TearDown() {
706 - // delete MS explicitly to detach from MS on disk
707 - delete ms_;
708 -
709 - // just to make sure all locks are effectively released
710 - Table::relinquishAutoLocks();
711 -
712 509 cleanup();
713 510 }
714 -
715 511 protected:
716 - std::string const my_ms_name_;
512 + std::string my_ms_name_;
717 513 std::string my_data_name_;
718 - MeasurementSet *ms_;
514 + MeasurementSet ms_;
719 515
720 - virtual std::string GetDataName() {
721 - return "";
516 + VisibilityIterator2 *ManufactureVI(String const &mode) {
517 + return BasicManufacturer1::ManufactureVI(&ms_, mode);
722 518 }
723 519
724 - virtual std::string GetRelativeDataPath() {
725 - return "";
520 + void TestFactory(String const &mode, String const &expectedClassName) {
521 +
522 + cout << "Mode \"" << mode << "\" expected class name \""
523 + << expectedClassName << "\"" << endl;
524 +
525 + if (expectedClassName.size() > 0) {
526 + std::unique_ptr<VisibilityIterator2> vi(ManufactureVI(mode));
527 +
528 + // Verify type string
529 + String viiType = vi->ViiType();
530 + EXPECT_TRUE(viiType.startsWith(expectedClassName));
531 + } else {
532 + cout << "Creation of VI via factory will fail" << endl;
533 + // exception must be thrown
534 + EXPECT_THROW( {
535 + std::unique_ptr<VisibilityIterator2> vi(ManufactureVI(mode)); //new VisibilityIterator2(factory));
536 + },
537 + AipsError)<< "The process must throw AipsError";
538 + }
726 539 }
727 540
728 - template<class Validator, class Manufacturer = BasicManufacturer1>
541 + template<class Validator, class Manufacturer=BasicManufacturer1>
729 542 void TestTVI() {
730 543 // Create VI
731 - std::unique_ptr < VisibilityIterator2
732 - > vi(Manufacturer::ManufactureVI(ms_, Validator::GetMode()));
544 + std::unique_ptr<VisibilityIterator2> vi(Manufacturer::ManufactureVI(&ms_, Validator::GetMode()));
733 545 ASSERT_TRUE(vi->ViiType().startsWith(Validator::GetTypePrefix()));
734 546
735 547 // MS property
736 548 auto ms = vi->ms();
737 549 uInt const nRowMs = ms.nrow();
738 550 uInt const nRowPolarizationTable = ms.polarization().nrow();
739 551 auto const desc = ms.tableDesc();
740 552 auto const correctedExists = desc.isColumn("CORRECTED_DATA");
741 553 auto const modelExists = desc.isColumn("MODEL_DATA");
742 554 auto const dataExists = desc.isColumn("DATA");
743 555 auto const floatExists = desc.isColumn("FLOAT_DATA");
744 556 //auto const weightSpExists = desc.isColumn("WEIGHT_SPECTRUM");
745 557 cout << "MS Property" << endl;
746 558 cout << "\tMS Name: \"" << ms.tableName() << "\"" << endl;
747 559 cout << "\tNumber of Rows: " << nRowMs << endl;
748 560 cout << "\tNumber of Spws: " << vi->nSpectralWindows() << endl;
749 561 cout << "\tNumber of Polarizations: " << vi->nPolarizationIds() << endl;
750 562 cout << "\tNumber of DataDescs: " << vi->nDataDescriptionIds() << endl;
751 563 cout << "\tChannelized Weight Exists? "
752 - << (vi->weightSpectrumExists() ? "True" : "False") << endl;
564 + << (vi->weightSpectrumExists() ? "True" : "False") << endl;
753 565 //cout << "\tChannelized Sigma Exists? " << (vi->sigmaSpectrumExists() ? "True" : "False") << endl;
754 566
755 567 // mv-VI consistency check
756 - EXPECT_EQ(nRowPolarizationTable + 1, (uInt )vi->nPolarizationIds());
568 + EXPECT_EQ(nRowPolarizationTable + 1, (uInt)vi->nPolarizationIds());
757 569
758 570 // VI iteration
759 571 Vector<uInt> swept(nRowMs, 0);
760 572 uInt nRowChunkSum = 0;
761 573 VisBuffer2 *vb = vi->getVisBuffer();
762 574 vi->originChunks();
763 575 while (vi->moreChunks()) {
764 576 vi->origin();
765 577 Int const nRowChunk = vi->nRowsInChunk();
766 578 nRowChunkSum += nRowChunk;
767 579 cout << "*************************" << endl;
768 580 cout << "*** Start loop on chunk " << vi->getSubchunkId().chunk() << endl;
769 581 cout << "*** Number of Rows: " << nRowChunk << endl;
770 582 cout << "*************************" << endl;
771 583
772 584 Int nRowSubchunkSum = 0;
773 585
774 586 while (vi->more()) {
775 587 auto subchunk = vi->getSubchunkId();
776 - cout << "=== Start loop on subchunk " << subchunk.subchunk() << " ==="
777 - << endl;
588 + cout << "=== Start loop on subchunk " << subchunk.subchunk() << " ===" << endl;
778 589
779 590 // cannot use getInterval due to the error
780 591 // "undefined reference to VisibilityIterator2::getInterval"
781 592 // even if the code is liked to libmsvis.so.
782 593 //cout << "Interval: " << vi->getInterval() << endl;
783 594
784 595 cout << "Antenna1: " << vb->antenna1() << endl;
785 596 cout << "Antenna2: " << vb->antenna2() << endl;
786 597 cout << "Array Id: " << vb->arrayId() << endl;
787 598 cout << "Data Desc Ids: " << vb->dataDescriptionIds() << endl;
792 603 cout << "Field Id: " << vb->fieldId() << endl;
793 604 cout << "Flag Row: " << vb->flagRow() << endl;
794 605 cout << "Observation Id: " << vb->observationId() << endl;
795 606 cout << "Processor Id: " << vb->processorId() << endl;
796 607 cout << "Scan: " << vb->scan() << endl;
797 608 cout << "State Id: " << vb->stateId() << endl;
798 609 cout << "Time: " << vb->time() << endl;
799 610 cout << "Time Centroid: " << vb->timeCentroid() << endl;
800 611 cout << "Time Interval: " << vb->timeInterval() << endl;
801 612 auto const corrTypes = vb->correlationTypes();
802 - auto toStokes = [](Vector<Int> const &corrTypes) {
803 - Vector<String> typeNames(corrTypes.size());
804 - for (size_t i = 0; i < corrTypes.size(); ++i) {
805 - typeNames[i] = Stokes::name((Stokes::StokesTypes)corrTypes[i]);
806 - }
807 - return typeNames;
808 - };
809 - cout << "Correlation Types: " << toStokes(corrTypes) << endl;
613 + cout << "Correlation Types: " << corrTypes << endl;
810 614 //cout << "UVW: " << vb->uvw() << endl;
811 615
812 616 cout << "---" << endl;
813 617 Int nRowSubchunk = vb->nRows();
814 618 Vector<uInt> rowIds = vb->rowIds();
815 619 for (auto iter = rowIds.begin(); iter != rowIds.end(); ++iter) {
816 620 swept[*iter] += 1;
817 621 }
818 622 nRowSubchunkSum += nRowSubchunk;
819 623 Int nAnt = vb->nAntennas();
859 663 EXPECT_EQ(visShape, visCubeCorrected.shape());
860 664 }
861 665 EXPECT_EQ(!modelExists, visCubeModel.empty());
862 666 if (!visCubeModel.empty()) {
863 667 EXPECT_EQ(visShape, visCubeModel.shape());
864 668 }
865 669 EXPECT_EQ(!floatExists, visCubeFloat.empty());
866 670 if (!visCubeFloat.empty()) {
867 671 EXPECT_EQ(visShape, visCubeFloat.shape());
868 672 }
869 - EXPECT_EQ((ssize_t )nRowSubchunk, weight.shape()[1]);
673 + EXPECT_EQ((uInt)1, flagRow.size());
674 + EXPECT_EQ((ssize_t)1, weight.shape()[0]);
675 + EXPECT_EQ((ssize_t)nRowSubchunk, weight.shape()[1]);
870 676 // NB: weight spectrum is created on-the-fly based on WEIGHT
871 677 // so that weightSp is always non-empty.
872 678 // see VisBufferImpl2::fillWeightSpectrum.
873 679 //EXPECT_EQ(!weightSpExists, weightSp.empty());
874 680 EXPECT_FALSE(weightSp.empty());
875 681 if (!weightSp.empty()) {
876 - EXPECT_EQ((ssize_t )nChan, weightSp.shape()[1]);
877 - EXPECT_EQ((ssize_t )nRowSubchunk, weightSp.shape()[2]);
682 + EXPECT_EQ((ssize_t)1, weightSp.shape()[0]);
683 + EXPECT_EQ((ssize_t)nChan, weightSp.shape()[1]);
684 + EXPECT_EQ((ssize_t)nRowSubchunk, weightSp.shape()[2]);
878 685 }
879 686
880 687 // polarization averaging specific check
688 + // length of the correlation (polarization) axis must be 1
689 + ASSERT_EQ((ssize_t)1, visShape[0]);
881 690 // polarization id always points to the row to be appended
882 - ASSERT_EQ(nRowPolarizationTable, (uInt )vb->polarizationId());
883 - Validator::ValidatePolarization(corrTypes);
691 + ASSERT_EQ(nRowPolarizationTable, (uInt)vb->polarizationId());
692 + // correlation type is always I
693 + ASSERT_EQ((size_t)1, corrTypes.size());
694 + ASSERT_TRUE(allEQ(corrTypes, (Int)Stokes::I));
884 695
885 696 // validation of polarization average
886 697 if (!visCube.empty()) {
887 698 cout << "validate DATA" << endl;
888 699 Validator::ValidateData(visCube, ms, rowIds);
889 700 }
890 701 if (!visCubeCorrected.empty()) {
891 702 cout << "validate CORRECTED_DATA" << endl;
892 703 Validator::ValidateCorrected(visCubeCorrected, ms, rowIds);
893 704 }
911 722 // chunk-subchunk consistency check
912 723 EXPECT_EQ(nRowChunk, nRowSubchunkSum);
913 724
914 725 vi->nextChunk();
915 726 }
916 727
917 728 // chunk-ms consistency check
918 729 EXPECT_EQ(nRowMs, nRowChunkSum);
919 730
920 731 // iteration check
921 - EXPECT_TRUE(allEQ(swept, (uInt )1));
732 + EXPECT_TRUE(allEQ(swept, (uInt)1));
922 733
923 734 }
924 735
925 736 private:
926 737 void copyRegular(String const &src, String const &dst) {
927 738 RegularFile r(src);
928 739 r.copy(dst);
929 740 }
930 741 void copySymLink(String const &src, String const &dst) {
931 742 Path p = SymLink(src).followSymLink();
975 786 copySymLink(full_path, work_path);
976 787 } else if (f.isRegular()) {
977 788 copyRegular(full_path, work_path);
978 789 } else if (f.isDirectory()) {
979 790 copyDirectory(full_path, work_path);
980 791 }
981 792 }
982 793 }
983 794 void cleanup() {
984 795 if (my_data_name_.size() > 0) {
985 - deleteTable(my_data_name_);
796 + File f(my_data_name_);
797 + if (f.isRegular()) {
798 + RegularFile r(my_data_name_);
799 + r.remove();
800 + } else if (f.isDirectory()) {
801 + Directory d(my_data_name_);
802 + d.removeRecursive();
803 + }
986 804 }
987 805 deleteTable(my_ms_name_);
988 806 }
989 807 void deleteTable(std::string const &name) {
990 808 File file(name);
991 809 if (file.exists()) {
992 810 std::cout << "Removing " << name << std::endl;
993 811 Table::deleteTable(name, true);
994 812 }
995 813 }
996 -};
997 -
998 -// Fixture class for standard test
999 -class PolAverageTVITest: public PolAverageTVITestBase {
1000 -protected:
1001 - virtual std::string GetDataName() {
1002 - return "analytic_type1.bl.ms";
1003 - }
1004 -
1005 - virtual std::string GetRelativeDataPath() {
1006 - return "tsdbaseline";
1007 - }
1008 -
1009 - VisibilityIterator2 *ManufactureVI(String const &mode) {
1010 - return BasicManufacturer1::ManufactureVI(ms_, mode);
1011 - }
1012 -
1013 - void TestFactory(String const &mode, String const &expectedClassName) {
1014 -
1015 - cout << "Mode \"" << mode << "\" expected class name \""
1016 - << expectedClassName << "\"" << endl;
1017 -
1018 - if (expectedClassName.size() > 0) {
1019 - std::unique_ptr < VisibilityIterator2 > vi(ManufactureVI(mode));
1020 -
1021 - // Verify type string
1022 - String viiType = vi->ViiType();
1023 - EXPECT_TRUE(viiType.startsWith(expectedClassName));
1024 - } else {
1025 - cout << "Creation of VI via factory will fail" << endl;
1026 - // exception must be thrown
1027 - EXPECT_THROW( {
1028 - std::unique_ptr<VisibilityIterator2> vi(ManufactureVI(mode)); //new VisibilityIterator2(factory));
1029 - },
1030 - AipsError)<< "The process must throw AipsError";
1031 - }
1032 - }
1033 -
1034 -};
1035 -
1036 -// Fixture class for testing four polarization (cross-pol, stokes IQUV)
1037 -class PolAverageTVIFourPolarizationTest: public PolAverageTVITestBase {
1038 -protected:
1039 - virtual std::string GetDataName() {
1040 - return "crosspoltest.ms";
1041 - }
1042 -
1043 - virtual std::string GetRelativeDataPath() {
1044 - return "sdsave";
1045 - }
1046 -
1047 - void SetCorrTypeToStokes() {
1048 - ScalarColumn<Int> dataDescIdColumn(*ms_, "DATA_DESC_ID");
1049 - Vector<Int> dataDescIdList = dataDescIdColumn.getColumn();
1050 - ScalarColumn<Int> polarizationIdColumn(ms_->dataDescription(),
1051 - "POLARIZATION_ID");
1052 - Vector<Int> polarizationIdList(dataDescIdList.size());
1053 - for (size_t i = 0; i < dataDescIdList.size(); ++i) {
1054 - polarizationIdList[i] = polarizationIdColumn(dataDescIdList[i]);
1055 - }
1056 - std::cout << "polarizationIdList = " << polarizationIdList << std::endl;
1057 - uInt n = GenSort<Int>::sort(polarizationIdList, Sort::Ascending,
1058 - Sort::HeapSort | Sort::NoDuplicates);
1059 - std::cout << "polarizationIdList (sorted n = " << n << ") = "
1060 - << polarizationIdList << std::endl;
1061 -
1062 - ArrayColumn<Int> corrTypeColumn(ms_->polarization(), "CORR_TYPE");
1063 - Int const newCorrTypes[] = { Stokes::I, Stokes::Q, Stokes::U, Stokes::V };
1064 - for (uInt i = 0; i < n; ++i) {
1065 - auto row = polarizationIdList[i];
1066 - std::cout << "row = " << row << std::endl;
1067 - Vector<Int> corrType = corrTypeColumn(row);
1068 - std::cout << "corrType = " << corrType << std::endl;
1069 - ASSERT_LE(corrType.size(), sizeof(newCorrTypes) / sizeof(Int));
1070 - for (size_t j = 0; j < corrType.size(); ++j) {
1071 - corrType[j] = newCorrTypes[j];
1072 - }
1073 - std::cout << "new corrType = " << corrType << std::endl;
1074 - corrTypeColumn.put(row, corrType);
1075 - }
1076 - }
1077 -};
1078 -
1079 -// Fixture class for testing dirty (partially flagged) data
1080 -// NB: use same data as PolAverageTVITest
1081 -class PolAverageTVIDirtyDataTest: public PolAverageTVITest {
1082 -public:
1083 - virtual void SetUp() {
1084 - // call parent's SetUp method
1085 - PolAverageTVITestBase::SetUp();
1086 -
1087 - // corrupt data
1088 - CorruptData();
1089 - }
1090 -
1091 -private:
1092 - // Make input data dirty
1093 - void CorruptData() {
1094 - // Accessor to FLAG column
1095 - ArrayColumn<Bool> flagColumn(*ms_, "FLAG");
1096 - Cube<Bool> flag = flagColumn.getColumn();
1097 -
1098 - // Accessor to DATA columns
1099 - Cube<Float> floatData;
1100 - Cube<Complex> complexData, correctedData;
1101 - if (ms_->tableDesc().isColumn("DATA")) {
1102 - ArrayColumn<Complex> dataColumn(*ms_, "DATA");
1103 - dataColumn.getColumn(complexData);
1104 - ASSERT_EQ(flag.shape(), complexData.shape());
1105 - }
1106 - if (ms_->tableDesc().isColumn("FLOAT_DATA")) {
1107 - ArrayColumn<Float> dataColumn(*ms_, "FLOAT_DATA");
1108 - dataColumn.getColumn(floatData);
1109 - ASSERT_EQ(flag.shape(), floatData.shape());
1110 - }
1111 - if (ms_->tableDesc().isColumn("CORRECTED_DATA")) {
1112 - ArrayColumn<Complex> dataColumn(*ms_, "CORRECTED_DATA");
1113 - dataColumn.getColumn(correctedData);
1114 - ASSERT_EQ(flag.shape(), correctedData.shape());
1115 - }
1116 -
1117 - // corrupt row 0, channel 10, pol 1
1118 - size_t row = 0;
1119 - size_t chan = 10;
1120 - size_t pol = 1;
1121 - ASSERT_GT((size_t )flag.nplane(), row);
1122 - ASSERT_GT((size_t )flag.ncolumn(), chan);
1123 - ASSERT_GT((size_t )flag.nrow(), pol);
1124 - Float corruptValue = std::numeric_limits<float>::quiet_NaN();
1125 - flag(pol, chan, row) = True;
1126 - ASSERT_EQ(flag(pol, chan, row), True);
1127 - if (!floatData.empty()) {
1128 - floatData(pol, chan, row) = corruptValue;
1129 - ASSERT_TRUE(std::isnan(floatData(pol, chan, row)));
1130 - }
1131 - if (!complexData.empty()) {
1132 - complexData(pol, chan, row) = corruptValue;
1133 - ASSERT_TRUE(std::isnan(complexData(pol, chan, row).real()));
1134 - }
1135 - if (!correctedData.empty()) {
1136 - correctedData(pol, chan, row) = corruptValue;
1137 - ASSERT_TRUE(std::isnan(correctedData(pol, chan, row).real()));
1138 - }
1139 -
1140 - // corrupt row 1, channel 100, all pols
1141 - row = 1;
1142 - chan = 100;
1143 - ASSERT_GT((size_t )flag.nplane(), row);
1144 - ASSERT_GT((size_t )flag.ncolumn(), chan);
1145 - IPosition blc(3, 0, chan, row);
1146 - IPosition trc(3, flag.nrow() - 1, chan, row);
1147 - flag(blc, trc) = True;
1148 - ASSERT_EQ(flag(0, chan, row), True);
1149 - ASSERT_EQ(flag(1, chan, row), True);
1150 - if (!floatData.empty()) {
1151 - floatData(blc, trc) = corruptValue;
1152 - ASSERT_TRUE(std::isnan(floatData(0, chan, row)));
1153 - ASSERT_TRUE(std::isnan(floatData(1, chan, row)));
1154 - }
1155 - if (!complexData.empty()) {
1156 - complexData(blc, trc) = corruptValue;
1157 - ASSERT_TRUE(std::isnan(complexData(0, chan, row).real()));
1158 - ASSERT_TRUE(std::isnan(complexData(1, chan, row).real()));
1159 - }
1160 - if (!correctedData.empty()) {
1161 - correctedData(blc, trc) = corruptValue;
1162 - ASSERT_TRUE(std::isnan(correctedData(0, chan, row).real()));
1163 - ASSERT_TRUE(std::isnan(correctedData(1, chan, row).real()));
1164 - }
1165 -
1166 - // write back to MS
1167 - flagColumn.putColumn(flag);
1168 - if (ms_->tableDesc().isColumn("DATA")) {
1169 - ArrayColumn<Complex> dataColumn(*ms_, "DATA");
1170 - dataColumn.putColumn(complexData);
1171 - }
1172 - if (ms_->tableDesc().isColumn("FLOAT_DATA")) {
1173 - ArrayColumn<Float> dataColumn(*ms_, "FLOAT_DATA");
1174 - dataColumn.putColumn(floatData);
1175 - }
1176 - if (ms_->tableDesc().isColumn("CORRECTED_DATA")) {
1177 - ArrayColumn<Complex> dataColumn(*ms_, "CORRECTED_DATA");
1178 - dataColumn.putColumn(correctedData);
1179 - }
1180 - }
1181 814
1182 815 };
1183 816
1184 817 TEST_F(PolAverageTVITest, Factory) {
1185 818
1186 819 TestFactory("default", "StokesPolAverage");
1187 820 TestFactory("Default", "StokesPolAverage");
1188 821 TestFactory("DEFAULT", "StokesPolAverage");
1189 822 TestFactory("geometric", "GeometricPolAverage");
1190 823 TestFactory("Geometric", "GeometricPolAverage");
1191 824 TestFactory("GEOMETRIC", "GeometricPolAverage");
1192 825 TestFactory("stokes", "StokesPolAverage");
1193 826 TestFactory("Stokes", "StokesPolAverage");
1194 827 TestFactory("STOKES", "StokesPolAverage");
1195 828 // empty mode (default)
1196 829 TestFactory("", "StokesPolAverage");
1197 830 // invalid mode (throw exception)
1198 831 TestFactory("invalid", "");
1199 832 }
1200 833
1201 834 TEST_F(PolAverageTVITest, GeometricAverage) {
1202 - // Use different types of constructor to create factory
835 + // Use difference type of constructor to create factory
1203 836 TestTVI<GeometricAverageValidator, BasicManufacturer1>();
1204 837 TestTVI<GeometricAverageValidator, BasicManufacturer2>();
1205 838 TestTVI<GeometricAverageValidator, LayerManufacturer>();
1206 839 }
1207 840
1208 841 TEST_F(PolAverageTVITest, StokesAverage) {
1209 - // Use different types of constructor to create factory
842 + // Use difference type of constructor to create factory
1210 843 TestTVI<StokesAverageValidator, BasicManufacturer1>();
1211 844 TestTVI<StokesAverageValidator, BasicManufacturer2>();
1212 845 TestTVI<StokesAverageValidator, LayerManufacturer>();
1213 846 }
1214 847
1215 -TEST_F(PolAverageTVIDirtyDataTest, GeometricAverageCorrupted) {
1216 - TestTVI<GeometricAverageValidator, BasicManufacturer1>();
1217 -}
1218 -
1219 -TEST_F(PolAverageTVIDirtyDataTest, StokesAverageCorrupted) {
1220 - TestTVI<StokesAverageValidator, BasicManufacturer1>();
1221 -}
1222 -
1223 -TEST_F(PolAverageTVIFourPolarizationTest, GeometricAverageSkipped) {
1224 - // Edit CORR_TYPE to be IQUV
1225 - SetCorrTypeToStokes();
1226 -
1227 - TestTVI<GeometricIdenticalValidator, BasicManufacturer1>();
1228 -}
1229 -
1230 -TEST_F(PolAverageTVIFourPolarizationTest, StokesAverageSkipped) {
1231 - // Edit CORR_TYPE to be IQUV
1232 - SetCorrTypeToStokes();
1233 -
1234 - TestTVI<StokesIdenticalValidator, BasicManufacturer1>();
1235 -}
1236 -
1237 -TEST_F(PolAverageTVIFourPolarizationTest, GeometricAverageCrossPol) {
1238 - TestTVI<GeometricAverageCrossPolarizationValidator, BasicManufacturer1>();
1239 -}
1240 -
1241 -TEST_F(PolAverageTVIFourPolarizationTest, StokesAverageCrossPol) {
1242 - TestTVI<StokesAverageCrossPolarizationValidator, BasicManufacturer1>();
1243 -}
848 +// TODO: define test on not-well-behaved data
849 +// such as partially flagged, contains NaN, etc.
850 +// TODO: define test on the data that should not average
851 +// along polarization axis (IQUV or single polarization or multi-pol but not containing XX and YY)
1244 852
1245 853 int main(int argc, char **argv) {
1246 854 ::testing::InitGoogleTest(&argc, argv);
1247 855 std::cout << "PolAverageTVI test " << std::endl;
1248 856 return RUN_ALL_TESTS();
1249 857 }

Everything looks good. We'll let you know here if there's anything you should know about.

Add shortcut