Commits
522 522 | |
523 523 | # run task |
524 524 | sdatmcor(infile=self.infile, outfile=self.outfile, datacolumn='data') |
525 525 | finally: |
526 526 | if omp_num_threads_org is None: |
527 527 | os.environ.pop('OMP_NUM_THREADS') |
528 528 | else: |
529 529 | os.environ['OMP_NUM_THREADS'] = omp_num_threads_org |
530 530 | |
531 531 | # consistency check |
532 - | if omp_num_threads_org is None: |
533 - | self.assertIsNone(os.environ.get('OMP_NUM_THREADS')) |
532 + | omp_num_threads_current = os.environ.get('OMP_NUM_THREADS') |
533 + | if omp_num_threads_current is None: |
534 + | self.assertIsNone(omp_num_threads_org) |
534 535 | else: |
535 - | self.assertEqual(os.environ.get('OMP_NUM_THREADS'), omp_num_threads_org) |
536 + | self.assertIsNotNone(omp_num_threads_org) |
537 + | self.assertEqual(omp_num_threads_current, omp_num_threads_org) |
536 538 | |
537 539 | # check log |
538 540 | self.assertTrue(os.path.exists(casalog.logfile()), msg='casalog file is missing!') |
539 541 | with open(casalog.logfile(), 'r') as f: |
540 542 | pattern = re.compile(r'.*Setting numThreads_ to ([0-9]+)') |
541 543 | lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f))) |
542 544 | num_threads_log = int(lines[-1].group(1)) |
543 545 | |
544 546 | casalog.post( |
545 - | f'OMP_NUM_THREAD_VALUES: initial: {OMP_NUM_THREADS_INITIAL}, ' |
547 + | f'OMP_NUM_THREAD_VALUES: initial: {OMP_NUM_THREADS_INITIAL} (returned by get_omp_num_threads), ' |
546 548 | f'at test start time: {omp_num_threads_org}, current: {num_threads}, ' |
547 549 | f'last set in logfile: {num_threads_log}') |
548 550 | self.assertEqual(num_threads, num_threads_log) |
549 551 | |
550 552 | # check output MS |
551 553 | self.check_result({19: True, 23: True}) |
552 554 | |
553 555 | def test_unset_omp_num_threads(self): |
554 556 | """Test if the task respects OMP_NUM_THREADS environment variable.""" |
555 557 | # unset OMP_NUM_THREADS if it is set |
572 574 | |
573 575 | # check log |
574 576 | self.assertTrue(os.path.exists(casalog.logfile()), msg='casalog file is missing!') |
575 577 | with open(casalog.logfile(), 'r') as f: |
576 578 | pattern = re.compile(r'.*Setting numThreads_ to ([0-9]+)') |
577 579 | lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f))) |
578 580 | num_threads_log = int(lines[-1].group(1)) |
579 581 | num_threads_expected = min(8, casalog.getNumCPUs()) |
580 582 | |
581 583 | casalog.post( |
582 - | f'OMP_NUM_THREAD_VALUES: initial: {OMP_NUM_THREADS_INITIAL}, ' |
584 + | f'OMP_NUM_THREAD_VALUES: initial: {OMP_NUM_THREADS_INITIAL} (returned by get_omp_num_threads), ' |
583 585 | f'at test start time: {omp_num_threads_org}, expected: {num_threads_expected}, ' |
584 586 | f'last set in logfile: {num_threads_log}') |
585 587 | self.assertEqual(num_threads_expected, num_threads_log) |
586 588 | |
587 589 | # check output MS |
588 590 | self.check_result({19: True, 23: True}) |
589 591 | |
590 592 | |
591 593 | class ATMParamTest(unittest.TestCase): |
592 594 | def _param_test_template(self, valid_test_cases, |