Commits

Takeshi Nakazato authored ef19b7e2845
CAS-13713 correct handling of environment variable

NAOJ code review Refs #2425 #2426
No tags

casatasks/tests/tasks/test_task_sdatmcor.py

Modified
498 498 with self.assertRaises(Exception):
499 499 sdatmcor(
500 500 infile=self.infile, outfile=self.outfile, datacolumn='data',
501 501 atmdetail=True,
502 502 layerboundaries='800m,1.5km', layertemperature='250K,200K,190K'
503 503 )
504 504
505 505 def test_omp_num_threads(self):
506 506 """Test if the task respects OMP_NUM_THREADS environment variable."""
507 507 omp_num_threads_org = os.environ.get('OMP_NUM_THREADS')
508 + if omp_num_threads_org is not None:
509 + self.assertTrue(
510 + omp_num_threads_org.isdigit(),
511 + msg="invalid value of OMP_NUM_THREADS environment variable"
512 + )
513 + omp_num_threads_org = int(omp_num_threads_org)
508 514 try:
509 515 # set num_threads for OpenMP to any value different from the current one
510 516 if omp_num_threads_org is None:
511 517 num_threads = 2
512 518 else:
513 - iter = filter(lambda x: x != int(omp_num_threads_org), range(2, 9))
514 - num_threads = next(iter)
519 + num_threads = omp_num_threads_org + 1
515 520 self.assertNotEqual(num_threads, omp_num_threads_org)
516 521 os.environ['OMP_NUM_THREADS'] = f'{num_threads}'
517 522
518 523 # run task
519 524 sdatmcor(infile=self.infile, outfile=self.outfile, datacolumn='data')
520 525 finally:
521 526 if omp_num_threads_org is None:
522 527 os.environ.pop('OMP_NUM_THREADS')
523 528 else:
524 - os.environ['OMP_NUM_THREADS'] = omp_num_threads_org
529 + os.environ['OMP_NUM_THREADS'] = str(omp_num_threads_org)
525 530
526 531 # consistency check
527 532 if omp_num_threads_org is None:
528 533 self.assertIsNone(os.environ.get('OMP_NUM_THREADS'))
529 534 else:
530 - self.assertEqual(os.environ.get('OMP_NUM_THREADS'), omp_num_threads_org)
535 + self.assertEqual(os.environ.get('OMP_NUM_THREADS'), str(omp_num_threads_org))
531 536
532 537 # check log
533 538 if os.path.exists(casalog.logfile()):
534 539 with open(casalog.logfile(), 'r') as f:
535 540 pattern = re.compile(r'.*Setting numThreads_ to ([0-9+])')
536 541 lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f)))
537 542 num_threads_log = int(lines[-1].group(1))
538 543
539 544 print(f'{OMP_NUM_THREADS_INITIAL} {omp_num_threads_org} {num_threads} {num_threads_log}')
540 545 self.assertEqual(num_threads, num_threads_log)

Everything looks good. We'll let you know here if there's anything you should know about.

Add shortcut