Commits
498 498 | with self.assertRaises(Exception): |
499 499 | sdatmcor( |
500 500 | infile=self.infile, outfile=self.outfile, datacolumn='data', |
501 501 | atmdetail=True, |
502 502 | layerboundaries='800m,1.5km', layertemperature='250K,200K,190K' |
503 503 | ) |
504 504 | |
505 505 | def test_omp_num_threads(self): |
506 506 | """Test if the task respects OMP_NUM_THREADS environment variable.""" |
507 507 | omp_num_threads_org = os.environ.get('OMP_NUM_THREADS') |
508 + | if omp_num_threads_org is not None: |
509 + | self.assertTrue( |
510 + | omp_num_threads_org.isdigit(), |
511 + | msg="invalid value of OMP_NUM_THREADS environment variable" |
512 + | ) |
513 + | omp_num_threads_org = int(omp_num_threads_org) |
508 514 | try: |
509 515 | # set num_threads for OpenMP to any value different from the current one |
510 516 | if omp_num_threads_org is None: |
511 517 | num_threads = 2 |
512 518 | else: |
513 - | iter = filter(lambda x: x != int(omp_num_threads_org), range(2, 9)) |
514 - | num_threads = next(iter) |
519 + | num_threads = omp_num_threads_org + 1 |
515 520 | self.assertNotEqual(num_threads, omp_num_threads_org) |
516 521 | os.environ['OMP_NUM_THREADS'] = f'{num_threads}' |
517 522 | |
518 523 | # run task |
519 524 | sdatmcor(infile=self.infile, outfile=self.outfile, datacolumn='data') |
520 525 | finally: |
521 526 | if omp_num_threads_org is None: |
522 527 | os.environ.pop('OMP_NUM_THREADS') |
523 528 | else: |
524 - | os.environ['OMP_NUM_THREADS'] = omp_num_threads_org |
529 + | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads_org) |
525 530 | |
526 531 | # consistency check |
527 532 | if omp_num_threads_org is None: |
528 533 | self.assertIsNone(os.environ.get('OMP_NUM_THREADS')) |
529 534 | else: |
530 - | self.assertEqual(os.environ.get('OMP_NUM_THREADS'), omp_num_threads_org) |
535 + | self.assertEqual(os.environ.get('OMP_NUM_THREADS'), str(omp_num_threads_org)) |
531 536 | |
532 537 | # check log |
533 538 | if os.path.exists(casalog.logfile()): |
534 539 | with open(casalog.logfile(), 'r') as f: |
535 540 | pattern = re.compile(r'.*Setting numThreads_ to ([0-9+])') |
536 541 | lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f))) |
537 542 | num_threads_log = int(lines[-1].group(1)) |
538 543 | |
539 544 | print(f'{OMP_NUM_THREADS_INITIAL} {omp_num_threads_org} {num_threads} {num_threads_log}') |
540 545 | self.assertEqual(num_threads, num_threads_log) |