Commits

Takeshi Nakazato authored 1fa18c65950
CAS-13713 add test for OMP_NUM_THREADS env var.

casatasks/tests/tasks/test_task_sdatmcor.py

Modified
16 16 #
17 17 # [Add the link to the JIRA ticket here once it exists]
18 18 #
19 19 # Based on the requirements listed in plone found here:
20 20 # https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.single.sdatmcor.html
21 21 #
22 22 #
23 23 ##########################################################################
24 24 import itertools
25 25 import os
26 +import re
26 27 import shutil
27 28 import unittest
28 29
29 30 import numpy as np
30 31
31 32 from casatasks import applycal, casalog, gencal, sdatmcor
32 33 from casatasks.private.sdutil import (convert_antenna_spec_autocorr,
33 34 get_antenna_selection_include_autocorr,
34 35 table_manager)
35 36 import casatasks.private.task_sdatmcor as sdatmcor_impl
494 495
495 496 def test_custom_atm_params_non_conform_list_input(self):
496 497 """Test customized ATM parameters: non-conform layerboundaries and layertemperature."""
497 498 with self.assertRaises(Exception):
498 499 sdatmcor(
499 500 infile=self.infile, outfile=self.outfile, datacolumn='data',
500 501 atmdetail=True,
501 502 layerboundaries='800m,1.5km', layertemperature='250K,200K,190K'
502 503 )
503 504
505 + def test_omp_num_threads(self):
506 + """Test if the task respects OMP_NUM_THREADS environment variable."""
507 + omp_num_threads_org = os.environ.get('OMP_NUM_THREADS')
508 + try:
509 + # set num_threads for OpenMP to any different value than the current one
510 + if omp_num_threads_org is None:
511 + num_threads = 2
512 + else:
513 + iter = filter(lambda x: x != int(omp_num_threads_org), range(2, 9))
514 + num_threads = next(iter)
515 + self.assertNotEqual(num_threads, omp_num_threads_org)
516 + os.environ['OMP_NUM_THREADS'] = f'{num_threads}'
517 +
518 + # run task
519 + sdatmcor(infile=self.infile, outfile=self.outfile, datacolumn='data')
520 + finally:
521 + if omp_num_threads_org is None:
522 + os.environ.pop('OMP_NUM_THREADS')
523 + else:
524 + os.environ['OMP_NUM_THREADS'] = omp_num_threads_org
525 +
526 + # consistency check
527 + if omp_num_threads_org is None:
528 + self.assertIsNone(os.environ.get('OMP_NUM_THREADS'))
529 + else:
530 + self.assertEqual(os.environ.get('OMP_NUM_THREADS'), omp_num_threads_org)
531 +
532 + # check log
533 + if os.path.exists(casalog.logfile()):
534 + with open(casalog.logfile(), 'r') as f:
535 + pattern = re.compile(r'.*Setting numThreads_ to ([0-9+])')
536 + lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f)))
537 + num_threads_log = int(lines[-1].group(1))
538 +
539 + print(f'{OMP_NUM_THREADS_INITIAL} {omp_num_threads_org} {num_threads} {num_threads_log}')
540 + self.assertEqual(num_threads, num_threads_log)
541 +
542 + # check output MS
543 + self.check_result({19: True, 23: True})
544 +
504 545
505 546 class ATMParamTest(unittest.TestCase):
506 547 def _param_test_template(self, valid_test_cases,
507 548 invalid_user_input, user_default, task_default, unit=''):
508 549 # internal error
509 550 wrong_task_default = 'NG'
510 551 with self.assertRaises(RuntimeError):
511 552 param, is_customized = sdatmcor_impl.parse_atm_params(
512 553 '',
513 554 user_default,

Everything looks good. We'll let you know here if there's anything you should know about.

Add shortcut