Commits
14 14 | # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Library General Public |
15 15 | # License for more details. |
16 16 | # |
17 17 | # [Add the link to the JIRA ticket here once it exists] |
18 18 | # |
19 19 | # Based on the requirements listed in plone found here: |
20 20 | # https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.single.sdatmcor.html |
21 21 | # |
22 22 | # |
23 23 | ########################################################################## |
24 + | import functools |
24 25 | import itertools |
25 26 | import os |
26 27 | import re |
27 28 | import shutil |
28 29 | import unittest |
29 30 | |
30 31 | import numpy as np |
31 32 | |
32 33 | from casatasks import applycal, casalog, gencal, sdatmcor |
33 34 | from casatasks.private.sdutil import (convert_antenna_spec_autocorr, |
494 495 | |
495 496 | def test_custom_atm_params_non_conform_list_input(self): |
496 497 | """Test customized ATM parameters: non-conform layerboundaries and layertemperature.""" |
497 498 | with self.assertRaises(Exception): |
498 499 | sdatmcor( |
499 500 | infile=self.infile, outfile=self.outfile, datacolumn='data', |
500 501 | atmdetail=True, |
501 502 | layerboundaries='800m,1.5km', layertemperature='250K,200K,190K' |
502 503 | ) |
503 504 | |
505 + | def __extract_num_threads_from_logfile(self, logfile): |
506 + | with open(logfile, 'r') as f: |
507 + | pattern = re.compile(r'.*Setting numThreads_ to ([0-9]+)') |
508 + | matches_iterator = map(lambda line: pattern.search(line), f) |
509 + | last_match = functools.reduce( |
510 + | lambda previous_match, current_match: |
511 + | current_match if current_match else previous_match, |
512 + | matches_iterator |
513 + | ) |
514 + | |
515 + | self.assertIsNotNone( |
516 + | last_match, |
517 + | msg=f'No match found for pattern "{pattern.pattern}" in "{casalog.logfile()}"' |
518 + | ) |
519 + | num_threads_log = int(last_match.group(1)) |
520 + | |
521 + | return num_threads_log |
522 + | |
504 523 | def test_set_omp_num_threads(self): |
505 524 | """Test if the task respects OMP_NUM_THREADS environment variable.""" |
506 525 | omp_num_threads_org = os.environ.get('OMP_NUM_THREADS') |
507 526 | omp_num_threads_org_int = None |
508 527 | |
509 528 | try: |
510 529 | # set num_threads for OpenMP to any value different from the current one |
511 530 | if omp_num_threads_org is None: |
512 531 | num_threads = 2 |
513 532 | else: |
531 550 | # consistency check |
532 551 | omp_num_threads_current = os.environ.get('OMP_NUM_THREADS') |
533 552 | if omp_num_threads_current is None: |
534 553 | self.assertIsNone(omp_num_threads_org) |
535 554 | else: |
536 555 | self.assertIsNotNone(omp_num_threads_org) |
537 556 | self.assertEqual(omp_num_threads_current, omp_num_threads_org) |
538 557 | |
539 558 | # check log |
540 559 | self.assertTrue(os.path.exists(casalog.logfile()), msg='casalog file is missing!') |
541 - | with open(casalog.logfile(), 'r') as f: |
542 - | pattern = re.compile(r'.*Setting numThreads_ to ([0-9]+)') |
543 - | lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f))) |
544 - | num_threads_log = int(lines[-1].group(1)) |
560 + | num_threads_log = self.__extract_num_threads_from_logfile(casalog.logfile()) |
545 561 | |
546 562 | casalog.post( |
547 563 | f'OMP_NUM_THREAD_VALUES: initial: {OMP_NUM_THREADS_INITIAL} (returned by get_omp_num_threads), ' |
548 564 | f'at test start time: {omp_num_threads_org}, current: {num_threads}, ' |
549 565 | f'last set in logfile: {num_threads_log}') |
550 566 | self.assertEqual(num_threads, num_threads_log) |
551 567 | |
552 568 | # check output MS |
553 569 | self.check_result({19: True, 23: True}) |
554 570 | |
567 583 | os.environ['OMP_NUM_THREADS'] = omp_num_threads_org |
568 584 | |
569 585 | # consistency check |
570 586 | if omp_num_threads_org is None: |
571 587 | self.assertIsNone(os.environ.get('OMP_NUM_THREADS')) |
572 588 | else: |
573 589 | self.assertEqual(os.environ.get('OMP_NUM_THREADS'), omp_num_threads_org) |
574 590 | |
575 591 | # check log |
576 592 | self.assertTrue(os.path.exists(casalog.logfile()), msg='casalog file is missing!') |
577 - | with open(casalog.logfile(), 'r') as f: |
578 - | pattern = re.compile(r'.*Setting numThreads_ to ([0-9]+)') |
579 - | lines = list(filter(lambda x: x is not None, map(lambda x: re.search(pattern, x), f))) |
580 - | num_threads_log = int(lines[-1].group(1)) |
593 + | num_threads_log = self.__extract_num_threads_from_logfile(casalog.logfile()) |
581 594 | num_threads_expected = min(8, casalog.getNumCPUs()) |
582 595 | |
583 596 | casalog.post( |
584 597 | f'OMP_NUM_THREAD_VALUES: initial: {OMP_NUM_THREADS_INITIAL} (returned by get_omp_num_threads), ' |
585 598 | f'at test start time: {omp_num_threads_org}, expected: {num_threads_expected}, ' |
586 599 | f'last set in logfile: {num_threads_log}') |
587 600 | self.assertEqual(num_threads_expected, num_threads_log) |
588 601 | |
589 602 | # check output MS |
590 603 | self.check_result({19: True, 23: True}) |