from __future__ import absolute_import
from __future__ import print_function
import os
import sys
import shutil
import re
import numpy
import contextlib
import unittest
from casatestutils import listing

class Casa5InitError(Exception):
    pass

class MixedCasa5Casa6InitError(Casa5InitError):
    pass

try:
    try:
        from casatasks.private.casa_transition import is_CASA6
    except ImportError as e:
        # Plain CASA5 or older: CASA5 prior to mixed CASA5-CASA6
        raise Casa5InitError
    else:
        # Init Mixed CASA
        try:
            if not is_CASA6:
                raise MixedCasa5Casa6InitError
        except MixedCasa5Casa6InitError:
                raise Casa5InitError
        else:
            # Init CASA6
            from casatools import ctsys, table, ms, measures
            from casatasks import casalog, sdcal, partition, initweights
            from casatasks.private import sdutil

            ### for testhelper import
            sys.path.append(os.path.abspath(os.path.dirname(__file__)))

            tb = table()

            ctsys_resolve = ctsys.resolve
            # default isn't used in CASA6
            def default(atask):
                pass

except Casa5InitError as e:
    # Init Plain CASA5 or older, or Mixed CASA which is CASA5
    is_CASA6 = False
    from __main__ import default
    from tasks import *
    from taskinit import *
    import sdutil
    from sdcal import sdcal
    from partition import partition

    # make the CASA5 tool constuctors used here look the CASA6 versions
    ms = mstool
    table = tbtool
    # the global tb is also used here
    measures = metool

    # Get the path to data
    dataRoot = os.path.join(os.environ.get('CASAPATH').split()[0],'casatestdata/')
    def ctsys_resolve(rel_path):
        "Resolve absolute path of a unit test data directory given as a relative path"
        return os.path.join(dataRoot,rel_path)


@contextlib.contextmanager
def mmshelper(vis, separationaxis='auto'):
    outputvis = vis.rstrip('/') + '.mms'
    os.system('rm -rf {0}*'.format(outputvis))
    try:
        partition(vis=vis, outputvis=outputvis, separationaxis=separationaxis)
        if os.path.exists(outputvis):
            yield outputvis
        else:
            yield None
    finally:
        os.system('rm -rf {0}*'.format(outputvis))

class sdcal_test(unittest.TestCase):

    """
    Unit test for task sdcal.

    The list of tests:
    test00	--- default parameters (raises an error)
    test01	--- spwmap comprising list
    test02	--- spwmap comprising dictionary
    test03	--- spwmap comprising others
    test04	--- there is no infile
    test05
    """

    # Data path of input
    datapath=ctsys_resolve('unittest/sdcal/')

    # Input
    infile1 = 'uid___A002_X6218fb_X264.ms.sel'
    infiles = [infile1]
    tsystable = 'out.cal'

    def setUp(self):
        for infile in self.infiles:
            if os.path.exists(infile):
                shutil.rmtree(infile)
            shutil.copytree(os.path.join(self.datapath,infile), infile)

    def tearDown(self):
        for infile in self.infiles:
            if (os.path.exists(infile)):
                shutil.rmtree(infile)

        if os.path.exists(self.tsystable):
            shutil.rmtree(self.tsystable)

    def _compareOutFile(self,out,reference):
        self.assertTrue(os.path.exists(out))
        self.assertTrue(os.path.exists(reference),msg="Reference file doesn't exist: "+reference)
        self.assertTrue(listing.compare(out,reference),'New and reference files are different. %s != %s. '%(out,reference))

    def test00(self):
        """Test00:Check the identification of TSYS_SPECTRuM and FPARAM"""

        tid = "00"
        infile = self.infile1
        sdcal(infile=infile, calmode='tsys', outfile=self.tsystable)
        compfile1=infile+'/SYSCAL'
        compfile2=self.tsystable

        tb.open(compfile1)
        subt1=tb.query('', sortlist='ANTENNA_ID, TIME, SPECTRAL_WINDOW_ID', columns='TSYS_SPECTRUM')
        tsys1=subt1.getcol('TSYS_SPECTRUM')
        tb.close()
        subt1.close()

        tb.open(compfile2)
        subt2=tb.query('', sortlist='ANTENNA1, TIME, SPECTRAL_WINDOW_ID', columns='FPARAM, FLAG')
        tsys2=subt2.getcol('FPARAM')
        flag=subt2.getcol('FLAG')

        tb.close()
        subt2.close()

        if (tsys1 == tsys2).all():
            print('')
            print('The shape of the MS/SYSCAL/TSYS_SPECTRUM', tsys1.shape)
            print('The shape of the FPARAM extracted with sdcal', tsys2.shape)
            print('Both tables are identical.')
        else:
            print('')
            print('The shape of the MS/SYSCAL/TSYS_SPECTRUM', tsys1.shape)
            print('The shape of the FPARAM of the extraction with sdcal', tsys2.shape)
            print('Both tables are not identical.')

        if flag.all()==0:
            print('ALL FLAGs are set to zero.')


    def test00M(self):
        """Test00M:Check the identification of TSYS_SPECTRuM and FPARAM (MMS)"""

        tid = "00M"
        infile = self.infile1
        with mmshelper(infile) as mvis:
            self.assertTrue(mvis is not None)
            sdcal(infile=mvis, calmode='tsys', outfile=self.tsystable)
        compfile1=infile+'/SYSCAL'
        compfile2=self.tsystable

        tb.open(compfile1)
        subt1=tb.query('', sortlist='ANTENNA_ID, TIME, SPECTRAL_WINDOW_ID', columns='TSYS_SPECTRUM')
        tsys1=subt1.getcol('TSYS_SPECTRUM')
        tb.close()
        subt1.close()

        tb.open(compfile2)
        subt2=tb.query('', sortlist='ANTENNA1, TIME, SPECTRAL_WINDOW_ID', columns='FPARAM, FLAG')
        tsys2=subt2.getcol('FPARAM')
        flag=subt2.getcol('FLAG')

        tb.close()
        subt2.close()

        if (tsys1 == tsys2).all():
            print('')
            print('The shape of the MS/SYSCAL/TSYS_SPECTRUM', tsys1.shape)
            print('The shape of the FPARAM extracted with sdcal', tsys2.shape)
            print('Both tables are identical.')
        else:
            print('')
            print('The shape of the MS/SYSCAL/TSYS_SPECTRUM', tsys1.shape)
            print('The shape of the FPARAM of the extraction with sdcal', tsys2.shape)
            print('Both tables are not identical.')

        if flag.all()==0:
            print('ALL FLAGs are set to zero.')


    def test01(self):
        """Test01: weight = 1/(SIGMA**2) X 1/(FPARAM_ave**2) dictionary version"""
        #focus on antenna1=0, data_disk_id=1
        #spwmap_dict={1:[1],3:[3],5:[5],7:[7]}


        tid = "01"
        infile = self.infile1
        sdcal(infile=infile, calmode='tsys', outfile=self.tsystable)
        initweights(vis=infile, wtmode='nyq', dowtsp=True)
        #spwmap_list=[0,1,2,3,4,5,6,7,8,1,10,3,12,5,14,7,16]
        #spwmap_dict={1:[9],3:[11],5:[13],7:[15]}

        spwmap_dict={1:[1],3:[3],5:[5],7:[7]}
        sdcal(infile=infile, calmode='apply', spwmap=spwmap_dict, applytable=self.tsystable, outfile='')


        tb.open(infile)
        sigma00=tb.getcol('SIGMA')[0][0]
        sigma10=tb.getcol('SIGMA')[1][0]
        weight00=tb.getcol('WEIGHT')[0][0]
        weight10=tb.getcol('WEIGHT')[1][0]
        tb.close()

        tb.open(self.tsystable)
        sum_fparam0=0
        sum_fparam1=0
        for i in range(128):
            sum_fparam0 += tb.getvarcol('FPARAM')['r1'][0][i][0]
            sum_fparam1 += tb.getvarcol('FPARAM')['r1'][1][i][0]
        fparam0_ave=sum_fparam0/128.0
        fparam1_ave=sum_fparam1/128.0
        print('fparam_average_r1_0', fparam0_ave)
        print('fparam_average_r1_1', fparam1_ave)
        print('SIGMA00 ', sigma00)
        print('SIGMA10 ', sigma10)
        print('WEIGHT00 ', weight00)
        print('WEIGHT10 ', weight10)
        answer0 = 1/(sigma00**2)*1/(fparam0_ave**2)
        answer1 = 1/(sigma10**2)*1/(fparam1_ave**2)
        print('pol0: 1/SIGMA**2 X 1/(FPARAM_ave)**2', answer0)
        print('pol1: 1/SIGMA**2 X 1/(FPARAM_ave)**2', answer1)
        diff0_percent=(weight00-answer0)/weight00*100
        diff1_percent=(weight10-answer1)/weight10*100
        print('difference between fparam_r1_0 and weight00', diff0_percent, '%')
        print('difference between fparam_r1_1 and weight10', diff1_percent, '%')
        tb.close()


    def test02(self):
        """Test02: weight = 1/(SIGMA**2) X 1/(FPARAM_ave**2) list version"""
        #focus on antenna1=0, data_disk_id=1
        #spwmap_list=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]


        tid = "02"
        infile = self.infile1
        sdcal(infile=infile, calmode='tsys', outfile=self.tsystable)
        initweights(vis=infile, wtmode='nyq', dowtsp=True)
        spwmap_list=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
        #spwmap_dict={1:[9],3:[11],5:[13],7:[15]}

        #spwmap_dict={1:[1],3:[3],5:[5],7:[7]}
        sdcal(infile=infile, calmode='apply', spwmap=spwmap_list, applytable=self.tsystable, outfile='')


        tb.open(infile)
        sigma00=tb.getcol('SIGMA')[0][0]
        sigma10=tb.getcol('SIGMA')[1][0]
        weight00=tb.getcol('WEIGHT')[0][0]
        weight10=tb.getcol('WEIGHT')[1][0]
        tb.close()

        tb.open(self.tsystable)
        sum_fparam0=0
        sum_fparam1=0
        for i in range(128):
            sum_fparam0 += tb.getvarcol('FPARAM')['r1'][0][i][0]
            sum_fparam1 += tb.getvarcol('FPARAM')['r1'][1][i][0]
        fparam0_ave=sum_fparam0/128.0
        fparam1_ave=sum_fparam1/128.0
        print('fparam_average_r1_0', fparam0_ave)
        print('fparam_average_r1_1', fparam1_ave)
        print('SIGMA00 ', sigma00)
        print('SIGMA10 ', sigma10)
        print('WEIGHT00 ', weight00)
        print('WEIGHT10 ', weight10)
        answer0 = 1/(sigma00**2)*1/(fparam0_ave**2)
        answer1 = 1/(sigma10**2)*1/(fparam1_ave**2)
        print('pol0: 1/SIGMA**2 X 1/(FPARAM_ave)**2', answer0)
        print('pol1: 1/SIGMA**2 X 1/(FPARAM_ave)**2', answer1)
        diff0_percent=(weight00-answer0)/weight00*100
        diff1_percent=(weight10-answer1)/weight10*100
        print('difference between fparam_r1_0 and weight00', diff0_percent, '%')
        print('difference between fparam_r1_1 and weight10', diff1_percent, '%')
        tb.close()


        #print(type(fparam_dict))
        #print('shape of fparam')
        #print('shape of fparam_dict['r29']', fparam_dict['r29'].shape)
        #print(fparam_dict['r29'][0])
        #print(fparam_dict['r29'][1])
        #tb.close()

        #tb.open(infile)

        #data_dict=tb.getvarcol('DATA')

        #subt=tb.query('', sortlist='ANTENNA1, TIME, SPECTRAL_WINDOW_ID', columns='FPARAM, DATA')
        #data=subt2.getcol('DATA')
        #fparam=subt2.getcol('FPARAM')
        #print(data[0])
        #print(data[1])
        #print(fparam[0])
        #print(fparam[1])

        #subt_dict=tb.query('', sortlist='ANTENNA1, TIME', columns='WEIGHT, CORRECTED_DATA')
        #weight_dict = subt_dict.getcol('WEIGHT')
        #weight_dict=tb.getvarcol('WEIGHT')
        #print(type(weight_dict))
        #print(weight_dict['r69'])
        #print(weight_dict['r69'][0])
        #print(weight_dict['r69'][1])
        #print(weight_dict)

        #corrected_data_dict = subt_dict.getcol('CORRECTED_DATA')
        #tb.close()
        #subt_dict.close()

        #sdcal(infile=infile, calmode='apply', spwmap=spwmap_dict, applytable='tsys.cal', outfile='')
        #tb.open(infile)
        #subt_list=tb.query('', sortlist='ANTENNA1, TIME, SPECTRAL_WINDOW_ID', columns='WEIGHT, CORRECTED_DATA')
        #weight_list = subt_list.getcol('WEIGHT')
        #corrected_data_list = subt_list.getcol('CORRECTED_DATA')
        #tb.close()
        #subt_list.close()

        #sdcal(infile=infile, calmode='apply', spwmap=spwmap_list, applytable='tsys.cal', outfile='')


        #print('dict:', spwmap)
        #print('list:', spwmap)
        #if spwmap.all()==spwmap_dict.all():
        #    Spwmap is able to cope with dictionary and list.
        #print(spwmap.all()==spwmap_dict.all())


    def test03(self):
        """Test03: Validation of CORRECTED_DATA = DATA X FPARAM (spwmap={1:[1], 3:[3], 5:[5], 7:[7]})"""

        tid ="03"
        infile=self.infile1
        tsysfile=self.tsystable


        #tsys table is produced
        sdcal(infile=infile, calmode='tsys', outfile=tsysfile)
        #spwmap=[0,1,2,3,4,5,6,7,8,1,10,3,12,5,14,7,16]
        spwmap={1:[1],3:[3],5:[5],7:[7]}
        initweights(vis=infile, wtmode='nyq', dowtsp=True)
        #sdcal(infile=infile, calmode='apply', spwmap=spwmap, applytable=tsysfile, outfile='')


        sdcal(infile=infile, calmode='apply', spwmap=spwmap, applytable=tsysfile)


        tb.open(infile)
        corrected_data=tb.getvarcol('CORRECTED_DATA')['r1'][0][0][0]
        data=tb.getvarcol('DATA')['r1'][0][0][0]
        tb.close()


        tb.open(tsysfile)
        fparam= tb.getvarcol('FPARAM')['r1'][0][0][0]
        tb.close()

        print("CORRECTED_DATA", corrected_data)
        print("DATA", data)
        print("FPARAM", fparam)
        diff = corrected_data.real - (data.real*fparam)
        diff_per = (diff/corrected_data.real)*100
        print("difference between CORRECTED_DATA and DATA X FPARAM", diff_per, "%")


    def test04(self):
        """Test04: Validation of CORRECTED_DATA = DATA X FPARAM
        (spwmap={1:[9], 3:[11], 5:[13], 7:[15]})
        antanna1=0, DATA_DISC_ID=9, FPARAM_average
        """


        tid ="04"
        infile=self.infile1
        tsysfile=self.tsystable

        #tsys table is produced
        sdcal(infile=infile, calmode='tsys', outfile=tsysfile)
        #spwmap=[0,1,2,3,4,5,6,7,8,1,10,3,12,5,14,7,16]
        spwmap={1:[9],3:[11],5:[13],7:[15]}
        initweights(vis=infile, wtmode='nyq', dowtsp=True)
        #sdcal(infile=infile, calmode='apply', spwmap=spwmap, applytable=tsysfile, outfile='')
        sdcal(infile=infile, calmode='apply', spwmap=spwmap, applytable=tsysfile)

        tb.open(infile)
        corrected_data=tb.getvarcol('CORRECTED_DATA')['r2'][0][0][0]
        data=tb.getvarcol('DATA')['r2'][0][0][0]
        tb.close()

        tb.open(tsysfile)
        sum_fparam=0
        for i in range(128):
            fparam= tb.getvarcol('FPARAM')['r1'][0][i][0]
            sum_fparam += fparam
        fparam_ave=sum_fparam/128.0
        tb.close()

        print("CORRECTED_DATA", corrected_data)
        print("DATA", data)
        print("FPARAM average(128ch)", fparam_ave)
        diff = corrected_data.real - (data.real*fparam_ave)
        diff_per = (diff/corrected_data.real)*100
        print("difference between CORRECTED_DATA and DATA X FPARAM_average(128)", diff_per, "%")



    def test05(self):
        """Test05: Validation of CORRECTED_DATA = DATA X FPARAM
        (spwmap={1:[9], 3:[11], 5:[13], 7:[15]})
        antanna1=0, DATA_DISC_ID=9, FPARAM_average
        """
        print('')

        tid ="05"
        infile=self.infile1
        tsysfile=self.tsystable

        #tsys table is produced
        sdcal(infile=infile, calmode='tsys', outfile=tsysfile)
        #spwmap=[0,1,2,3,4,5,6,7,8,1,10,3,12,5,14,7,16]
        spwmap={1:[9],3:[11],5:[13],7:[15]}
        initweights(vis=infile, wtmode='nyq', dowtsp=True)
        #sdcal(infile=infile, calmode='apply', spwmap=spwmap, applytable=tsysfile, outfile='')
        sdcal(infile=infile, calmode='apply', spwmap=spwmap, applytable=tsysfile)

        tb.open(infile)
        corrected_data=tb.getvarcol('CORRECTED_DATA')['r2'][0][0][0]
        data=tb.getvarcol('DATA')['r2'][0][0][0]
        tb.close()

        tb.open(tsysfile)
        fparam= tb.getvarcol('FPARAM')['r1'][0][0][0]
        tb.close()

        print("CORRECTED_DATA", corrected_data)
        print("DATA", data)
        print("FPARAM", fparam)
        diff = corrected_data.real - (data.real*fparam)
        diff_per = (diff/corrected_data.real)*100
        print("difference between CORRECTED_DATA and DATA X FPARAM", diff_per, "%")



    def test06(self):
        """Test06: weight_spectrum = 1/(SIGMA**2) X 1/(FPARAMx**2) dictionary version"""
        #focus on antenna1=0, data_disk_id=1
        #spwmap_dict={1:[1],3:[3],5:[5],7:[7]}


        tid = "06"
        infile = self.infile1
        sdcal(infile=infile, calmode='tsys', outfile=self.tsystable)
        initweights(vis=infile, wtmode='nyq', dowtsp=True)
        #spwmap_list=[0,1,2,3,4,5,6,7,8,1,10,3,12,5,14,7,16]
        #spwmap_dict={1:[9],3:[11],5:[13],7:[15]}

        spwmap_dict={1:[1],3:[3],5:[5],7:[7]}
        sdcal(infile=infile, calmode='apply', spwmap=spwmap_dict, applytable=self.tsystable, interp='nearest', outfile='')

        row=0
        eps = 1.0e-1

        tb.open(infile)
        sigma=tb.getcell('SIGMA', row)
        weight_spectrum=tb.getcell('WEIGHT_SPECTRUM', row)
        total_ch=tb.getcell('WEIGHT_SPECTRUM',row).shape[1]
        tb.close()

        tb.open(self.tsystable)
        fparam=tb.getcell('FPARAM', row)
        for ch in range(total_ch):
            #print('SIGMA00 ', sigma[0])
            #print('SIGMA10 ', sigma[1])
            #print('WEIGHT_SPECTRUM00 ', weight_spectrum[0][ch])
            #print('WEIGHT_SPECTRUM10 ', weight_spectrum[1][ch])
            answer0 = 1/(sigma[0]**2)*1/(fparam[0][ch]**2)
            answer1 = 1/(sigma[1]**2)*1/(fparam[1][ch]**2)
            #print('pol0: 1/SIGMA**2 X 1/(FPARAM)**2', answer0)
            #print('pol1: 1/SIGMA**2 X 1/(FPARAM)**2', answer1i)
            diff0=weight_spectrum[0][ch]-answer0
            diff1=weight_spectrum[1][ch]-answer1
            diff0_percent= diff0/weight_spectrum[0][ch]*100
            diff1_percent= diff1/weight_spectrum[1][ch]*100

            #diff0_percent=(weight_spectrum[0][ch]-answer0)/weight_spectrum[0][ch]*100
            #diff1_percent=(weight_spectrum[1][ch]-answer1)/weight_spectrum[1][ch]*100
            print('')
            print('pol0 & pol1 ch '+ str(ch)+ ': diff between 1/SIGMA**2 X 1/(FPARAM['+str(ch)+'])**2 and WEIGHT_SPECTRUM['+ str(ch)+']' , diff0, diff1)
            print(diff0_percent, '%', diff1_percent, '%')
            #self.assertTrue(diff0 < eps, msg='The error is small enough')
        tb.close()





class sdcal_test_base(unittest.TestCase):
    """
    Base class for sdcal unit test.
    The following attributes/functions are defined here.

        datapath
        decorators (invalid_argument_case, exception_case)
    """
    # Data path of input
    datapath=ctsys_resolve('unittest/sdcal/')

    # Input
    infile = 'uid___A002_X6218fb_X264.ms.sel'
    applytable = infile + '.sky'

    # task execution result
    result = None

    # decorators
    @staticmethod
    def invalid_argument_case(func):
        """
        Decorator for the test case that is intended to fail
        due to invalid argument.
        """
        import functools
        @functools.wraps(func)
        def wrapper(self):
            func(self)
            self.assertFalse(self.result, msg='The task must return False')
        return wrapper

    @staticmethod
    def exception_case(exception_type, exception_pattern):
        """
        Decorator for the test case that is intended to throw
        exception.

            exception_type: type of exception
            exception_pattern: regex for inspecting exception message
                               using re.search
        """
        def wrapper(func):
            import functools
            @functools.wraps(func)
            def _wrapper(self):
                self.assertTrue(len(exception_pattern) > 0, msg='Internal Error')
                with self.assertRaises(exception_type) as ctx:
                    func(self)
                    self.fail(msg='The task must throw exception')
                the_exception = ctx.exception
                message = str(the_exception)
                self.assertIsNotNone(re.search(exception_pattern, message), msg='error message \'%s\' is not expected.'%(message))
            return _wrapper
        return wrapper

    def _setUp(self, files, task):
        for f in files:
            if os.path.exists(f):
                shutil.rmtree(f)
            shutil.copytree(os.path.join(self.datapath, f), f)

        default(task)

    def _tearDown(self, files):
        for f in files:
            if os.path.exists(f):
                shutil.rmtree(f)

class sdcal_test_ps(sdcal_test_base):
    """
    Unit test for task sdcal (position switchsky calibration).

    The list of tests:
    test_ps00 --- default parameters (raises an error)
    test_ps01 --- invalid calibration type
    test_ps02 --- invalid selection (empty selection result)
    test_ps03 --- outfile exists (overwrite=False)
    test_ps04 --- empty outfile
    test_ps05 --- position switch calibration ('ps')
    test_ps06 --- position switch calibration ('ps') with data selection
    test_ps07 --- outfile exists (overwrite=True)
    test_ps08 --- inappropriate calmode ('otfraster')
    """
    invalid_argument_case = sdcal_test_base.invalid_argument_case
    exception_case = sdcal_test_base.exception_case

    @property
    def outfile(self):
        return self.applytable

    def setUp(self):
        self._setUp([self.infile], sdcal)

    def tearDown(self):
        self._tearDown([self.infile, self.outfile])

    def normal_case(**kwargs):
        """
        Decorator for the test case that is intended to verify
        normal execution result.

        selection --- data selection parameter as dictionary

        Here, expected result is as follows:
            - total number of rows is 12
            - number of antennas is 2
            - number of spectral windows is 2
            - each (antenna,spw) pair has 3 rows
            - expected sky data is a certain fixed value except completely
              flagged channels
              ANT, SPW, SKY
              0     9   [1.0, 2.0, 3.0]
              1     9   [7.0, 8.0, 9.0]
              0    11   [4.0, 5.0, 6.0]
              1    11   [10.0, 11.0, 12.0]
            - channels 0~10 are flagged, each integration has sprious
              ANT, SPW, SKY
              0     9   [(511,512), (127,128), (383,384)]
              1     9   [(511,512), (127,128), (383,384)]
              0    11   [(511,512), (127,128), (383,384)]
              1    11   [(511,512), (127,128), (383,384)]
        """
        def wrapper(func):
            import functools
            @functools.wraps(func)
            def _wrapper(self):
                func(self)

                # sanity check
                self.assertIsNone(self.result, msg='The task must complete without error')
                self.assertTrue(os.path.exists(self.outfile), msg='Output file is not properly created.')

                # verifying nrow
                if len(kwargs) == 0:
                    expected_nrow = 12
                    antenna1_selection = None
                    spw_selection = None
                else:
                    myms = ms()
                    myargs = kwargs.copy()
                    if 'baseline' not in myargs:
                        with sdutil.tbmanager(self.infile) as tb:
                            antenna1 = numpy.unique(tb.getcol('ANTENNA1'))
                            myargs['baseline'] = '%s&&&'%(','.join(map(str,antenna1)))
                    a = myms.msseltoindex(self.infile, **myargs)
                    antenna1_selection = a['antenna1']
                    spw_selection = a['spw']
                    expected_nrow = 3 * len(spw_selection) * len(antenna1_selection)
                with sdutil.tbmanager(self.outfile) as tb:
                    self.assertEqual(tb.nrows(), expected_nrow, msg='Number of rows mismatch (expected %s actual %s)'%(expected_nrow, tb.nrows()))

                # verifying resulting sky spectra
                expected_value = {0: {9: [1., 2., 3.],
                                      11: [4., 5., 6.]},
                                  1: {9: [7., 8., 9.],
                                      11: [10., 11., 12.]}}
                eps = 1.0e-6
                for (ant,d) in expected_value.items():
                    if antenna1_selection is not None and ant not in antenna1_selection:
                        continue
                    for (spw,val) in d.items():
                        if spw_selection is not None and spw not in spw_selection:
                            continue
                        #print(ant, spw, val)
                        construct = lambda x: '%s == %s'%(x)
                        taql = ' && '.join(map(construct,[('ANTENNA1',ant), ('SPECTRAL_WINDOW_ID',spw)]))
                        with sdutil.table_selector(self.outfile, taql) as tb:
                            nrow = tb.nrows()
                            self.assertEqual(nrow, 3, msg='Number of rows mismatch')
                            for irow in range(tb.nrows()):
                                expected = val[irow]
                                self.assertGreater(expected, 0.0, msg='Internal Error')
                                fparam = tb.getcell('FPARAM', irow)
                                flag = tb.getcell('FLAG', irow)
                                message_template = lambda x,y: 'Unexpected %s for antenna %s spw %s row %s (expected %s)'%(x,ant,spw,irow,y)
                                self.assertTrue(all(flag[:,:10].flatten() == True), msg=message_template('flag status', True))
                                self.assertTrue(all(flag[:,10:].flatten() == False), msg=message_template('flag status', False))
                                fparam_valid = fparam[flag == False]
                                error = abs((fparam_valid - expected) / expected)
                                self.assertTrue(all(error < eps), msg=message_template('sky data', expected))
            return _wrapper
        return wrapper


    @invalid_argument_case
    def test_ps00(self):
        """
        test_ps00 --- default parameters (raises an error)
        """
        # CASA6 throws an exception
        if is_CASA6:
            self.assertRaises(Exception, sdcal)
        else:
            self.result = sdcal()

    @invalid_argument_case
    def test_ps01(self):
        """
        test_ps01 --- invalid calibration type
        """
        # CASA6 throwa nan exception
        if is_CASA6:
            self.assertRaises(Exception, sdcal, infile=self.infile, calmode='invalid_type', outfile=self.outfile)
        else:
            self.result = sdcal(infile=self.infile, calmode='invalid_type', outfile=self.outfile)

    @exception_case(RuntimeError, 'Spw Expression: No match found for 99,')
    def test_ps02(self):
        """
        test_ps02 --- invalid selection (invalid spw selection)
        """
        self.result = sdcal(infile=self.infile, calmode='ps', spw='99', outfile=self.outfile)

    @exception_case(RuntimeError, '^overwrite is False and output file exists:')
    def test_ps03(self):
        """
        test_ps03 --- outfile exists (overwrite=False)
        """
        # copy input to output
        shutil.copytree(self.infile, self.outfile)
        self.result = sdcal(infile=self.infile, calmode='ps', outfile=self.outfile, overwrite=False)

    @exception_case(RuntimeError, 'Output file name must be specified\.')
    def test_ps04(self):
        """
        test_ps04 --- empty outfile
        """
        self.result = sdcal(infile=self.infile, calmode='ps', outfile='', overwrite=False)

    @normal_case()
    def test_ps05(self):
        """
        test_ps05 --- position switch calibration ('ps')
        """
        self.result = sdcal(infile=self.infile, calmode='ps', outfile=self.outfile)

    @normal_case()
    def test_ps05M(self):
        """
        test_ps05M --- position switch calibration ('ps') for MMS
        """
        with mmshelper(vis=self.infile) as mvis:
            self.assertTrue(mvis is not None)
            self.result = sdcal(infile=mvis, calmode='ps', outfile=self.outfile)

    @normal_case(spw='9')
    def test_ps06(self):
        """
        test_ps06 --- position switch calibration ('ps') with data selection
        """
        self.result = sdcal(infile=self.infile, calmode='ps', spw='9', outfile=self.outfile)

    @normal_case()
    def test_ps07(self):
        """
        test_ps07 --- outfile exists (overwrite=True)
        """
        # copy input to output
        shutil.copytree(self.infile, self.outfile)
        self.result = sdcal(infile=self.infile, calmode='ps', outfile=self.outfile, overwrite=True)

    @exception_case(RuntimeError, "Error in Calibrater::solve")
    def test_ps08(self):
        """
        test_ps08 --- inappropriate calmode ('otfraster')
        """
        # the data doesn't an OTF raster scan so that unexpected behavior may happen
        # if calmode is 'otfraster'
        # in this case, gap detection detects the row having only one integration
        # due to irregular time stamp distribution and causes the "Too many edge
        # points" error
        self.result = sdcal(infile=self.infile, outfile=self.outfile, calmode='otfraster')

class DataManager:
    """
    Functor decorator to handle data setup/teardown of unit test class methods
    on a per-test basis
    """

    def __init__(self,io_files=None):
        self.io_files = io_files

    def setUp(self):
        input_ms_name = self.io_files['input']['ms_name']
        input_ms_path = os.path.join(self.io_files['datapath'],input_ms_name)

        if not os.path.exists(input_ms_path):
            err_msg = 'Input MS not found:\n{}'.format(input_ms_path)
            raise Exception(err_msg)

        input_ms_copy_name = input_ms_name
        if os.path.exists(input_ms_copy_name):
            shutil.rmtree(input_ms_copy_name)

        shutil.copytree(input_ms_path, input_ms_copy_name)

        output_cal_tbl_name = self.io_files['output']['cal_tbl_name']
        if os.path.exists(output_cal_tbl_name):
            shutil.rmtree(output_cal_tbl_name)

    def tearDown(self):
        input_ms_copy_name = self.io_files['input']['ms_name']
        if os.path.exists(input_ms_copy_name):
            shutil.rmtree(input_ms_copy_name)

        output_cal_tbl_name = self.io_files['output']['cal_tbl_name']
        if os.path.exists(output_cal_tbl_name):
            shutil.rmtree(output_cal_tbl_name)

    def __call__(self,func):
        def wrapped_func(inst_UnitTest):
            inst_UnitTest.io_files = self.io_files
            self.setUp()
            func(inst_UnitTest)
            self.tearDown()

        return wrapped_func

class sdcal_test_bug_fix_cas_12712(unittest.TestCase):
    """
    Test fix for sdcal bug reported in CAS-12712
    """
    io_files = None

    @DataManager({ 'input' : { 'ms_name'      : 'uid___A002_X85c183_X36f.ms.sel' },
                   'output': { 'cal_tbl_name' : 'uid___A002_X85c183_X36f.ms.sel.cal.sky.tbl' },
                    # Path to search when looking for test input data,
                    # relative to some root data directory
                   'datapath' : ctsys_resolve('unittest/sdcal/')
                   })
    def test_cas_12712_01(self):
        input_ms = self.io_files['input']['ms_name']
        output_cal_tbl = self.io_files['output']['cal_tbl_name']

        spw_sel="23"
        sdcal(infile=input_ms,
              outfile=output_cal_tbl,
              spw=spw_sel,
              overwrite=True,
              calmode='ps')

        cal_tbl_path = os.path.join(os.getcwd(),output_cal_tbl)
        err_msg = 'Calibration table not found:\n{}'.format(cal_tbl_path)
        self.assertTrue(os.path.exists(cal_tbl_path), err_msg)

        # Expected number of calibration table rows for spw=23 and antenna=0: 61
        expected_rows = 61
        computed_rows = None
        ant_id = 0
        res_tbl = None
        try:
            tb.open(cal_tbl_path)
            qry = 'select * from {tbl_name:} where ( ANTENNA1 == {ant_id:} and ANTENNA2 == {ant_id:} )'.format(
                tbl_name=output_cal_tbl,
                ant_id=ant_id
                )
            res_tbl = tb.taql(qry)
            computed_rows=res_tbl.nrows()
        finally:
            if tb:
                tb.close()
            if res_tbl:
                res_tbl.close()

            err_msg = "\n".join([
                "Calibration table: {}".format(cal_tbl_path),
                "Expected number of rows for ant={} and spw={}: {}".format(
                    ant_id,spw_sel,expected_rows),
                "Computed number of rows for ant={} and spw={}: {}".format(
                    ant_id,spw_sel,computed_rows),
                ])
            self.assertEqual(expected_rows,computed_rows,err_msg)

class sdcal_test_otfraster(sdcal_test_base):
    """
    Unit test for task sdcal (OTF raster sky calibration).
    Since basic test case is covered by sdcal_test_ps, only
    tests specific to otfraster calibration are defined here.

    The list of tests:
    test_otfraster00 --- invalid fraction (non numeric value)
    test_otfraster01 --- too many edge points (fraction 0.5)
    test_otfraster02 --- too many edge points (fraction '50%')
    test_otfraster03 --- too many edge points (noff 100000)
    ###test_otfraster04 --- negative edge points
    ###test_otfraster05 --- zero edge points
    test_otfraster06 --- inappropriate calibration mode ('ps')
    test_otfraster07 --- OTF raster calibration ('otfraster') with default setting
    test_otfraster08 --- OTF raster calibration ('otfraster') with string fraction (numeric value)
    test_otfraster09 --- OTF raster calibration ('otfraster') with string fraction (percentage)
    test_otfraster10 --- OTF raster calibration ('otfraster') with numeric fraction
    test_otfraster11 --- OTF raster calibration ('otfraster') with auto detection
    test_otfraster12 --- OTF raster calibration ('otfraster') with custom noff
    test_otfraster13 --- check if noff takes priority over fraction
    """
    invalid_argument_case = sdcal_test_base.invalid_argument_case
    exception_case = sdcal_test_base.exception_case
    infile = 'uid___A002_X6218fb_X264.ms.sel.otfraster'

    @staticmethod
    def calculate_expected_value(table, numedge=1):
        expected_value = {}
        with sdutil.tbmanager(table) as tb:
            antenna_list = numpy.unique(tb.getcol('ANTENNA1'))
            ddid_list = numpy.unique(tb.getcol('DATA_DESC_ID'))
        with sdutil.tbmanager(os.path.join(table,'DATA_DESCRIPTION')) as tb:
            dd_spw_map = tb.getcol('SPECTRAL_WINDOW_ID')
        for antenna in antenna_list:
            expected_value[antenna] = {}
            for ddid in ddid_list:
                spw = dd_spw_map[ddid]
                taql = 'ANTENNA1 == %s && ANTENNA2 == %s && DATA_DESC_ID == %s'%(antenna,antenna,ddid)
                with sdutil.tbmanager(table) as tb:
                    try:
                        tsel = tb.query(taql, sortlist='TIME')
                        time_list = tsel.getcol('TIME')
                        data = tsel.getcol('DATA').real
                        flag = tsel.getcol('FLAG')
                    finally:
                        tsel.close()
                #print('time_list', time_list)
                if len(time_list) < 2:
                    continue
                data_list = []
                time_difference = time_list[1:] - time_list[:-1]
                #print('time_difference', time_difference)
                gap_threshold = numpy.median(time_difference) * 5
                #print('gap_threshold', gap_threshold)
                gap_list = numpy.concatenate(([0], numpy.where(time_difference > gap_threshold)[0]+1))
                if gap_list[-1] != len(time_list):
                    gap_list = numpy.concatenate((gap_list, [len(time_list)]))
                #print('gap_list', gap_list)
                for i in range(len(gap_list)-1):
                    start = gap_list[i]
                    end = gap_list[i+1]
                    raster_data = data[:,:,start:end]
                    raster_flag = flag[:,:,start:end]
                    raster_row = numpy.ma.masked_array(raster_data, raster_flag)
                    left_edge = raster_row[:,:,:numedge].mean(axis=2)
                    right_edge = raster_row[:,:,-numedge:].mean(axis=2)
                    data_list.extend([left_edge, right_edge])
                expected_value[antenna][spw] = data_list
                #print('antenna', antenna, 'spw', spw, 'len(data_list)', len(data_list))

        return expected_value

    @property
    def outfile(self):
        return self.applytable

    def setUp(self):
        self._setUp([self.infile], sdcal)

    def tearDown(self):
        self._tearDown([self.infile, self.outfile])

    def normal_case(numedge=1, **kwargs):
        """
        Decorator for the test case that is intended to verify
        normal execution result.

        numedge --- expected number of edge points
        selection --- data selection parameter as dictionary

        Here, expected result is as follows:
            - total number of rows is 24
            - number of antennas is 2
            - number of spectral windows is 2
            - each (antenna,spw) pair has 6 rows
        """
        def wrapper(func):
            import functools
            @functools.wraps(func)
            def _wrapper(self):
                func(self)

                # sanity check
                self.assertIsNone(self.result, msg='The task must complete without error')
                self.assertTrue(os.path.exists(self.outfile), msg='Output file is not properly created.')

                # verifying nrow
                if len(kwargs) == 0:
                    expected_nrow = 24
                    antenna1_selection = None
                    spw_selection = None
                else:
                    myms = ms()
                    myargs = kwargs.copy()
                    if 'baseline' not in myargs:
                        with sdutil.tbmanager(self.infile) as tb:
                            antenna1 = numpy.unique(tb.getcol('ANTENNA1'))
                            myargs['baseline'] = '%s&&&'%(','.join(map(str,antenna1)))
                    a = myms.msseltoindex(self.infile, **myargs)
                    antenna1_selection = a['antenna1']
                    spw_selection = a['spw']
                    expected_nrow = 6 * len(spw_selection) * len(antenna1_selection)
                with sdutil.tbmanager(self.outfile) as tb:
                    self.assertEqual(tb.nrows(), expected_nrow, msg='Number of rows mismatch (expected %s actual %s)'%(expected_nrow, tb.nrows()))

                # verifying resulting sky spectra
                eps = 1.0e-6
                expected_value = sdcal_test_otfraster.calculate_expected_value(self.infile, numedge)
                for (ant,d) in expected_value.items():
                    if antenna1_selection is not None and ant not in antenna1_selection:
                        continue
                    for (spw,val) in d.items():
                        if spw_selection is not None and spw not in spw_selection:
                            continue
                        #print(ant, spw, val)
                        construct = lambda x: '%s == %s'%(x)
                        taql = ' && '.join(map(construct,[('ANTENNA1',ant), ('SPECTRAL_WINDOW_ID',spw)]))
                        with sdutil.table_selector(self.outfile, taql) as tb:
                            nrow = tb.nrows()
                            self.assertEqual(nrow, 6, msg='Number of rows mismatch')
                            for irow in range(tb.nrows()):
                                expected = val[irow]
                                fparam = tb.getcell('FPARAM', irow)
                                flag = tb.getcell('FLAG', irow)
                                self.assertEqual(expected.shape, fparam.shape, msg='Shape mismatch for antenna %s spw %s row %s (expected %s actual %s)'%(ant,spw,irow,list(expected.shape),list(fparam.shape)))
                                npol,nchan = expected.shape
                                for ipol in range(npol):
                                    for ichan in range(nchan):
                                        message_template = lambda x,y,z: 'Unexpected %s for antenna %s spw %s row %s pol %s channel %s (expected %s actual %s)'%(x,ant,spw,irow,ipol,ichan,y,z)
                                        _flag = flag[ipol,ichan]
                                        _mask = expected.mask[ipol,ichan]
                                        _expected = expected.data[ipol,ichan]
                                        _fparam = fparam[ipol,ichan]
                                        self.assertEqual(_mask, _flag, msg=message_template('FLAG',_mask,_flag))
                                        if _mask is True:
                                            self.assertEqual(0.0, _fparam, msg=message_template('FPARAM',0.0,_fparam))
                                        elif abs(_expected) < eps:
                                            self.assertLess(abs(_fparam), eps, msg=message_template('FPARAM',_expected,_fparam))
                                        else:
                                            diff = abs((_fparam - _expected) / _expected)
                                            self.assertLess(diff, eps, msg=message_template('FPARAM',_expected,_fparam))
                                #self.assertTrue(all(flag[:,:10].flatten() == True), msg=message_template('flag status', True))
                                #self.assertTrue(all(flag[:,10:].flatten() == False), msg=message_template('flag status', False))
                                #fparam_valid = fparam[flag == False]
                                #error = abs((fparam_valid - expected) / expected)
                                #self.assertTrue(all(error < eps), msg=message_template('sky data', expected))
            return _wrapper
        return wrapper

    @exception_case(RuntimeError, '^Invalid fraction value \(.+\)$')
    def test_otfraster00(self):
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction='auto')

    @exception_case(ValueError, '^Too many edge points\. fraction must be < 0.5\.$')
    def test_otfraster01(self):
        """
        test_otfraster01 --- too many edge points (fraction 0.5)
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction=0.5)

    @exception_case(ValueError, '^Too many edge points\. fraction must be < 0.5\.$')
    def test_otfraster02(self):
        """
        test_otfraster02 --- too many edge points (fraction 50%)
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction='50%')

    @exception_case(RuntimeError, 'Error in Calibrater::solve')
    def test_otfraster03(self):
        """
        test_otfraster03 --- too many edge points (noff 100000)
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', noff=10000)

    #@exception_case(RuntimeError, 'Error in Calibrater::solve')
    #def test_otfraster04(self):
    #    """
    #    test_otfraster04 --- negative edge points
    #    """
    #    self.result = sdcal(infile=self.infile, outfile=self.outfile,
    #                         calmode='otfraster', noff=-3)

    #@exception_case(RuntimeError, 'Error in Calibrater::solve')
    #def test_otfraster05(self):
    #    """
    #    test_otfraster05 --- zero edge points
    #    """
    #    self.result = sdcal(infile=self.infile, outfile=self.outfile,
    #                         calmode='otfraster', noff=0)

    @exception_case(RuntimeError, 'Error in Calibrater::solve')
    def test_otfraster06(self):
        """
        test_otfraster06 --- inappropriate calibration mode ('ps')
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='ps')
    @normal_case(numedge=1)
    def test_otfraster07(self):
        """
        test_otfraster07 --- OTF raster calibration ('otfraster') with default setting
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster')

    @normal_case(numedge=1)
    def test_otfraster07M(self):
        """
        test_otfraster07M --- OTF raster calibration ('otfraster') with default setting (MMS)
        """
        with mmshelper(vis=self.infile) as mvis:
            self.assertTrue(mvis is not None)
            self.result = sdcal(infile=mvis, outfile=self.outfile,
                                calmode='otfraster')

    @normal_case(numedge=2)
    def test_otfraster08(self):
        """
        test_otfraster08 --- OTF raster calibration ('otfraster') with string fraction (numeric value)
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction='0.3')

    @normal_case(numedge=2)
    def test_otfraster09(self):
        """
        test_otfraster09 --- OTF raster calibration ('otfraster') with string fraction (percentage)
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction='30%')

    @normal_case(numedge=2)
    def test_otfraster10(self):
        """
        test_otfraster10 --- OTF raster calibration ('otfraster') with numeric fraction
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction=0.3)

    @normal_case(numedge=2)
    def test_otfraster11(self):
        """
        test_otfraster11 --- OTF raster calibration ('otfraster') with auto detection
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction=0, noff=0)

    @normal_case(numedge=3)
    def test_otfraster12(self):
        """
        test_otfraster12 --- OTF raster calibration ('otfraster') with custom noff
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', noff=3)

    @normal_case(numedge=3)
    def test_otfraster13(self):
        """
        test_otfraster13 --- check if noff takes priority over fraction
        """
        self.result = sdcal(infile=self.infile, outfile=self.outfile,
                             calmode='otfraster', fraction='90%', noff=3)

def assert_true(condition,err_msg):
    """Assertion not optimized away -contrary to Python assert- when Python interpreter is run in optimized mode (__debug__=False)"""
    if not condition:
        raise RuntimeError(err_msg)

class CasaTableChecker:
    """Base class for OTF mode checkers"""
    def __init__(self,tbl_path):
        self.tb = table()
        self.path = tbl_path
        self.tb.open(self.path) # Raises RuntimeError on failure
        assert self.tb.ok()
    def __del__(self):
        self.tb.close()

class MsCalTableChecker(CasaTableChecker):
    """OTF mode checker: checking equality of 2 ms_caltable fparam columns within specified tolerance"""
    def __init__(self,tbl_path,tol=1e-6):
        CasaTableChecker.__init__(self, tbl_path)
        assert_true('FPARAM' in self.tb.colnames(),
                   str(self.path) + ': FPARAM column missing')
        self.fparam = self.tb.getcol('FPARAM')
        self.tol=tol
    def __eq__(self,other):
        assert_true(self.fparam.shape == other.fparam.shape,
                   'FPARAM columns: shape mismatch: expected {expected} result {result}'.format(expected=self.fparam.shape, result=other.fparam.shape))
        assert_true(numpy.allclose(self.fparam,other.fparam,atol=self.tol,rtol=0.0),
                   'FPARAM columns: error exceeds tolerance='+str(self.tol))
        return True
    def __del__(self):
        CasaTableChecker.__del__(self)

class MsCorrectedDataChecker(CasaTableChecker):
    """OTF mode checker: checking equality of 2 ms corrected_data columns within specified tolerance"""
    def __init__(self,tbl_path,convert_to_kelvin=False,tol_real=1e-6):
        CasaTableChecker.__init__(self, tbl_path)
        self.tol_real = tol_real
        assert_true('CORRECTED_DATA' in self.tb.colnames(),
                   str(self.path) + ': CORRECTED_DATA column missing')
        self.cdata = self.tb.getcol('CORRECTED_DATA')
        if convert_to_kelvin:
            tbl_syscal = table()
            tbl_syscal.open(os.path.join(tbl_path,'SYSCAL'))
            tsys_spectrum = tbl_syscal.getcol('TSYS_SPECTRUM')
            self.cdata = tsys_spectrum * self.cdata
            tbl_syscal.close()

    def __eq__(self,other):
        assert_true(self.cdata.shape == other.cdata.shape,
                   'CORRECTED_DATA: shape: mismatch')
        assert_true((self.cdata.imag == other.cdata.imag).all(),
                   'CORRECTED_DATA: imaginary part: mismatch')
        assert_true(numpy.allclose(self.cdata.real,other.cdata.real,atol=self.tol_real,rtol=0.0),
                   'CORRECTED_DATA: real part: error exceeds tolerance='+str(self.tol_real))
        return True
    def __del__(self):
        CasaTableChecker.__del__(self)


class sdcal_test_otf(unittest.TestCase):
    """
    Unit tests for task sdcal,
    sky calibration mode = 'otf' : On-The-Fly (OTF) *non-raster*

    The list of tests:
    Test       | Input            | Edges    | Calibration
    Name       | MS               | Fraction | Mode
    ==========================================================================
    test_otf01 | squares.dec60_cs | 10%      | otf
    test_otf02 | squares.dec60_cs | 20%      | otf
    test_otf03 | lissajous        | 10%      | otf
    test_otf04 | lissajous        | 20%      | otf
    test_otf05 | squares.dec60_cs | 10%      | otf,apply
    test_otf06 | lissajous        | 10%      | apply
    test_otf07 | lissajous        | 10%      | otf,tsys,apply
    """

    # Required checkers:
    # - compare 2 calibration tables
    # - compare 2 corrected data
    datapath=ctsys_resolve('unittest/sdcal/')
    ref_datapath=os.path.join(datapath,'otf_reference_data')
    sdcal_params = {}
    current_test_params = {}

    def setup(self):
        # Copy input MS into current directory
        infile = self.sdcal_params['infile']
        if os.path.exists(infile):
            shutil.rmtree(infile)
        shutil.copytree(os.path.join(self.datapath, infile), infile)
        # Delete output calibration table if any
        if 'outfile' in self.sdcal_params :
            outfile = self.sdcal_params['outfile']
            if os.path.exists(outfile):
                shutil.rmtree(outfile)
        # Compute reference calibrated ms if required
        if 'compute_ref_ms' in self.current_test_params:
            # Create a second copy of input MS
            ref_ms_name = 'ref_'+infile
            if os.path.exists(ref_ms_name):
                shutil.rmtree(ref_ms_name)
            shutil.copytree(src=infile,dst=ref_ms_name)
            # Calibrate it using current test caltable
            sdcal(infile=ref_ms_name,calmode='apply',applytable=self.ref_caltable())
            # Update test params
            self.current_test_params['ref_calibrated_ms'] = ref_ms_name

    def tearDown(self):
        casalog.post("tearDown")
        infile = self.sdcal_params['infile']
        if os.path.exists(infile):
            shutil.rmtree(infile)
        if 'outfile' in self.sdcal_params:
            outfile = self.sdcal_params['outfile']
            if os.path.exists(outfile):
                shutil.rmtree(outfile)
        if 'compute_ref_ms' in self.current_test_params:
            ref_ms_name = self.ref_calibrated_ms()
            if os.path.exists(ref_ms_name):
                shutil.rmtree(ref_ms_name)


    def run_sdcal(self):
        self.setup()
        sdcal(**self.sdcal_params)

    def ref_caltable(self):
        assert_true('ref_caltable' in self.current_test_params,'sdcal_test_otf internal error')
        return os.path.join(self.ref_datapath,self.current_test_params['ref_caltable'])

    def ref_calibrated_ms(self):
        assert_true('ref_calibrated_ms' in self.current_test_params,'sdcal_test_otf internal error')
        ref_ms_name = self.current_test_params['ref_calibrated_ms']
        if 'compute_ref_ms' in self.current_test_params:
            return ref_ms_name
        else:
            return os.path.join(self.ref_datapath,ref_ms_name)

    def test_otf01(self):
        """
        test_otf01 --- Compute calibration table. calmode='otf' ms=squares.dec60_cs.ms
        """
        self.sdcal_params = {
            'infile':'squares.dec60_cs.ms',
            'calmode':'otf',
            'outfile':'test_otf01.ms_caltable'
        }
        self.current_test_params = {
            'ref_caltable':'squares.dec60_cs.edges_fraction_0.1.ms_caltable'
        }
        expected_result = MsCalTableChecker(self.ref_caltable())
        self.run_sdcal()
        sdcal_result = MsCalTableChecker(self.sdcal_params['outfile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf02(self):
        """
        test_otf02 --- Compute calibration table. calmode='otf' ms=squares.dec60_cs.ms edges_fraction=20%
        """
        self.sdcal_params = {
            'infile':'squares.dec60_cs.ms',
            'calmode':'otf',
            'outfile':'test_otf02.ms_caltable',
            'fraction':0.2
        }
        self.current_test_params = {
            'ref_caltable':'squares.dec60_cs.edges_fraction_0.2.ms_caltable'
        }
        expected_result = MsCalTableChecker(self.ref_caltable())
        self.run_sdcal()
        sdcal_result = MsCalTableChecker(self.sdcal_params['outfile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf03(self):
        """
        test_otf03 --- Compute calibration table. calmode='otf' ms=lissajous.ms
        """
        self.sdcal_params = {
            'infile':'lissajous.ms',
            'calmode':'otf',
            'outfile':'test_otf03.ms_caltable'
        }
        self.current_test_params = {
            'ref_caltable':'lissajous.edges_new_fraction_0.1.ms_caltable'
        }
        expected_result = MsCalTableChecker(self.ref_caltable())
        self.run_sdcal()
        sdcal_result = MsCalTableChecker(self.sdcal_params['outfile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf03M(self):
        """
        test_otf03 --- Compute calibration table. calmode='otf' ms=lissajous.ms (MMS)
        """
        self.sdcal_params = {
            'infile':'lissajous.ms',
            'calmode':'otf',
            'outfile':'test_otf03.ms_caltable'
        }
        self.current_test_params = {
            'ref_caltable':'lissajous.edges_new_fraction_0.1.ms_caltable'
        }
        expected_result = MsCalTableChecker(self.ref_caltable())
        self.setup()
        infile = self.sdcal_params['infile']
        with mmshelper(vis=infile) as mvis:
            self.assertTrue(mvis is not None)
            self.sdcal_params['infile'] = mvis
            sdcal(**self.sdcal_params)
            self.sdcal_params['infile'] = infile
        sdcal_result = MsCalTableChecker(self.sdcal_params['outfile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf04(self):
        """
        test_otf04 --- Compute calibration table. calmode='otf' ms=lissajous.ms edges_fraction=20%
        """
        self.sdcal_params = {
            'infile':'lissajous.ms',
            'calmode':'otf',
            'outfile':'test_otf04.ms_caltable',
            'fraction':'20%'
        }
        self.current_test_params = {
            'ref_caltable':'lissajous.edges_new_fraction_0.2.ms_caltable'
        }
        expected_result = MsCalTableChecker(self.ref_caltable())
        self.run_sdcal()
        sdcal_result = MsCalTableChecker(self.sdcal_params['outfile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf05(self):
        """
        test_otf05 --- Sky calibration. calmode='otf,apply' ms=squares.dec60_cs.ms
        """
        self.sdcal_params = {
            'infile':'squares.dec60_cs.ms',
            'calmode':'otf,apply'
        }
        self.current_test_params = {
            'ref_caltable':'squares.dec60_cs.edges_fraction_0.1.ms_caltable',
            'compute_ref_ms':True
        }
        self.run_sdcal()
        expected_result = MsCorrectedDataChecker(self.ref_calibrated_ms())
        sdcal_result = MsCorrectedDataChecker(self.sdcal_params['infile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf06(self):
        """
        test_otf06 --- Sky calibration reusing caltable pre-computed with calmode='otf'. calmode='apply' ms=lissajous.ms
        """
        self.sdcal_params = {
            'infile':'lissajous.ms',
            'calmode':'apply',
            'applytable':os.path.join(self.ref_datapath,'lissajous.edges_new_fraction_0.1.ms_caltable')
        }
        self.current_test_params = {
            # new reference data after George's CTTimeInterp1 fix
            #'ref_calibrated_ms':'lissajous.edges_new_fraction_0.1.sky.ms'
            'ref_calibrated_ms':'lissajous.edges_after4.7_fraction_0.1.sky.ms'
        }
        expected_result = MsCorrectedDataChecker(self.ref_calibrated_ms())
        self.run_sdcal()
        sdcal_result = MsCorrectedDataChecker(self.sdcal_params['infile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics

    def test_otf07(self):
        """
        test_otf07 --- Sky calibration + Tsys conversion, composite calmode='otf,tsys,apply'. ms=lissajous.ms
        """
        self.sdcal_params = {
            'infile':'lissajous.ms',
            'calmode':'otf,tsys,apply',
        }
        self.current_test_params = {
            # new reference data after George's CTTimeInterp1 fix
            #'ref_calibrated_ms':'lissajous.edges_new_fraction_0.1.sky.ms'
            'ref_calibrated_ms':'lissajous.edges_after4.7_fraction_0.1.sky.ms'
        }
        expected_result = MsCorrectedDataChecker(self.ref_calibrated_ms(),convert_to_kelvin=True)
        self.run_sdcal()
        sdcal_result = MsCorrectedDataChecker(self.sdcal_params['infile'])
        self.assertEqual(sdcal_result,expected_result) # AlmostEqual semantics


class sdcal_test_otf_ephem(unittest.TestCase):
    """
    Unit tests for task sdcal,
    sky calibration mode = 'otf' : On-The-Fly (OTF) *non-raster*

    Test cases for ephemeris objects are defined in this class.

    The list of tests:
    Test            | Input            | Edges    | Calibration
    Name            | MS               | Fraction | Mode
    ==========================================================================
    test_otfephem01 | otf_ephem.ms     | 10%      | otf
    test_otfephem02 | otf_ephem.ms     | 10%      | otf,apply
    """
    
    datapath=ctsys_resolve('unittest/sdcal/')
    infile = 'otf_ephem.ms'
    outfile = infile + '.otfcal'

    def setUp(self):
        if os.path.exists(self.infile):
            shutil.rmtree(self.infile)
        shutil.copytree(os.path.join(self.datapath, self.infile), self.infile)

        default(sdcal)

    def tearDown(self):
        to_be_removed = [self.infile, self.outfile]
        for f in to_be_removed:
            if os.path.exists(f):
                shutil.rmtree(f)

    def check_ephem(self, vis):
        with sdutil.tbmanager(os.path.join(vis, 'SOURCE')) as tb:
            names = tb.getcol('NAME')
        self.assertEqual(len(names), 1)
        me = measures()
        direction_refcodes = me.listcodes(me.direction())
        ephemeris_codes = direction_refcodes['extra']
        self.assertIn(names[0].upper(), ephemeris_codes)

    def check_fresh_ms(self, vis):
        with sdutil.tbmanager(vis) as tb:
            colnames = tb.colnames()

        self.assertNotIn('CORRECTED_DATA', colnames)

    def check_caltable(self, caltable):
        with sdutil.tbmanager(caltable) as tb:
            data = tb.getcol('FPARAM')

        self.assertEqual(data.shape, (2, 1, 10))
        self.assertTrue(numpy.all(data == 1.0))

    def check_corrected(self, vis):
        with sdutil.tbmanager(vis) as tb:
            data = tb.getcol('CORRECTED_DATA')
            nrow = tb.nrows()

        self.assertEqual(nrow, 40)

        real = data.real
        expected = numpy.ones(real.shape, dtype=numpy.float64)
        for irow in range(nrow):
            if irow % 4 == 3:
                expected[:,:,irow] = 0.0
        self.assertTrue(numpy.all(real == expected))

        imag = data.imag
        self.assertTrue(numpy.all(imag == 0))


    def test_otfephem01(self):
        """test_otfephem01: Sky calibration of 'otf' mode for ephemeris object"""
        self.check_ephem(self.infile)
        self.assertFalse(os.path.exists(self.outfile))
        sdcal(infile=self.infile, outfile=self.outfile, calmode='otf')
        self.check_caltable(self.outfile)

    def test_otfephem02(self):
        """test_otfephem02: On-the-fly application of 'otf' calibration mode for ephemeris object"""
        self.check_ephem(self.infile)
        self.check_fresh_ms(self.infile)
        sdcal(infile=self.infile, calmode='otf,apply')
        self.check_corrected(self.infile)



# interpolator utility for testing
class Interpolator(object):
    @staticmethod
    def __interp_freq_linear(data, flag):
        outdata = data.copy()
        outflag = flag
        npol, nchan = outdata.shape
        for ipol in range(npol):
            valid_chans = numpy.where(outflag[ipol,:] == False)[0]
            if len(valid_chans) == 0:
                continue
            for ichan in range(nchan):
                if outflag[ipol,ichan] == True:
                    #print('###', ipol, ichan, 'before', data[ipol,ichan])
                    if ichan <= valid_chans[0]:
                        outdata[ipol,ichan] = data[ipol,valid_chans[0]]
                    elif ichan >= valid_chans[-1]:
                        outdata[ipol,ichan] = data[ipol,valid_chans[-1]]
                    else:
                        ii = abs(valid_chans - ichan).argmin()
                        if valid_chans[ii] - ichan > 0:
                            ii -= 1
                        i0 = valid_chans[ii]
                        i1 = valid_chans[ii+1]
                        outdata[ipol,ichan] = ((i1 - ichan) * data[ipol,i0] + (ichan - i0) * data[ipol,i1]) / (i1 - i0)
                    #print('###', ipol, ichan, 'after', data[ipol,ichan])
        return outdata, outflag

    @staticmethod
    def interp_freq_linear(data, flag):
        outflag = flag.copy()
        outflag[:] = False
        outdata, outflag = Interpolator.__interp_freq_linear(data, outflag)
        return outdata, outflag

    @staticmethod
    def interp_freq_nearest(data, flag):
        outdata = data.copy()
        outflag = flag
        npol, nchan = outdata.shape
        for ipol in range(npol):
            valid_chans = numpy.where(outflag[ipol,:] == False)[0]
            if len(valid_chans) == 0:
                continue
            for ichan in range(nchan):
                if outflag[ipol,ichan] == True:
                    #print('###', ipol, ichan, 'before', data[ipol,ichan])
                    if ichan <= valid_chans[0]:
                        outdata[ipol,ichan] = data[ipol,valid_chans[0]]
                    elif ichan >= valid_chans[-1]:
                        outdata[ipol,ichan] = data[ipol,valid_chans[-1]]
                    else:
                        ii = abs(valid_chans - ichan).argmin()
                        outdata[ipol,ichan] = data[ipol,valid_chans[ii]]
                    #print('###', ipol, ichan, 'after', data[ipol,ichan])
        return outdata, outflag

    @staticmethod
    def interp_freq_linearflag(data, flag):
        # NOTE
        # interpolation/extrapolation of flag along frequency axis is
        # also needed for linear interpolation. Due to this, number of
        # flag channels will slightly increase and causes different
        # behavior from existing scantable based single dish task
        # (sdcal2).
        #
        # It appears that effective flag at a certain channel is set to
        # the flag at previous channels (except for channel 0).
        #
        # 2015/02/26 TN
        npol,nchan = flag.shape
        #print('###BEFORE', flag[:,:12])
        outflag = flag.copy()
        for ichan in range(nchan-1):
            outflag[:,ichan] = numpy.logical_or(flag[:,ichan], flag[:,ichan+1])
        outflag[:,1:] = outflag[:,:-1]
        outflag[:,-1] = flag[:,-2]
        #print('###AFTER', outflag[:,:12])

        outdata, outflag = Interpolator.__interp_freq_linear(data, outflag)
        return outdata, outflag

    @staticmethod
    def interp_freq_nearestflag(data, flag):
        outdata, outflag = Interpolator.interp_freq_nearest(data, flag)
        return outdata, outflag

    def __init__(self, table, finterp='linear'):
        self.table = table
        self.taql = ''
        self.time = None
        self.data = None
        self.flag = None
        self.exposure = None
        self.finterp = getattr(Interpolator,'interp_freq_%s'%(finterp.lower()))
        print('self.finterp:', self.finterp.__name__)

    def select(self, antenna, spw):
        self.taql = 'ANTENNA1 == %s && ANTENNA2 == %s && SPECTRAL_WINDOW_ID == %s'%(antenna, antenna, spw)
        with sdutil.table_selector(self.table, self.taql) as tb:
            self.time = tb.getcol('TIME')
            self.data = tb.getcol('FPARAM')
            self.flag = tb.getcol('FLAG')
            self.exposure = tb.getcol('INTERVAL')

    def interpolate(self, t):
        raise Exception('Not implemented')

    def weightscale_linear(self, dt_on, dt_off0, dt_off1=None, t_on=None, t_off0=None, t_off1=None):
        if dt_off1 is None:
            return self.weightscale_nearest(dt_on, dt_off0)
        else:
            delta = t_off1 - t_off0
            delta0 = t_on - t_off0
            delta1 = t_off1 - t_on
            sigmasqscale = 1.0 + dt_on / (delta * delta) * (delta1 * delta1 / dt_off0 + delta0 * delta0 / dt_off1)
            return 1.0 / sigmasqscale

    def weightscale_nearest(self, dt_on, dt_off):
        return dt_off / (dt_on + dt_off)

class LinearInterpolator(Interpolator):
    def __init__(self, table, finterp='linear'):
        super(LinearInterpolator, self).__init__(table, finterp)

    def interpolate(self, t, tau):
        dt = self.time - t
        index = abs(dt).argmin()
        if dt[index] > 0.0:
            index -= 1
        if index < 0:
            ref = self.data[:,:,0].copy()
            weightscale = self.weightscale_linear(tau, self.exposure[0])
        elif index >= len(self.time) - 1:
            ref = self.data[:,:,-1].copy()
            weightscale = self.weightscale_linear(tau, self.exposure[-1])
        else:
            t0 = self.time[index]
            t1 = self.time[index+1]
            d0 = self.data[:,:,index]
            d1 = self.data[:,:,index+1]
            dt0 = t - t0
            dt1 = t1 - t
            dt2 = t1 - t0
            ref = (dt1 * d0 + dt0 * d1) / dt2
            tau0 = self.exposure[index]
            tau1 = self.exposure[index+1]
            weightscale = self.weightscale_linear(tau, tau0, tau1, t, t0, t1)
        flag = self.interpolate_flag(t)
        ref, refflag = self.finterp(ref, flag)

        return ref, refflag, weightscale

    def interpolate_flag(self, t):
        dt = self.time - t
        index = abs(dt).argmin()
        if dt[index] > 0.0:
            index -= 1
        if index < 0:
            flag = self.flag[:,:,0].copy()
        elif index >= len(self.time) - 1:
            flag = self.flag[:,:,-1].copy()
        else:
            f0 = self.flag[:,:,index]
            f1 = self.flag[:,:,index+1]
            flag = numpy.logical_or(f0, f1)

        return flag

class NearestInterpolator(Interpolator):
    def __init__(self, table, finterp='nearest'):
        super(NearestInterpolator, self).__init__(table, finterp)

    def interpolate(self, t, tau):
        dt = self.time - t
        index = abs(dt).argmin()
        weightscale = self.weightscale_nearest(tau, self.exposure[index])
        ref, refflag = self.finterp(self.data[:,:,index].copy(), self.flag[:,:,index].copy())
        return ref, refflag, weightscale


class sdcal_test_apply(sdcal_test_base):

    """
    Unit test for task sdcal (apply tables).

    The list of tests:
    test_apply_sky00 --- empty applytable
    test_apply_sky01 --- empty applytable (list ver.)
    test_apply_sky02 --- empty applytable list
    test_apply_sky03 --- unexisting applytable
    test_apply_sky04 --- unexisting applytable (list ver.)
    test_apply_sky05 --- invalid selection (empty selection result)
    test_apply_sky06 --- invalid interp value
    test_apply_sky07 --- invalid applytable (not caltable)
    test_apply_sky08 --- apply data (linear)
    test_apply_sky09 --- apply selected data
    test_apply_sky10 --- apply data (nearest)
    test_apply_sky11 --- apply data (linearflag for frequency interpolation)
    test_apply_sky12 --- apply data (nearestflag for frequency interpolation)
    test_apply_sky13 --- apply data (string applytable input)
    test_apply_sky14 --- apply data (interp='')
    test_apply_sky15 --- check if WEIGHT_SPECTRUM is updated properly when it exists
    test_apply_sky16 --- apply both sky table and Tsys table simultaneously
    test_apply_composite00 --- on-the-fly application of sky table ('ps,apply')
    test_apply_composite01 --- on-the-fly application of sky table with existing Tsys table
    test_apply_composite02 --- on-the-fly application of sky and tsys tables ('ps,tsys,apply')
    test_apply_composite03 --- on-the-fly application of sky table ('otfraster,apply')
    """
    invalid_argument_case = sdcal_test_base.invalid_argument_case
    exception_case = sdcal_test_base.exception_case

    @property
    def nrow_per_chunk(self):
        # number of rows per antenna per spw is 18
        return 18

    @property
    def eps(self):
        # required accuracy is 2.0e-4
        return 3.0e-4

    def setUp(self):
        self._setUp([self.infile, self.applytable], sdcal)


    def tearDown(self):
        self._tearDown([self.infile, self.applytable])

    def check_weight(self, inweight, outweight, scale):
        #print('inweight', inweight)
        #print('outweight', outweight)
        #print('scale', scale)
        # shape check
        self.assertEqual(inweight.shape, outweight.shape, msg='')

        # weight should not be zero
        self.assertFalse(any(inweight.flatten() == 0.0), msg='')
        self.assertFalse(any(outweight.flatten() == 0.0), msg='')

        # check difference
        expected_weight = inweight * scale
        diff = abs((outweight - expected_weight) / expected_weight)
        self.assertTrue(all(diff.flatten() < self.eps),
                        msg='')


    def normal_case(interp='linear', tsys=1.0, **kwargs):
        """
        Decorator for the test case that is intended to verify
        normal execution result.

        interp --- interpolation option ('linear', 'nearest', '*flag')
                   comma-separated list is allowed and it will be
                   interpreted as '<interp for time>,<intep for freq>'
        tsys --- tsys scaling factor
        selection --- data selection parameter as dictionary
        """
        def wrapper(func):
            import functools
            @functools.wraps(func)
            def _wrapper(self):
                # data selection
                myms = ms()
                myargs = kwargs.copy()
                if 'baseline' not in myargs:
                    with sdutil.tbmanager(self.infile) as tb:
                        antenna1 = numpy.unique(tb.getcol('ANTENNA1'))
                        myargs['baseline'] = '%s&&&'%(','.join(map(str,antenna1)))
                a = myms.msseltoindex(self.infile, **myargs)
                antennalist = a['antenna1']
                with sdutil.tbmanager(self.applytable) as tb:
                    spwlist = numpy.unique(tb.getcol('SPECTRAL_WINDOW_ID'))
                with sdutil.tbmanager(os.path.join(self.infile, 'DATA_DESCRIPTION')) as tb:
                    spwidcol = tb.getcol('SPECTRAL_WINDOW_ID').tolist()
                    spwddlist = map(spwidcol.index, spwlist)
                if len(a['spw']) > 0:
                    spwlist = list(set(spwlist) & set(a['spw']))
                    spwddlist = map(spwidcol.index, spwlist)

                # preserve original flag and weight
                flag_org = {}
                weight_org = {}
                weightsp_org = {}
                for antenna in antennalist:
                    flag_org[antenna] = {}
                    weight_org[antenna] = {}
                    weightsp_org[antenna] = {}
                    for (spw,spwdd) in zip(spwlist,spwddlist):
                        taql = 'ANTENNA1 == %s && ANTENNA2 == %s && DATA_DESC_ID == %s'%(antenna, antenna, spwdd)
                        with sdutil.table_selector(self.infile, taql) as tb:
                            flag_org[antenna][spw] = tb.getcol('FLAG')
                            weight_org[antenna][spw] = tb.getcol('WEIGHT')
                            if 'WEIGHT_SPECTRUM' in tb.colnames() and tb.iscelldefined('WEIGHT_SPECTRUM', 0):
                                #print('WEIGHT_SPECTRUM is defined for antenna %s spw %s'%(antenna, spw))
                                weightsp_org[antenna][spw] = tb.getcol('WEIGHT_SPECTRUM')
                            #else:
                            #    print('WEIGHT_SPECTRUM is NOT defined for antenna %s spw %s'%(antenna, spw))

                # execute test
                func(self)

                # sanity check
                self.assertIsNone(self.result, msg='The task must complete without error')
                # verify if CORRECTED_DATA exists
                with sdutil.tbmanager(self.infile) as tb:
                    self.assertTrue('CORRECTED_DATA' in tb.colnames(), msg='CORRECTED_DATA column must be created after task execution!')

                # parse interp
                pos = interp.find(',')
                if pos == -1:
                    tinterp = interp.lower()
                    finterp = 'linearflag'
                else:
                    tinterp = interp[:pos].lower()
                    finterp = interp[pos+1:]
                if len(tinterp) == 0:
                    tinterp = 'linear'
                if len(finterp) == 0:
                    finterp = 'linearflag'

                # CAS-10772
                # Linear flag interpolation along spectral axis behaves like "nearest" if science and
                # calibrater data have same set of frequency channels. This is always true for single
                # dish sky calibration.
                # So, finterp option for flags should always be 'nearestflag'.
                if finterp == 'linearflag':
                    finterp = 'nearestflag'

                # result depends on interp
                print('Interpolation option:', tinterp, finterp)
                self.assertTrue(tinterp in ['linear', 'nearest'], msg='Internal Error')
                if tinterp == 'linear':
                    interpolator = LinearInterpolator(self.applytable, finterp)
                else:
                    interpolator = NearestInterpolator(self.applytable, finterp)
                for antenna in antennalist:
                    for (spw,spwdd) in zip(spwlist,spwddlist):
                        interpolator.select(antenna, spw)
                        taql = 'ANTENNA1 == %s && ANTENNA2 == %s && DATA_DESC_ID == %s'%(antenna, antenna, spwdd)
                        with sdutil.table_selector(self.infile, taql) as tb:
                            self.assertEqual(tb.nrows(), self.nrow_per_chunk, msg='Number of rows mismatch in antenna %s spw %s'%(antenna, spw))
                            if spw in weightsp_org[antenna]:
                                has_weightsp = True

                            else:
                                has_weightsp = False
                            for irow in range(tb.nrows()):
                                t = tb.getcell('TIME', irow)
                                dt = tb.getcell('INTERVAL', irow)
                                data = tb.getcell('DATA', irow)
                                outflag = tb.getcell('FLAG', irow)
                                corrected = tb.getcell('CORRECTED_DATA', irow)
                                ref, calflag, weightscale = interpolator.interpolate(t, dt)
                                inflag = flag_org[antenna][spw][:,:,irow]
                                expected = tsys * (data - ref) / ref
                                expected_flag = numpy.logical_or(inflag, calflag)

                                # weight test
                                self.assertEqual(tb.iscelldefined('WEIGHT_SPECTRUM', irow), has_weightsp,
                                                 msg='')
                                inweight = weight_org[antenna][spw][:,irow]
                                outweight = tb.getcell('WEIGHT', irow)
                                tsyssq = tsys * tsys
                                if has_weightsp:
                                    # Need to check WEIGHT_SPECTRUM
                                    inweightsp = weightsp_org[antenna][spw][:,:,irow]
                                    outweightsp = tb.getcell('WEIGHT_SPECTRUM', irow)

                                    self.check_weight(inweight, outweight, weightscale / tsyssq)
                                    self.check_weight(inweightsp, outweightsp, weightscale / tsyssq)
                                else:
                                    self.check_weight(inweight, outweight, weightscale / tsyssq)

                                #print('antenna', antenna, 'spw', spw, 'row', irow)
                                #print('inflag', inflag[:,:12], 'calflag', calflag[:,:12], 'expflag', expected_flag[:,:12], 'outflag', outflag[:,:12])
                                #print('ref', ref[:,126:130], 'data', data[:,126:130], 'expected', expected[:,126:130], 'corrected', corrected[:,126:130])

                                self.assertEqual(corrected.shape, expected.shape, msg='Shape mismatch in antenna %s spw %s row %s (expeted %s actual %s)'%(antenna,spw,irow,list(expected.shape),list(corrected.shape)))
                                npol, nchan = corrected.shape

                                # verify data
                                diff = numpy.ones(expected.shape,dtype=float)
                                small_data = numpy.where(abs(expected) < 1.0e-7)
                                diff[small_data] = abs(corrected[small_data] - expected[small_data])
                                regular_data = numpy.where(abs(expected) >= 1.0e-7)
                                diff[regular_data] = abs((corrected[regular_data] - expected[regular_data]) / expected[regular_data])
                                self.assertTrue(all(diff.flatten() < self.eps), msg='Calibrated result differ in antenna %s spw %s row %s (expected %s actual %s diff %s)'%(antenna,spw,irow,expected,corrected,diff))



                                # verify flag
                                self.assertTrue(all(outflag.flatten() == expected_flag.flatten()), msg='Resulting flag differ in antenna%s spw %s row %s (expected %s actual %s)'%(antenna,spw,irow,expected_flag,outflag))

            return _wrapper
        return wrapper

    @exception_case(Exception, 'Applytable name must be specified.')
    def test_apply_sky00(self):
        """
        test_apply_sky00 --- empty applytable
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable='')

    @exception_case(Exception, 'Applytable name must be specified.')
    def test_apply_sky01(self):
        """
        test_apply_sky01 --- empty applytable (list ver.)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[''])

    @exception_case(Exception, 'Applytable name must be specified.')
    def test_apply_sky02(self):
        """
        test_apply_sky02 --- empty applytable list
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[])

    @exception_case(Exception, "^Table doesn't exist:")
    def test_apply_sky03(self):
        """
        test_apply_sky03 --- unexisting applytable
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable='notexist.sky')

    @exception_case(Exception, "^Table doesn't exist:")
    def test_apply_sky04(self):
        """
        test_apply_sky04 --- unexisting applytable (list ver.)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=['notexist.sky'])

    @exception_case(RuntimeError, 'Spw Expression: No match found for 99')
    def test_apply_sky05(self):
        """
        test_apply_sky05 --- invalid selection (empty selection result)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', spw='99', applytable=[self.applytable])

    #@exception_case(RuntimeError, '^Unknown interptype: \'.+\'!! Check inputs and try again\.$')
    @exception_case(RuntimeError, 'Error in Calibrater::setapply.')
    def test_apply_sky06(self):
        """
        test_apply_sky06 --- invalid interp value
        """
        # 'sinusoid' interpolation along time axis is not supported
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], interp='sinusoid')

    @exception_case(RuntimeError, '^Applytable \'.+\' is not a caltable format$')
    def test_apply_sky07(self):
        """
        test_apply_sky07 --- invalid applytable (not caltable)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.infile], interp='linear')

    @normal_case()
    def test_apply_sky08(self):
        """
        test_apply_sky08 --- apply data (linear)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], interp='linear')

    @normal_case()
    def test_apply_sky08M(self):
        """
        test_apply_sky08M --- apply data (linear) for MMS
        """
        self.skipTest('Skip test_apply_sky08M until calibrator tool supports processing MMS on serial casa')
        #self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], interp='linear')

    @normal_case(spw='9')
    def test_apply_sky09(self):
        """
        test_apply_sky09 --- apply selected data
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], spw='9', interp='linear')

    @normal_case(interp='nearest')
    def test_apply_sky10(self):
        """
        test_apply_sky10 --- apply data (nearest)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], interp='nearest')

    @normal_case(interp='linear,linearflag')
    def test_apply_sky11(self):
        """
        test_apply_sky11 --- apply data (linearflag for frequency interpolation)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], interp='linear,linearflag')

    @normal_case(interp='linear,nearestflag')
    def test_apply_sky12(self):
        """
        test_apply_sky12 --- apply data (nearestflag for frequency interpolation)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=[self.applytable], interp='linear,nearestflag')

    @normal_case(interp='linear')
    def test_apply_sky13(self):
        """
        test_apply_sky13 --- apply data (string applytable input)
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=self.applytable, interp='linear')

    @normal_case(interp='')
    def test_apply_sky14(self):
        """
        test_apply_sky14 --- apply data (interp='')
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=self.applytable, interp='')

    def fill_weightspectrum(func):
        import functools
        @functools.wraps(func)
        def wrapper(self):
            with sdutil.tbmanager(self.infile, nomodify=False) as tb:
                self.assertTrue('WEIGHT_SPECTRUM' in tb.colnames(), msg='Internal Error')
                nrow = tb.nrows()
                for irow in range(nrow):
                    w = tb.getcell('WEIGHT', irow)
                    wsp = numpy.ones(tb.getcell('DATA', irow).shape, dtype=float)
                    for ipol in range(len(w)):
                        wsp[ipol,:] = w[ipol]
                    tb.putcell('WEIGHT_SPECTRUM', irow, wsp)
                    self.assertTrue(tb.iscelldefined('WEIGHT_SPECTRUM', irow), msg='Internal Error')
            func(self)
        return wrapper

    @fill_weightspectrum
    @normal_case(interp='linear')
    def test_apply_sky15(self):
        """
        test_apply_sky15 --- check if WEIGHT_SPECTRUM is updated properly when it exists
        """
        self.result = sdcal(infile=self.infile, calmode='apply', applytable=self.applytable, interp='linear')

    def modify_tsys(func):
        import functools
        @functools.wraps(func)
        def wrapper(self):
            with sdutil.tbmanager(os.path.join(self.infile,'SYSCAL'), nomodify=False) as tb:
                tsel = tb.query('SPECTRAL_WINDOW_ID IN [1,3]', sortlist='ANTENNA_ID,SPECTRAL_WINDOW_ID,TIME')
                try:
                    nrow = tsel.nrows()
                    tsys = 100.0
                    for irow in range(nrow):
                        tsys_spectrum = tsel.getcell('TSYS_SPECTRUM', irow)
                        tsys_spectrum[:] = 100.0
                        tsel.putcell('TSYS_SPECTRUM', irow, tsys_spectrum)
                        #tsys += 100.0
                finally:
                    tsel.close()
            func(self)
        return wrapper

    @normal_case()
    def test_apply_composite00(self):
        """
        test_apply_composite00 --- on-the-fly application of sky table ('ps,apply')
        """
        sdcal(infile=self.infile, calmode='ps,apply')

    @modify_tsys
    @normal_case(tsys=100.0)
    def test_apply_composite01(self):
        """
        test_apply_composite01 --- on-the-fly application of sky table with existing Tsys table
        """
        # generate Tsys table
        tsystable = self.infile.rstrip('/') + '.tsys'

        try:
            # generate Tsys table
            sdcal(infile=self.infile, calmode='tsys', outfile=tsystable)

            # apply
            sdcal(infile=self.infile, calmode='ps,apply', applytable=tsystable,
                   spwmap={1:[9], 3:[11]})
        finally:
            if os.path.exists(tsystable):
                shutil.rmtree(tsystable)

    @modify_tsys
    @normal_case(tsys=100.0)
    def test_apply_composite02(self):
        """
        test_apply_composite02 --- on-the-fly application of sky and tsys tables ('ps,tsys,apply')
        """
        sdcal(infile=self.infile, calmode='ps,tsys,apply',
               spwmap={1:[9], 3:[11]})

    @modify_tsys
    @normal_case(tsys=100.0)
    def test_apply_composite03(self):
        """
        test_apply_composite03 --- on-the-fly application of sky table ('otfraster,apply')
        """
        sdcal(infile=self.infile, calmode='tsys,apply', applytable=self.applytable,
               spwmap={1:[9], 3:[11]})

class sdcal_test_single_polarization(sdcal_test_base):
    """
    Unit test for task sdcal (calibration/application of single-polarization data).

    The list of tests:
    test_single_pol_ps --- generate caltable for single-polarization data
    test_single_pol_apply --- apply caltable to single-polarization data
    test_single_pol_apply_composite --- on-the-fly calibration/application on single-polarization data
    """
    datapath = ctsys_resolve('unittest/sdcal/')
    # Input
    infile = 'analytic_spectra.ms'
    #applytable = infile + '.sky'

    # task execution result
    result = None

    @property
    def outfile(self):
        return self.applytable

    def setUp(self):
        self._setUp([self.infile], sdcal)

    def tearDown(self):
        self._tearDown([self.infile, self.outfile])

    def _verify_caltable(self):
        """
        verify single polarization caltable

        This method checks if

            - calibration solution is properly stored in pol 0
            - pol 1 is all flagged

        Generated caltable will have the following properties:

            - number of rows is 2
            - FPARAM (pol 0) has identical value to infile rows that
              corresponds to OFF_SOURCE intents (STATE_ID 1)
            - FPARAM (pol 1) is all 0 and is all flagged
        """
        # get reference data from infile
        with sdutil.tbmanager(self.infile, nomodify=False) as tb:
            tsel = tb.query('STATE_ID == 1', sortlist='TIME')
            try:
                reftime = tsel.getcol('TIME')
                refdata = tsel.getcol('FLOAT_DATA')
                refflag = tsel.getcol('FLAG')
            finally:
                tsel.close()

        # verify caltable
        with sdutil.tbmanager(self.outfile, nomodify=False) as tb:
            caltime = tb.getcol('TIME')
            fparam = tb.getcol('FPARAM')
            calflag = tb.getcol('FLAG')

        self.assertEqual(len(caltime), len(reftime))
        self.assertTrue(numpy.all(caltime == reftime))

        nrow = len(caltime)

        self.assertEqual(fparam.shape, calflag.shape)
        calshape = fparam.shape
        datashape = refdata.shape
        self.assertEqual(calshape[0], 2)
        self.assertEqual(datashape[0], 1)
        self.assertEqual(calshape[1], datashape[1])
        self.assertEqual(calshape[2], datashape[2])

        for irow in range(nrow):
            # FPARAM (pol 0)
            self.assertTrue(numpy.all(fparam[0, :, irow] == refdata[0, :, irow]))
            self.assertTrue(numpy.all(calflag[0, :, irow] == refflag[0, :, irow]))

            # FPARAM (pol 1)
            self.assertTrue(numpy.all(fparam[1, :, irow] == 0))
            self.assertTrue(numpy.all(calflag[1, :, irow] == True))


    def _verify_application(self):
        """
        verify single polarization application result

        This method checks if

            - calibration solution is properly applied

        CORRECTED_DATA column will have the following properties:

            - For calibration spectra (STATE_ID 0), CORRECTED_DATA is identical to FLOAT_DATA
            - For OFF_SOURCE spectra (STATE_ID 1), CORRECTED_DATA is all 0
            - For ON_SOURCE spectra (STATE_ID 2), CORRECTED_DATA is a calculated result of
              (ON - OFF) / OFF with interpolated OFF in time
        """
        with sdutil.tbmanager(self.infile, nomodify=False) as tb:
            # calibration spectra (STATE_ID 0)
            tsel = tb.query('STATE_ID == 0')
            try:
                float_data = tsel.getcol('FLOAT_DATA')
                corrected_data = tsel.getcol('CORRECTED_DATA')

                self.assertTrue(numpy.all(corrected_data.real == float_data))
                self.assertTrue(numpy.all(corrected_data.imag == 0.0))
            finally:
                tsel.close()

            # OFF_SOURCE spectra (STATE_ID 1)
            reftime = None
            refdata = None
            tsel = tb.query('STATE_ID == 1', sortlist='TIME')
            try:
                corrected_data = tsel.getcol('CORRECTED_DATA')

                self.assertTrue(numpy.all(corrected_data.real == 0.0))
                self.assertTrue(numpy.all(corrected_data.imag == 0.0))
            finally:
                reftime = tsel.getcol('TIME')
                refdata = tsel.getcol('FLOAT_DATA')

                tsel.close()

            # ON_SOURCE spectra (STATE_ID 2)
            self.assertFalse(reftime is None)
            self.assertFalse(refdata is None)
            tsel = tb.query('STATE_ID == 2')
            try:
                sptime = tsel.getcol('TIME')
                float_data = tsel.getcol('FLOAT_DATA')
                corrected_data = tsel.getcol('CORRECTED_DATA')

                self.assertEqual(len(reftime), 2)

                nrow = float_data.shape[2]

                off_data = numpy.zeros(corrected_data.shape, dtype=numpy.float64)
                for irow in range(nrow):
                    off_data[:,:,irow] = (refdata[:,:,1] * (sptime[irow] - reftime[0]) \
                                          + refdata[:,:,0] * (reftime[1] - sptime[irow])) \
                                            / (reftime[1] - reftime[0])
                calibrated = (float_data - off_data) / off_data

                # exclude nan
                idx_not_nan = numpy.where(numpy.isfinite(float_data))
                diff = numpy.abs((corrected_data.real - calibrated) / calibrated)
                diff_not_nan = diff[idx_not_nan]
                eps = 1.0e-7
                #print('maxdiff = {}'.format(diff_not_nan.max()))
                self.assertTrue(numpy.all(diff_not_nan < eps))
                self.assertTrue(numpy.all(corrected_data[idx_not_nan].imag == 0.0))
            finally:
                tsel.close()

    def test_single_pol_ps(self):
        """
        test_single_pol_ps --- generate caltable for single-polarization data
        """
        self.result = sdcal(infile=self.infile, calmode='ps', outfile=self.outfile)
        self._verify_caltable()

    def test_single_pol_apply(self):
        """
        test_single_pol_apply --- apply caltable to single-polarization data
        """
        self.test_single_pol_ps()
        self.assertTrue(os.path.exists(self.outfile))

        self.result = sdcal(infile=self.infile, calmode='apply', applytable=self.outfile)
        self._verify_application()

    def test_single_pol_apply_composite(self):
        """
        test_single_pol_apply_composite --- on-the-fly calibration/application on single-polarization data
        """
        self.result = sdcal(infile=self.infile, calmode='ps,apply')
        self._verify_application()

def suite():
    return [  sdcal_test
            , sdcal_test_ps
            , sdcal_test_otfraster
            , sdcal_test_otf
            , sdcal_test_apply
            , sdcal_test_otf_ephem
            , sdcal_test_single_polarization
            , sdcal_test_bug_fix_cas_12712
            ]

if is_CASA6:
    if __name__ == '__main__':
        unittest.main()