Commits
32 32 | |
33 33 | |
34 34 | |
35 35 | |
36 36 | |
37 37 | |
38 38 | |
39 39 | |
40 40 | |
41 41 | |
42 - | |
43 42 | |
44 43 | |
45 44 | |
46 45 | |
47 46 | |
48 47 | |
49 48 | |
50 - | |
51 - | |
52 49 | |
53 50 | using namespace std; |
54 51 | using namespace casa; |
55 52 | using namespace casacore; |
56 53 | using namespace casa::vi; |
57 54 | |
58 55 | namespace { |
59 56 | string GetCasaDataPath() { |
60 57 | if (casacore::EnvironmentVariable::isDefined("CASAPATH")) { |
61 58 | string casapath = casacore::EnvironmentVariable::get("CASAPATH"); |
158 155 | EXPECT_FLOAT_EQ(ref, result); |
159 156 | } |
160 157 | |
161 158 | template<> |
162 159 | void ValidatorUtil::ValidateScalar<Complex>(Complex const ref, |
163 160 | Complex const result) { |
164 161 | EXPECT_FLOAT_EQ(ref.real(), result.real()); |
165 162 | EXPECT_FLOAT_EQ(ref.imag(), result.imag()); |
166 163 | } |
167 164 | |
168 - | // Base class for validating polarization average |
169 165 | template<class CollapserImpl> |
170 166 | struct ValidatorBase { |
171 - | private: |
172 167 | template<class T> |
173 168 | static void CollapseData(Array<T> const &baseData, |
174 169 | Array<Bool> const &baseFlag, Array<Float> const &baseWeight, |
175 170 | Array<T> &result) { |
176 171 | ASSERT_EQ(baseData.shape(), baseWeight.shape()); |
177 172 | ASSERT_EQ(baseData.ndim(), (uInt )3); |
178 173 | IPosition const baseShape = baseData.shape(); |
179 174 | IPosition shape(baseShape); |
180 175 | shape[0] = 1; |
181 176 | result.resize(shape); |
202 197 | p_r[j] += p_d[j] * p_w[j]; |
203 198 | p_s[j] += p_w[j]; |
204 199 | } |
205 200 | } |
206 201 | result.putStorage(p_r, b1); |
207 202 | weightSum.putStorage(p_s, b2); |
208 203 | dSlice.freeStorage(p_d, b3); |
209 204 | weight.freeStorage(p_w, b4); |
210 205 | fSlice.freeStorage(p_f, b5); |
211 206 | } |
212 - | // std::cout << "DEBUG: result = " << result.nonDegenerate() |
213 - | // << " weightSum = " << weightSum.nonDegenerate() << std::endl; |
214 - | Bool b1, b2; |
215 - | T *p_r = result.getStorage(b1); |
216 - | Float const *p_s = weightSum.getStorage(b2); |
217 - | for (ssize_t i = 0; i < result.shape().product(); ++i) { |
218 - | if (p_s[i] > 0.0f) { |
219 - | p_r[i] /= p_s[i]; |
220 - | } |
221 - | } |
222 - | result.putStorage(p_r, b1); |
223 - | weightSum.freeStorage(p_s, b2); |
224 - | // std::cout << "DEBUG: r/w = " << result << std::endl; |
207 + | result /= weightSum; |
225 208 | } |
226 209 | |
227 210 | static void CollapseWeight(Array<Float> const &baseWeight, |
228 211 | Array<Float> &result) { |
229 212 | IPosition const baseShape = baseWeight.shape(); |
230 213 | IPosition resultShape(baseShape); |
231 214 | resultShape[0] = 1; |
232 215 | result.resize(resultShape); |
233 216 | result = 0; |
234 217 | size_t nPol = baseShape[0]; |
235 218 | for (size_t i = 0; i < nPol; ++i) { |
236 219 | IPosition start(baseShape.size(), 0); |
237 220 | start[0] = i; |
238 221 | IPosition end(baseShape); |
239 222 | end -= 1; |
240 223 | end[0] = i; |
241 224 | auto wslice = baseWeight(start, end); |
242 - | //CollapserImpl::AccumulateWeight(wslice, result); |
243 - | CollapserImpl::AccumulateWeight(start, end, baseWeight, result); |
225 + | CollapserImpl::AccumulateWeight(wslice, result); |
244 226 | } |
245 227 | CollapserImpl::NormalizeWeight(result); |
246 228 | } |
247 229 | |
248 230 | template<class T> |
249 231 | static void ValidateDataColumn(Array<T> const &data, MeasurementSet const &ms, |
250 232 | String const &columnName, Vector<uInt> const &rowIds) { |
251 233 | cout << "Start " << __func__ << endl; |
252 - | ASSERT_EQ(data.shape()[0], (ssize_t )1); |
253 234 | Array<T> baseData; |
254 235 | Array<Float> baseWeight; |
255 236 | Array<Bool> baseFlag; |
256 237 | Filler<T>::FillArrayReference(ms, columnName, rowIds, baseData); |
257 238 | Filler<Bool>::FillArrayReference(ms, "FLAG", rowIds, baseFlag); |
258 239 | baseWeight.resize(baseData.shape()); |
259 240 | Filler<Float>::FillWeightSp(ms, rowIds, baseWeight); |
260 241 | Cube<T> ref(data.shape()); |
261 242 | CollapseData(baseData, baseFlag, baseWeight, ref); |
262 243 | ValidatorUtil::ValidateArray<T>(ref, data); |
263 244 | } |
264 245 | |
265 246 | static void ValidateWeightColumn(Array<Float> const &weight, |
266 247 | MeasurementSet const &ms, String const &columnName, |
267 248 | Vector<uInt> const &rowIds) { |
268 - | ASSERT_EQ(weight.shape()[0], (ssize_t )1); |
269 249 | Array<Float> baseWeight; |
270 250 | Filler<Float>::FillArrayReference(ms, columnName, rowIds, baseWeight); |
271 251 | Array<Float> ref; |
272 252 | CollapseWeight(baseWeight, ref); |
273 253 | ValidatorUtil::ValidateArray(ref, weight); |
274 254 | } |
275 255 | |
276 - | public: |
277 - | static void ValidatePolarization(Vector<Int> const &corrType) { |
278 - | // correlation type is always I |
279 - | ASSERT_EQ((size_t )1, corrType.size()); |
280 - | ASSERT_TRUE(allEQ(corrType, (Int )Stokes::I)); |
281 - | } |
282 256 | static void ValidateData(Array<Complex> const &data, MeasurementSet const &ms, |
283 257 | Vector<uInt> const &rowIds) { |
284 258 | ValidateDataColumn(data, ms, "DATA", rowIds); |
285 259 | } |
286 260 | |
287 261 | static void ValidateCorrected(Cube<Complex> const &data, |
288 262 | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
289 263 | ValidateDataColumn(data, ms, "CORRECTED_DATA", rowIds); |
290 264 | } |
291 265 | |
315 289 | IPosition start(3, i, 0, 0); |
316 290 | IPosition end(3, i, nChan - 1, nRow - 1); |
317 291 | auto fslice = baseFlag(start, end); |
318 292 | ref &= fslice; |
319 293 | } |
320 294 | ValidatorUtil::ValidateArray(ref, flag); |
321 295 | } |
322 296 | |
323 297 | static void ValidateFlagRow(Vector<Bool> const &flag, |
324 298 | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
325 - | ASSERT_EQ(flag.size(), rowIds.size()); |
326 299 | Vector<Bool> ref; |
327 300 | Filler<Bool>::FillScalarReference(ms, "FLAG_ROW", rowIds, ref); |
328 301 | ASSERT_EQ(flag.shape(), ref.shape()); |
329 302 | EXPECT_TRUE(allEQ(flag, ref)); |
330 303 | } |
331 304 | |
332 305 | static void ValidateWeight(Matrix<Float> const &weight, |
333 306 | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
334 307 | ValidateWeightColumn(weight, ms, "WEIGHT", rowIds); |
335 308 | } |
353 326 | IPosition end(3, i, nChan - 1, j); |
354 327 | auto wslice = weight(start, end); |
355 328 | auto wref = scalarWeight(i, j); |
356 329 | EXPECT_TRUE(allEQ(wref, wslice)); |
357 330 | } |
358 331 | } |
359 332 | } |
360 333 | } |
361 334 | }; |
362 335 | |
363 - | // Base class for validating polarization average being skipped |
364 - | struct IdenticalValidator { |
365 - | public: |
366 - | static void ValidatePolarization(Vector<Int> const &corrType) { |
367 - | Int possibleTypes[] = { Stokes::I, Stokes::Q, Stokes::U, Stokes::V, |
368 - | Stokes::XY, Stokes::YX }; |
369 - | size_t len = sizeof(possibleTypes) / sizeof(Int); |
370 - | Vector<Int> possibleTypesV(IPosition(1, len), possibleTypes, SHARE); |
371 - | auto iterend = corrType.end(); |
372 - | for (auto iter = corrType.begin(); iter != iterend; ++iter) { |
373 - | ASSERT_TRUE(anyEQ(possibleTypesV, *iter)); |
374 - | } |
375 - | } |
376 - | static void ValidateData(Array<Complex> const &data, MeasurementSet const &ms, |
377 - | Vector<uInt> const &rowIds) { |
378 - | ValidateArrayColumn(data, ms, "DATA", rowIds); |
379 - | } |
380 - | static void ValidateCorrected(Cube<Complex> const &data, |
381 - | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
382 - | ValidateArrayColumn(data, ms, "CORRECTED_DATA", rowIds); |
383 - | } |
384 - | static void ValidateModel(Cube<Complex> const &data, MeasurementSet const &ms, |
385 - | Vector<uInt> const &rowIds) { |
386 - | ValidateArrayColumn(data, ms, "MODEL_DATA", rowIds); |
387 - | } |
388 - | static void ValidateFloat(Cube<Float> const &data, MeasurementSet const &ms, |
389 - | Vector<uInt> const &rowIds) { |
390 - | ValidateArrayColumn(data, ms, "FLOAT_DATA", rowIds); |
391 - | } |
392 - | static void ValidateFlag(Cube<Bool> const &flag, MeasurementSet const &ms, |
393 - | Vector<uInt> const &rowIds) { |
394 - | ValidateArrayColumn(flag, ms, "FLAG", rowIds); |
395 - | } |
396 - | static void ValidateFlagRow(Vector<Bool> const &flag, |
397 - | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
398 - | ValidateScalarColumn(flag, ms, "FLAG_ROW", rowIds); |
399 - | } |
400 - | static void ValidateWeight(Matrix<Float> const &weight, |
401 - | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
402 - | ValidateArrayColumn(weight, ms, "WEIGHT", rowIds); |
403 - | } |
404 - | static void ValidateWeightSp(Cube<Float> const &weight, |
405 - | MeasurementSet const &ms, Vector<uInt> const &rowIds) { |
406 - | if (ms.tableDesc().isColumn("WEIGHT_SPECTRUM")) { |
407 - | ValidateArrayColumn(weight, ms, "WEIGHT_SPECTRUM", rowIds); |
408 - | } |
409 - | } |
410 - | private: |
411 - | template<class T> |
412 - | static void ValidateArrayColumn(Array<T> const &data, |
413 - | MeasurementSet const &ms, String const &columnName, |
414 - | Vector<uInt> const &rowIds) { |
415 - | Array<T> ref; |
416 - | Filler<T>::FillArrayReference(ms, columnName, rowIds, ref); |
417 - | EXPECT_EQ(data.shape(), ref.shape()); |
418 - | EXPECT_TRUE(allEQ(data, ref)); |
419 - | } |
420 - | template<class T> |
421 - | static void ValidateScalarColumn(Array<T> const &data, |
422 - | MeasurementSet const &ms, String const &columnName, |
423 - | Vector<uInt> const &rowIds) { |
424 - | Vector<T> ref; |
425 - | Filler<T>::FillScalarReference(ms, columnName, rowIds, ref); |
426 - | EXPECT_EQ(data.shape(), ref.shape()); |
427 - | EXPECT_TRUE(allEQ(data, ref)); |
428 - | } |
429 - | }; |
430 - | |
431 - | // Base class for Geometric type validator |
432 - | struct GeometricValidatorBase { |
336 + | struct GeometricAverageValidator: public ValidatorBase<GeometricAverageValidator> { |
433 337 | static String GetMode() { |
434 338 | return "geometric"; |
435 339 | } |
436 340 | |
437 341 | static String GetTypePrefix() { |
438 342 | return "GeometricPolAverage("; |
439 343 | } |
440 - | }; |
441 - | |
442 - | // Base class for Stokes type validator |
443 - | struct StokesValidatorBase { |
444 - | static String GetMode() { |
445 - | return "stokes"; |
446 - | } |
447 - | |
448 - | static String GetTypePrefix() { |
449 - | return "StokesPolAverage("; |
450 - | } |
451 - | }; |
452 344 | |
453 - | // Validator for Geometric polarization average |
454 - | struct GeometricAverageValidator: public GeometricValidatorBase, |
455 - | public ValidatorBase<GeometricAverageValidator> { |
456 345 | static void SetWeight(IPosition const &start, IPosition const &end, |
457 346 | Array<Float> const &baseWeight, Array<Float> &weight) { |
458 347 | weight = baseWeight(start, end); |
459 348 | } |
460 349 | |
461 - | static void AccumulateWeight(IPosition const &start, IPosition const &end, |
462 - | Array<Float> const &weight, Array<Float> &result) { |
463 - | result += weight(start, end); |
350 + | static void AccumulateWeight(Array<Float> const &weight, |
351 + | Array<Float> &result) { |
352 + | result += weight; |
464 353 | } |
465 354 | |
466 355 | static void NormalizeWeight(Array<Float> &/*result*/) { |
467 356 | } |
468 357 | }; |
469 358 | |
470 - | // Validator for Stokes polarization average |
471 - | struct StokesAverageValidator: public StokesValidatorBase, public ValidatorBase< |
472 - | StokesAverageValidator> { |
473 - | static void SetWeight(IPosition const &/*start*/, IPosition const &/*end*/, |
474 - | Array<Float> const &/*baseWeight*/, Array<Float> &weight) { |
475 - | weight = 1.0f; |
476 - | } |
477 - | |
478 - | static void AccumulateWeight(IPosition const &start, IPosition const &end, |
479 - | Array<Float> const &weight, Array<Float> &result) { |
480 - | result += 1.0f / weight(start, end); |
481 - | } |
482 - | |
483 - | static void NormalizeWeight(Array<Float> &result) { |
484 - | result = 4.0f / result; |
485 - | } |
486 - | }; |
487 - | |
488 - | // Validator for Geometric polarization average including cross-polarization |
489 - | struct GeometricAverageCrossPolarizationValidator: public GeometricValidatorBase, |
490 - | public ValidatorBase<GeometricAverageCrossPolarizationValidator> { |
491 - | static void SetWeight(IPosition const &start, IPosition const &end, |
492 - | Array<Float> const &baseWeight, Array<Float> &weight) { |
493 - | // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL) |
494 - | ASSERT_EQ(start[0], end[0]); |
495 - | if (start[0] == 1 || start[0] == 2) { |
496 - | // set weight for cross-polarization component to exclude it from the average |
497 - | weight = 0.0f; |
498 - | } else { |
499 - | weight = baseWeight(start, end); |
500 - | } |
501 - | } |
502 - | |
503 - | static void AccumulateWeight(IPosition const &start, IPosition const &end, |
504 - | Array<Float> const &weight, Array<Float> &result) { |
505 - | // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL) |
506 - | ASSERT_EQ(start[0], end[0]); |
507 - | if (start[0] != 1 && start[0] != 2) { |
508 - | result += weight(start, end); |
509 - | } |
359 + | struct StokesAverageValidator: public ValidatorBase<StokesAverageValidator> { |
360 + | static String GetMode() { |
361 + | return "stokes"; |
510 362 | } |
511 363 | |
512 - | static void NormalizeWeight(Array<Float> &/*result*/) { |
364 + | static String GetTypePrefix() { |
365 + | return "StokesPolAverage("; |
513 366 | } |
514 - | }; |
515 367 | |
516 - | // Validator for Stokes polarization average including cross-polarization |
517 - | struct StokesAverageCrossPolarizationValidator: public StokesValidatorBase, |
518 - | public ValidatorBase<StokesAverageCrossPolarizationValidator> { |
519 - | static void SetWeight(IPosition const &start, IPosition const &end, |
368 + | static void SetWeight(IPosition const &/*start*/, IPosition const &/*end*/, |
520 369 | Array<Float> const &/*baseWeight*/, Array<Float> &weight) { |
521 - | // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL) |
522 - | ASSERT_EQ(start[0], end[0]); |
523 - | if (start[0] == 1 || start[0] == 2) { |
524 - | // set weight for cross-polarization component to exclude it from the average |
525 - | weight = 0.0f; |
526 - | } else { |
527 - | weight = 1.0f; |
528 - | } |
370 + | weight = 1.0f; |
529 371 | } |
530 372 | |
531 - | static void AccumulateWeight(IPosition const &start, IPosition const &end, |
532 - | Array<Float> const &weight, Array<Float> &result) { |
533 - | // Here it is assumed that polarization order is 0: XX (RR), 1: XY (RL), 2: YX (LR), 3: YY (LL) |
534 - | ASSERT_EQ(start[0], end[0]); |
535 - | if (start[0] != 1 && start[0] != 2) { |
536 - | result += 1.0f / weight(start, end); |
537 - | } |
373 + | static void AccumulateWeight(Array<Float> const &weight, |
374 + | Array<Float> &result) { |
375 + | result += 1.0f / weight; |
538 376 | } |
539 377 | |
540 378 | static void NormalizeWeight(Array<Float> &result) { |
541 379 | result = 4.0f / result; |
542 380 | } |
543 381 | }; |
544 382 | |
545 - | // Validator for Geometric polarization average (identical case = skip) |
546 - | struct GeometricIdenticalValidator: public GeometricValidatorBase, |
547 - | public IdenticalValidator { |
548 - | |
549 - | }; |
550 - | |
551 - | // Validator for Stokes polarization average (identical case = skip) |
552 - | struct StokesIdenticalValidator: public StokesValidatorBase, |
553 - | public IdenticalValidator { |
554 - | |
555 - | }; |
556 - | |
557 383 | template<class Impl> |
558 384 | class Manufacturer { |
559 385 | public: |
560 - | struct Product { |
561 - | ViFactory *factory; |
562 - | ViImplementation2 *vii; |
563 - | }; |
564 386 | static VisibilityIterator2 *ManufactureVI(MeasurementSet *ms, |
565 387 | String const &mode) { |
566 388 | cout << "### Manufacturer: " << endl << "### " << Impl::GetTestPurpose() |
567 389 | << endl; |
568 390 | |
569 391 | Record modeRec; |
570 392 | if (mode.size() > 0) { |
571 393 | modeRec.define("mode", mode); |
572 394 | } |
573 395 | |
574 396 | // build factory object |
575 - | Product p = Impl::BuildFactory(ms, modeRec); |
576 - | std::unique_ptr < ViFactory > factory(p.factory); |
397 + | std::unique_ptr<ViFactory> factory(Impl::BuildFactory(ms, modeRec)); |
577 398 | |
578 - | std::unique_ptr < VisibilityIterator2 > vi; |
399 + | std::unique_ptr<VisibilityIterator2> vi; |
579 400 | try { |
580 401 | vi.reset(new VisibilityIterator2(*factory.get())); |
581 402 | } catch (...) { |
582 403 | cout << "Failed to create VI at factory" << endl; |
583 - | // vii must be deleted since it is never managed by vi |
584 - | if (p.vii) { |
585 - | delete p.vii; |
586 - | } |
587 404 | throw; |
588 405 | } |
589 406 | |
590 407 | cout << "Created VI type \"" << vi->ViiType() << "\"" << endl; |
591 408 | |
592 409 | return vi.release(); |
593 410 | } |
594 411 | }; |
595 412 | |
596 413 | class BasicManufacturer1: public Manufacturer<BasicManufacturer1> { |
597 414 | public: |
598 - | static Product BuildFactory(MeasurementSet *ms, Record const &mode) { |
415 + | static ViFactory *BuildFactory(MeasurementSet *ms, Record const &mode) { |
599 416 | // create read-only VI impl |
600 417 | Block<MeasurementSet const *> const mss(1, ms); |
601 418 | SortColumns defaultSortColumns; |
602 419 | |
603 420 | std::unique_ptr<ViImplementation2> inputVii( |
604 421 | new VisibilityIteratorImpl2(mss, defaultSortColumns, 0.0, VbPlain, |
605 422 | False)); |
606 423 | |
607 424 | std::unique_ptr<ViFactory> factory( |
608 425 | new PolAverageVi2Factory(mode, inputVii.get())); |
609 426 | |
610 - | Product p; |
611 - | |
612 427 | // vi will be responsible for releasing inputVii so unique_ptr |
613 428 | // should release the ownership here |
614 - | p.vii = inputVii.release(); |
615 - | p.factory = factory.release(); |
429 + | inputVii.release(); |
616 430 | |
617 - | return p; |
431 + | return factory.release(); |
618 432 | } |
619 433 | |
620 434 | static String GetTestPurpose() { |
621 435 | return "Test PolAverageVi2Factory(Record const &, ViImplementation2 *)"; |
622 436 | } |
623 437 | }; |
624 438 | |
625 439 | class BasicManufacturer2: public Manufacturer<BasicManufacturer2> { |
626 440 | public: |
627 - | static Product BuildFactory(MeasurementSet *ms, Record const &mode) { |
441 + | static ViFactory *BuildFactory(MeasurementSet *ms, Record const &mode) { |
628 442 | // create factory directly from MS |
629 443 | SortColumns defaultSortColumns; |
630 444 | std::unique_ptr<ViFactory> factory( |
631 445 | new PolAverageVi2Factory(mode, ms, defaultSortColumns, 0.0, False)); |
632 446 | |
633 - | Product p; |
634 - | p.vii = nullptr; |
635 - | p.factory = factory.release(); |
636 - | return p; |
447 + | return factory.release(); |
637 448 | } |
638 449 | |
639 450 | static String GetTestPurpose() { |
640 451 | return "Test PolAverageVi2Factory(Record const &, MeasurementSet const *, ...)"; |
641 452 | } |
642 453 | }; |
643 454 | |
644 455 | class LayerManufacturer: public Manufacturer<LayerManufacturer> { |
645 456 | public: |
646 457 | class LayerFactoryWrapper: public ViFactory { |
647 458 | public: |
648 459 | LayerFactoryWrapper(MeasurementSet *ms, Record const &mode) : |
649 460 | ms_(ms), mode_(mode) { |
650 461 | } |
651 462 | |
652 463 | ViImplementation2 *createVi() const { |
653 464 | Vector<ViiLayerFactory *> v(1); |
654 - | auto layer0 = VisIterImpl2LayerFactory(ms_, IteratingParameters(0.0), |
655 - | false); |
465 + | auto layer0 = VisIterImpl2LayerFactory(ms_, IteratingParameters(0.0), false); |
656 466 | auto layer1 = PolAverageTVILayerFactory(mode_); |
657 467 | v[0] = &layer0; |
658 468 | return layer1.createViImpl2(v); |
659 469 | } |
660 470 | private: |
661 471 | MeasurementSet *ms_; |
662 472 | Record const mode_; |
663 473 | }; |
664 474 | |
665 - | static Product BuildFactory(MeasurementSet *ms, Record const &mode) { |
475 + | static ViFactory *BuildFactory(MeasurementSet *ms, Record const &mode) { |
666 476 | // create read-only VI impl |
667 - | std::unique_ptr<ViFactory> factory(new LayerFactoryWrapper(ms, mode)); |
477 + | std::unique_ptr<ViFactory> factory( |
478 + | new LayerFactoryWrapper(ms, mode)); |
668 479 | |
669 - | Product p; |
670 - | p.vii = nullptr; |
671 - | p.factory = factory.release(); |
672 - | return p; |
480 + | return factory.release(); |
673 481 | } |
674 482 | |
675 483 | static String GetTestPurpose() { |
676 484 | return "Test PolAverageTVILayerFactory"; |
677 485 | } |
678 486 | }; |
679 487 | |
680 488 | } // anonymous namespace |
681 489 | |
682 - | class PolAverageTVITestBase: public ::testing::Test { |
490 + | class PolAverageTVITest: public ::testing::Test { |
683 491 | public: |
684 - | PolAverageTVITestBase() : |
685 - | my_ms_name_("polaverage_test.ms"), my_data_name_(), ms_(nullptr) { |
686 - | } |
687 - | |
688 492 | virtual void SetUp() { |
689 493 | // my_data_name_ = "analytic_spectra.ms"; |
690 - | my_data_name_ = GetDataName(); //"analytic_type1.bl.ms"; |
691 - | std::string const data_path = ::GetCasaDataPath() + "/regression/unittest/" |
692 - | + GetRelativeDataPath() + "/"; |
693 - | // + "/regression/unittest/tsdbaseline/"; |
494 + | my_data_name_ = "analytic_type1.bl.ms"; |
495 + | my_ms_name_ = "polaverage_test.ms"; |
496 + | std::string const data_path = ::GetCasaDataPath() |
497 + | + "/regression/unittest/tsdbaseline/"; |
694 498 | // + "/regression/unittest/singledish/"; |
695 499 | |
696 - | ASSERT_TRUE(Directory(data_path).exists()); |
697 500 | copyDataFromRepository(data_path); |
698 501 | ASSERT_TRUE(File(my_data_name_).exists()); |
699 502 | deleteTable(my_ms_name_); |
700 503 | |
701 504 | // create MS |
702 - | ms_ = new MeasurementSet(my_data_name_, Table::Update); |
505 + | ms_ = MeasurementSet(my_data_name_, Table::Old); |
703 506 | } |
704 507 | |
705 508 | virtual void TearDown() { |
706 - | // delete MS explicitly to detach from MS on disk |
707 - | delete ms_; |
708 - | |
709 - | // just to make sure all locks are effectively released |
710 - | Table::relinquishAutoLocks(); |
711 - | |
712 509 | cleanup(); |
713 510 | } |
714 - | |
715 511 | protected: |
716 - | std::string const my_ms_name_; |
512 + | std::string my_ms_name_; |
717 513 | std::string my_data_name_; |
718 - | MeasurementSet *ms_; |
514 + | MeasurementSet ms_; |
719 515 | |
720 - | virtual std::string GetDataName() { |
721 - | return ""; |
516 + | VisibilityIterator2 *ManufactureVI(String const &mode) { |
517 + | return BasicManufacturer1::ManufactureVI(&ms_, mode); |
722 518 | } |
723 519 | |
724 - | virtual std::string GetRelativeDataPath() { |
725 - | return ""; |
520 + | void TestFactory(String const &mode, String const &expectedClassName) { |
521 + | |
522 + | cout << "Mode \"" << mode << "\" expected class name \"" |
523 + | << expectedClassName << "\"" << endl; |
524 + | |
525 + | if (expectedClassName.size() > 0) { |
526 + | std::unique_ptr<VisibilityIterator2> vi(ManufactureVI(mode)); |
527 + | |
528 + | // Verify type string |
529 + | String viiType = vi->ViiType(); |
530 + | EXPECT_TRUE(viiType.startsWith(expectedClassName)); |
531 + | } else { |
532 + | cout << "Creation of VI via factory will fail" << endl; |
533 + | // exception must be thrown |
534 + | EXPECT_THROW( { |
535 + | std::unique_ptr<VisibilityIterator2> vi(ManufactureVI(mode)); //new VisibilityIterator2(factory)); |
536 + | }, |
537 + | AipsError)<< "The process must throw AipsError"; |
538 + | } |
726 539 | } |
727 540 | |
728 - | template<class Validator, class Manufacturer = BasicManufacturer1> |
541 + | template<class Validator, class Manufacturer=BasicManufacturer1> |
729 542 | void TestTVI() { |
730 543 | // Create VI |
731 - | std::unique_ptr < VisibilityIterator2 |
732 - | > vi(Manufacturer::ManufactureVI(ms_, Validator::GetMode())); |
544 + | std::unique_ptr<VisibilityIterator2> vi(Manufacturer::ManufactureVI(&ms_, Validator::GetMode())); |
733 545 | ASSERT_TRUE(vi->ViiType().startsWith(Validator::GetTypePrefix())); |
734 546 | |
735 547 | // MS property |
736 548 | auto ms = vi->ms(); |
737 549 | uInt const nRowMs = ms.nrow(); |
738 550 | uInt const nRowPolarizationTable = ms.polarization().nrow(); |
739 551 | auto const desc = ms.tableDesc(); |
740 552 | auto const correctedExists = desc.isColumn("CORRECTED_DATA"); |
741 553 | auto const modelExists = desc.isColumn("MODEL_DATA"); |
742 554 | auto const dataExists = desc.isColumn("DATA"); |
743 555 | auto const floatExists = desc.isColumn("FLOAT_DATA"); |
744 556 | //auto const weightSpExists = desc.isColumn("WEIGHT_SPECTRUM"); |
745 557 | cout << "MS Property" << endl; |
746 558 | cout << "\tMS Name: \"" << ms.tableName() << "\"" << endl; |
747 559 | cout << "\tNumber of Rows: " << nRowMs << endl; |
748 560 | cout << "\tNumber of Spws: " << vi->nSpectralWindows() << endl; |
749 561 | cout << "\tNumber of Polarizations: " << vi->nPolarizationIds() << endl; |
750 562 | cout << "\tNumber of DataDescs: " << vi->nDataDescriptionIds() << endl; |
751 563 | cout << "\tChannelized Weight Exists? " |
752 - | << (vi->weightSpectrumExists() ? "True" : "False") << endl; |
564 + | << (vi->weightSpectrumExists() ? "True" : "False") << endl; |
753 565 | //cout << "\tChannelized Sigma Exists? " << (vi->sigmaSpectrumExists() ? "True" : "False") << endl; |
754 566 | |
755 567 | // mv-VI consistency check |
756 - | EXPECT_EQ(nRowPolarizationTable + 1, (uInt )vi->nPolarizationIds()); |
568 + | EXPECT_EQ(nRowPolarizationTable + 1, (uInt)vi->nPolarizationIds()); |
757 569 | |
758 570 | // VI iteration |
759 571 | Vector<uInt> swept(nRowMs, 0); |
760 572 | uInt nRowChunkSum = 0; |
761 573 | VisBuffer2 *vb = vi->getVisBuffer(); |
762 574 | vi->originChunks(); |
763 575 | while (vi->moreChunks()) { |
764 576 | vi->origin(); |
765 577 | Int const nRowChunk = vi->nRowsInChunk(); |
766 578 | nRowChunkSum += nRowChunk; |
767 579 | cout << "*************************" << endl; |
768 580 | cout << "*** Start loop on chunk " << vi->getSubchunkId().chunk() << endl; |
769 581 | cout << "*** Number of Rows: " << nRowChunk << endl; |
770 582 | cout << "*************************" << endl; |
771 583 | |
772 584 | Int nRowSubchunkSum = 0; |
773 585 | |
774 586 | while (vi->more()) { |
775 587 | auto subchunk = vi->getSubchunkId(); |
776 - | cout << "=== Start loop on subchunk " << subchunk.subchunk() << " ===" |
777 - | << endl; |
588 + | cout << "=== Start loop on subchunk " << subchunk.subchunk() << " ===" << endl; |
778 589 | |
779 590 | // cannot use getInterval due to the error |
780 591 | // "undefined reference to VisibilityIterator2::getInterval" |
781 592 | // even if the code is liked to libmsvis.so. |
782 593 | //cout << "Interval: " << vi->getInterval() << endl; |
783 594 | |
784 595 | cout << "Antenna1: " << vb->antenna1() << endl; |
785 596 | cout << "Antenna2: " << vb->antenna2() << endl; |
786 597 | cout << "Array Id: " << vb->arrayId() << endl; |
787 598 | cout << "Data Desc Ids: " << vb->dataDescriptionIds() << endl; |
792 603 | cout << "Field Id: " << vb->fieldId() << endl; |
793 604 | cout << "Flag Row: " << vb->flagRow() << endl; |
794 605 | cout << "Observation Id: " << vb->observationId() << endl; |
795 606 | cout << "Processor Id: " << vb->processorId() << endl; |
796 607 | cout << "Scan: " << vb->scan() << endl; |
797 608 | cout << "State Id: " << vb->stateId() << endl; |
798 609 | cout << "Time: " << vb->time() << endl; |
799 610 | cout << "Time Centroid: " << vb->timeCentroid() << endl; |
800 611 | cout << "Time Interval: " << vb->timeInterval() << endl; |
801 612 | auto const corrTypes = vb->correlationTypes(); |
802 - | auto toStokes = [](Vector<Int> const &corrTypes) { |
803 - | Vector<String> typeNames(corrTypes.size()); |
804 - | for (size_t i = 0; i < corrTypes.size(); ++i) { |
805 - | typeNames[i] = Stokes::name((Stokes::StokesTypes)corrTypes[i]); |
806 - | } |
807 - | return typeNames; |
808 - | }; |
809 - | cout << "Correlation Types: " << toStokes(corrTypes) << endl; |
613 + | cout << "Correlation Types: " << corrTypes << endl; |
810 614 | //cout << "UVW: " << vb->uvw() << endl; |
811 615 | |
812 616 | cout << "---" << endl; |
813 617 | Int nRowSubchunk = vb->nRows(); |
814 618 | Vector<uInt> rowIds = vb->rowIds(); |
815 619 | for (auto iter = rowIds.begin(); iter != rowIds.end(); ++iter) { |
816 620 | swept[*iter] += 1; |
817 621 | } |
818 622 | nRowSubchunkSum += nRowSubchunk; |
819 623 | Int nAnt = vb->nAntennas(); |
859 663 | EXPECT_EQ(visShape, visCubeCorrected.shape()); |
860 664 | } |
861 665 | EXPECT_EQ(!modelExists, visCubeModel.empty()); |
862 666 | if (!visCubeModel.empty()) { |
863 667 | EXPECT_EQ(visShape, visCubeModel.shape()); |
864 668 | } |
865 669 | EXPECT_EQ(!floatExists, visCubeFloat.empty()); |
866 670 | if (!visCubeFloat.empty()) { |
867 671 | EXPECT_EQ(visShape, visCubeFloat.shape()); |
868 672 | } |
869 - | EXPECT_EQ((ssize_t )nRowSubchunk, weight.shape()[1]); |
673 + | EXPECT_EQ((uInt)1, flagRow.size()); |
674 + | EXPECT_EQ((ssize_t)1, weight.shape()[0]); |
675 + | EXPECT_EQ((ssize_t)nRowSubchunk, weight.shape()[1]); |
870 676 | // NB: weight spectrum is created on-the-fly based on WEIGHT |
871 677 | // so that weightSp is always non-empty. |
872 678 | // see VisBufferImpl2::fillWeightSpectrum. |
873 679 | //EXPECT_EQ(!weightSpExists, weightSp.empty()); |
874 680 | EXPECT_FALSE(weightSp.empty()); |
875 681 | if (!weightSp.empty()) { |
876 - | EXPECT_EQ((ssize_t )nChan, weightSp.shape()[1]); |
877 - | EXPECT_EQ((ssize_t )nRowSubchunk, weightSp.shape()[2]); |
682 + | EXPECT_EQ((ssize_t)1, weightSp.shape()[0]); |
683 + | EXPECT_EQ((ssize_t)nChan, weightSp.shape()[1]); |
684 + | EXPECT_EQ((ssize_t)nRowSubchunk, weightSp.shape()[2]); |
878 685 | } |
879 686 | |
880 687 | // polarization averaging specific check |
688 + | // length of the correlation (polarization) axis must be 1 |
689 + | ASSERT_EQ((ssize_t)1, visShape[0]); |
881 690 | // polarization id always points to the row to be appended |
882 - | ASSERT_EQ(nRowPolarizationTable, (uInt )vb->polarizationId()); |
883 - | Validator::ValidatePolarization(corrTypes); |
691 + | ASSERT_EQ(nRowPolarizationTable, (uInt)vb->polarizationId()); |
692 + | // correlation type is always I |
693 + | ASSERT_EQ((size_t)1, corrTypes.size()); |
694 + | ASSERT_TRUE(allEQ(corrTypes, (Int)Stokes::I)); |
884 695 | |
885 696 | // validation of polarization average |
886 697 | if (!visCube.empty()) { |
887 698 | cout << "validate DATA" << endl; |
888 699 | Validator::ValidateData(visCube, ms, rowIds); |
889 700 | } |
890 701 | if (!visCubeCorrected.empty()) { |
891 702 | cout << "validate CORRECTED_DATA" << endl; |
892 703 | Validator::ValidateCorrected(visCubeCorrected, ms, rowIds); |
893 704 | } |
911 722 | // chunk-subchunk consistency check |
912 723 | EXPECT_EQ(nRowChunk, nRowSubchunkSum); |
913 724 | |
914 725 | vi->nextChunk(); |
915 726 | } |
916 727 | |
917 728 | // chunk-ms consistency check |
918 729 | EXPECT_EQ(nRowMs, nRowChunkSum); |
919 730 | |
920 731 | // iteration check |
921 - | EXPECT_TRUE(allEQ(swept, (uInt )1)); |
732 + | EXPECT_TRUE(allEQ(swept, (uInt)1)); |
922 733 | |
923 734 | } |
924 735 | |
925 736 | private: |
926 737 | void copyRegular(String const &src, String const &dst) { |
927 738 | RegularFile r(src); |
928 739 | r.copy(dst); |
929 740 | } |
930 741 | void copySymLink(String const &src, String const &dst) { |
931 742 | Path p = SymLink(src).followSymLink(); |
975 786 | copySymLink(full_path, work_path); |
976 787 | } else if (f.isRegular()) { |
977 788 | copyRegular(full_path, work_path); |
978 789 | } else if (f.isDirectory()) { |
979 790 | copyDirectory(full_path, work_path); |
980 791 | } |
981 792 | } |
982 793 | } |
983 794 | void cleanup() { |
984 795 | if (my_data_name_.size() > 0) { |
985 - | deleteTable(my_data_name_); |
796 + | File f(my_data_name_); |
797 + | if (f.isRegular()) { |
798 + | RegularFile r(my_data_name_); |
799 + | r.remove(); |
800 + | } else if (f.isDirectory()) { |
801 + | Directory d(my_data_name_); |
802 + | d.removeRecursive(); |
803 + | } |
986 804 | } |
987 805 | deleteTable(my_ms_name_); |
988 806 | } |
989 807 | void deleteTable(std::string const &name) { |
990 808 | File file(name); |
991 809 | if (file.exists()) { |
992 810 | std::cout << "Removing " << name << std::endl; |
993 811 | Table::deleteTable(name, true); |
994 812 | } |
995 813 | } |
996 - | }; |
997 - | |
998 - | // Fixture class for standard test |
999 - | class PolAverageTVITest: public PolAverageTVITestBase { |
1000 - | protected: |
1001 - | virtual std::string GetDataName() { |
1002 - | return "analytic_type1.bl.ms"; |
1003 - | } |
1004 - | |
1005 - | virtual std::string GetRelativeDataPath() { |
1006 - | return "tsdbaseline"; |
1007 - | } |
1008 - | |
1009 - | VisibilityIterator2 *ManufactureVI(String const &mode) { |
1010 - | return BasicManufacturer1::ManufactureVI(ms_, mode); |
1011 - | } |
1012 - | |
1013 - | void TestFactory(String const &mode, String const &expectedClassName) { |
1014 - | |
1015 - | cout << "Mode \"" << mode << "\" expected class name \"" |
1016 - | << expectedClassName << "\"" << endl; |
1017 - | |
1018 - | if (expectedClassName.size() > 0) { |
1019 - | std::unique_ptr < VisibilityIterator2 > vi(ManufactureVI(mode)); |
1020 - | |
1021 - | // Verify type string |
1022 - | String viiType = vi->ViiType(); |
1023 - | EXPECT_TRUE(viiType.startsWith(expectedClassName)); |
1024 - | } else { |
1025 - | cout << "Creation of VI via factory will fail" << endl; |
1026 - | // exception must be thrown |
1027 - | EXPECT_THROW( { |
1028 - | std::unique_ptr<VisibilityIterator2> vi(ManufactureVI(mode)); //new VisibilityIterator2(factory)); |
1029 - | }, |
1030 - | AipsError)<< "The process must throw AipsError"; |
1031 - | } |
1032 - | } |
1033 - | |
1034 - | }; |
1035 - | |
1036 - | // Fixture class for testing four polarization (cross-pol, stokes IQUV) |
1037 - | class PolAverageTVIFourPolarizationTest: public PolAverageTVITestBase { |
1038 - | protected: |
1039 - | virtual std::string GetDataName() { |
1040 - | return "crosspoltest.ms"; |
1041 - | } |
1042 - | |
1043 - | virtual std::string GetRelativeDataPath() { |
1044 - | return "sdsave"; |
1045 - | } |
1046 - | |
1047 - | void SetCorrTypeToStokes() { |
1048 - | ScalarColumn<Int> dataDescIdColumn(*ms_, "DATA_DESC_ID"); |
1049 - | Vector<Int> dataDescIdList = dataDescIdColumn.getColumn(); |
1050 - | ScalarColumn<Int> polarizationIdColumn(ms_->dataDescription(), |
1051 - | "POLARIZATION_ID"); |
1052 - | Vector<Int> polarizationIdList(dataDescIdList.size()); |
1053 - | for (size_t i = 0; i < dataDescIdList.size(); ++i) { |
1054 - | polarizationIdList[i] = polarizationIdColumn(dataDescIdList[i]); |
1055 - | } |
1056 - | std::cout << "polarizationIdList = " << polarizationIdList << std::endl; |
1057 - | uInt n = GenSort<Int>::sort(polarizationIdList, Sort::Ascending, |
1058 - | Sort::HeapSort | Sort::NoDuplicates); |
1059 - | std::cout << "polarizationIdList (sorted n = " << n << ") = " |
1060 - | << polarizationIdList << std::endl; |
1061 - | |
1062 - | ArrayColumn<Int> corrTypeColumn(ms_->polarization(), "CORR_TYPE"); |
1063 - | Int const newCorrTypes[] = { Stokes::I, Stokes::Q, Stokes::U, Stokes::V }; |
1064 - | for (uInt i = 0; i < n; ++i) { |
1065 - | auto row = polarizationIdList[i]; |
1066 - | std::cout << "row = " << row << std::endl; |
1067 - | Vector<Int> corrType = corrTypeColumn(row); |
1068 - | std::cout << "corrType = " << corrType << std::endl; |
1069 - | ASSERT_LE(corrType.size(), sizeof(newCorrTypes) / sizeof(Int)); |
1070 - | for (size_t j = 0; j < corrType.size(); ++j) { |
1071 - | corrType[j] = newCorrTypes[j]; |
1072 - | } |
1073 - | std::cout << "new corrType = " << corrType << std::endl; |
1074 - | corrTypeColumn.put(row, corrType); |
1075 - | } |
1076 - | } |
1077 - | }; |
1078 - | |
1079 - | // Fixture class for testing dirty (partially flagged) data |
1080 - | // NB: use same data as PolAverageTVITest |
1081 - | class PolAverageTVIDirtyDataTest: public PolAverageTVITest { |
1082 - | public: |
1083 - | virtual void SetUp() { |
1084 - | // call parent's SetUp method |
1085 - | PolAverageTVITestBase::SetUp(); |
1086 - | |
1087 - | // corrupt data |
1088 - | CorruptData(); |
1089 - | } |
1090 - | |
1091 - | private: |
1092 - | // Make input data dirty |
1093 - | void CorruptData() { |
1094 - | // Accessor to FLAG column |
1095 - | ArrayColumn<Bool> flagColumn(*ms_, "FLAG"); |
1096 - | Cube<Bool> flag = flagColumn.getColumn(); |
1097 - | |
1098 - | // Accessor to DATA columns |
1099 - | Cube<Float> floatData; |
1100 - | Cube<Complex> complexData, correctedData; |
1101 - | if (ms_->tableDesc().isColumn("DATA")) { |
1102 - | ArrayColumn<Complex> dataColumn(*ms_, "DATA"); |
1103 - | dataColumn.getColumn(complexData); |
1104 - | ASSERT_EQ(flag.shape(), complexData.shape()); |
1105 - | } |
1106 - | if (ms_->tableDesc().isColumn("FLOAT_DATA")) { |
1107 - | ArrayColumn<Float> dataColumn(*ms_, "FLOAT_DATA"); |
1108 - | dataColumn.getColumn(floatData); |
1109 - | ASSERT_EQ(flag.shape(), floatData.shape()); |
1110 - | } |
1111 - | if (ms_->tableDesc().isColumn("CORRECTED_DATA")) { |
1112 - | ArrayColumn<Complex> dataColumn(*ms_, "CORRECTED_DATA"); |
1113 - | dataColumn.getColumn(correctedData); |
1114 - | ASSERT_EQ(flag.shape(), correctedData.shape()); |
1115 - | } |
1116 - | |
1117 - | // corrupt row 0, channel 10, pol 1 |
1118 - | size_t row = 0; |
1119 - | size_t chan = 10; |
1120 - | size_t pol = 1; |
1121 - | ASSERT_GT((size_t )flag.nplane(), row); |
1122 - | ASSERT_GT((size_t )flag.ncolumn(), chan); |
1123 - | ASSERT_GT((size_t )flag.nrow(), pol); |
1124 - | Float corruptValue = std::numeric_limits<float>::quiet_NaN(); |
1125 - | flag(pol, chan, row) = True; |
1126 - | ASSERT_EQ(flag(pol, chan, row), True); |
1127 - | if (!floatData.empty()) { |
1128 - | floatData(pol, chan, row) = corruptValue; |
1129 - | ASSERT_TRUE(std::isnan(floatData(pol, chan, row))); |
1130 - | } |
1131 - | if (!complexData.empty()) { |
1132 - | complexData(pol, chan, row) = corruptValue; |
1133 - | ASSERT_TRUE(std::isnan(complexData(pol, chan, row).real())); |
1134 - | } |
1135 - | if (!correctedData.empty()) { |
1136 - | correctedData(pol, chan, row) = corruptValue; |
1137 - | ASSERT_TRUE(std::isnan(correctedData(pol, chan, row).real())); |
1138 - | } |
1139 - | |
1140 - | // corrupt row 1, channel 100, all pols |
1141 - | row = 1; |
1142 - | chan = 100; |
1143 - | ASSERT_GT((size_t )flag.nplane(), row); |
1144 - | ASSERT_GT((size_t )flag.ncolumn(), chan); |
1145 - | IPosition blc(3, 0, chan, row); |
1146 - | IPosition trc(3, flag.nrow() - 1, chan, row); |
1147 - | flag(blc, trc) = True; |
1148 - | ASSERT_EQ(flag(0, chan, row), True); |
1149 - | ASSERT_EQ(flag(1, chan, row), True); |
1150 - | if (!floatData.empty()) { |
1151 - | floatData(blc, trc) = corruptValue; |
1152 - | ASSERT_TRUE(std::isnan(floatData(0, chan, row))); |
1153 - | ASSERT_TRUE(std::isnan(floatData(1, chan, row))); |
1154 - | } |
1155 - | if (!complexData.empty()) { |
1156 - | complexData(blc, trc) = corruptValue; |
1157 - | ASSERT_TRUE(std::isnan(complexData(0, chan, row).real())); |
1158 - | ASSERT_TRUE(std::isnan(complexData(1, chan, row).real())); |
1159 - | } |
1160 - | if (!correctedData.empty()) { |
1161 - | correctedData(blc, trc) = corruptValue; |
1162 - | ASSERT_TRUE(std::isnan(correctedData(0, chan, row).real())); |
1163 - | ASSERT_TRUE(std::isnan(correctedData(1, chan, row).real())); |
1164 - | } |
1165 - | |
1166 - | // write back to MS |
1167 - | flagColumn.putColumn(flag); |
1168 - | if (ms_->tableDesc().isColumn("DATA")) { |
1169 - | ArrayColumn<Complex> dataColumn(*ms_, "DATA"); |
1170 - | dataColumn.putColumn(complexData); |
1171 - | } |
1172 - | if (ms_->tableDesc().isColumn("FLOAT_DATA")) { |
1173 - | ArrayColumn<Float> dataColumn(*ms_, "FLOAT_DATA"); |
1174 - | dataColumn.putColumn(floatData); |
1175 - | } |
1176 - | if (ms_->tableDesc().isColumn("CORRECTED_DATA")) { |
1177 - | ArrayColumn<Complex> dataColumn(*ms_, "CORRECTED_DATA"); |
1178 - | dataColumn.putColumn(correctedData); |
1179 - | } |
1180 - | } |
1181 814 | |
1182 815 | }; |
1183 816 | |
1184 817 | TEST_F(PolAverageTVITest, Factory) { |
1185 818 | |
1186 819 | TestFactory("default", "StokesPolAverage"); |
1187 820 | TestFactory("Default", "StokesPolAverage"); |
1188 821 | TestFactory("DEFAULT", "StokesPolAverage"); |
1189 822 | TestFactory("geometric", "GeometricPolAverage"); |
1190 823 | TestFactory("Geometric", "GeometricPolAverage"); |
1191 824 | TestFactory("GEOMETRIC", "GeometricPolAverage"); |
1192 825 | TestFactory("stokes", "StokesPolAverage"); |
1193 826 | TestFactory("Stokes", "StokesPolAverage"); |
1194 827 | TestFactory("STOKES", "StokesPolAverage"); |
1195 828 | // empty mode (default) |
1196 829 | TestFactory("", "StokesPolAverage"); |
1197 830 | // invalid mode (throw exception) |
1198 831 | TestFactory("invalid", ""); |
1199 832 | } |
1200 833 | |
1201 834 | TEST_F(PolAverageTVITest, GeometricAverage) { |
1202 - | // Use different types of constructor to create factory |
835 + | // Use difference type of constructor to create factory |
1203 836 | TestTVI<GeometricAverageValidator, BasicManufacturer1>(); |
1204 837 | TestTVI<GeometricAverageValidator, BasicManufacturer2>(); |
1205 838 | TestTVI<GeometricAverageValidator, LayerManufacturer>(); |
1206 839 | } |
1207 840 | |
1208 841 | TEST_F(PolAverageTVITest, StokesAverage) { |
1209 - | // Use different types of constructor to create factory |
842 + | // Use difference type of constructor to create factory |
1210 843 | TestTVI<StokesAverageValidator, BasicManufacturer1>(); |
1211 844 | TestTVI<StokesAverageValidator, BasicManufacturer2>(); |
1212 845 | TestTVI<StokesAverageValidator, LayerManufacturer>(); |
1213 846 | } |
1214 847 | |
1215 - | TEST_F(PolAverageTVIDirtyDataTest, GeometricAverageCorrupted) { |
1216 - | TestTVI<GeometricAverageValidator, BasicManufacturer1>(); |
1217 - | } |
1218 - | |
1219 - | TEST_F(PolAverageTVIDirtyDataTest, StokesAverageCorrupted) { |
1220 - | TestTVI<StokesAverageValidator, BasicManufacturer1>(); |
1221 - | } |
1222 - | |
1223 - | TEST_F(PolAverageTVIFourPolarizationTest, GeometricAverageSkipped) { |
1224 - | // Edit CORR_TYPE to be IQUV |
1225 - | SetCorrTypeToStokes(); |
1226 - | |
1227 - | TestTVI<GeometricIdenticalValidator, BasicManufacturer1>(); |
1228 - | } |
1229 - | |
1230 - | TEST_F(PolAverageTVIFourPolarizationTest, StokesAverageSkipped) { |
1231 - | // Edit CORR_TYPE to be IQUV |
1232 - | SetCorrTypeToStokes(); |
1233 - | |
1234 - | TestTVI<StokesIdenticalValidator, BasicManufacturer1>(); |
1235 - | } |
1236 - | |
1237 - | TEST_F(PolAverageTVIFourPolarizationTest, GeometricAverageCrossPol) { |
1238 - | TestTVI<GeometricAverageCrossPolarizationValidator, BasicManufacturer1>(); |
1239 - | } |
1240 - | |
1241 - | TEST_F(PolAverageTVIFourPolarizationTest, StokesAverageCrossPol) { |
1242 - | TestTVI<StokesAverageCrossPolarizationValidator, BasicManufacturer1>(); |
1243 - | } |
848 + | // TODO: define test on not-well-behaved data |
849 + | // such as partially flagged, contains NaN, etc. |
850 + | // TODO: define test on the data that should not average |
851 + | // along polarization axis (IQUV or single polarization or multi-pol but not containing XX and YY) |
1244 852 | |
1245 853 | int main(int argc, char **argv) { |
1246 854 | ::testing::InitGoogleTest(&argc, argv); |
1247 855 | std::cout << "PolAverageTVI test " << std::endl; |
1248 856 | return RUN_ALL_TESTS(); |
1249 857 | } |