Commits
16 16 | # |
17 17 | # [Add the link to the JIRA ticket here once it exists] |
18 18 | # |
19 19 | # Based on the requirements listed in plone found here: |
20 20 | # https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.single.sdatmcor.html |
21 21 | # |
22 22 | # |
23 23 | ########################################################################## |
24 24 | import itertools |
25 25 | import os |
26 + | import re |
26 27 | import shutil |
27 28 | import unittest |
28 29 | |
29 30 | import numpy as np |
30 31 | |
31 32 | from casatasks import applycal, casalog, gencal, sdatmcor |
32 33 | from casatasks.private.sdutil import (convert_antenna_spec_autocorr, |
33 34 | get_antenna_selection_include_autocorr, |
34 35 | table_manager) |
35 36 | import casatasks.private.task_sdatmcor as sdatmcor_impl |
494 495 | |
495 496 | def test_custom_atm_params_non_conform_list_input(self): |
496 497 | """Test customized ATM parameters: non-conform layerboundaries and layertemperature.""" |
497 498 | with self.assertRaises(Exception): |
498 499 | sdatmcor( |
499 500 | infile=self.infile, outfile=self.outfile, datacolumn='data', |
500 501 | atmdetail=True, |
501 502 | layerboundaries='800m,1.5km', layertemperature='250K,200K,190K' |
502 503 | ) |
503 504 | |
505 + | def test_omp_num_threads(self): |
506 + | """Test if the task respects OMP_NUM_THREADS environment variable.""" |
507 + | omp_num_threads_org = os.environ.get('OMP_NUM_THREADS') |
508 + | try: |
509 + | # set num_threads for OpenMP to any different value than the current one |
510 + | if omp_num_threads_org is None: |
511 + | num_threads = 2 |
512 + | else: |
513 + | iter = filter(lambda x: x != int(omp_num_threads_org), range(2, 9)) |
514 + | num_threads = next(iter) |
515 + | self.assertNotEqual(num_threads, omp_num_threads_org) |
516 + | os.environ['OMP_NUM_THREADS'] = f'{num_threads}' |
517 + | |
518 + | # run task |
519 + | sdatmcor(infile=self.infile, outfile=self.outfile, datacolumn='data') |
520 + | finally: |
521 + | if omp_num_threads_org is None: |
522 + | os.environ.pop('OMP_NUM_THREADS') |
523 + | else: |
524 + | os.environ['OMP_NUM_THREADS'] = omp_num_threads_org |
525 + | |
526 + | # consistency check |
527 + | if omp_num_threads_org is None: |
528 + | self.assertIsNone(os.environ.get('OMP_NUM_THREADS')) |
529 + | else: |
530 + | self.assertEqual(os.environ.get('OMP_NUM_THREADS'), omp_num_threads_org) |
531 + | |
532 + | # check log |
533 + | if os.path.exists(casalog.logfile()): |
534 + | with open(casalog.logfile(), 'r') as f: |
535 + | pattern = re.compile(r'.*Setting numThreads_ to ([0-9+])') |
536 + | lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f))) |
537 + | num_threads_log = int(lines[-1].group(1)) |
538 + | |
539 + | print(f'{OMP_NUM_THREADS_INITIAL} {omp_num_threads_org} {num_threads} {num_threads_log}') |
540 + | self.assertEqual(num_threads, num_threads_log) |
541 + | |
542 + | # check output MS |
543 + | self.check_result({19: True, 23: True}) |
544 + | |
504 545 | |
505 546 | class ATMParamTest(unittest.TestCase): |
506 547 | def _param_test_template(self, valid_test_cases, |
507 548 | invalid_user_input, user_default, task_default, unit=''): |
508 549 | # internal error |
509 550 | wrong_task_default = 'NG' |
510 551 | with self.assertRaises(RuntimeError): |
511 552 | param, is_customized = sdatmcor_impl.parse_atm_params( |
512 553 | '', |
513 554 | user_default, |