from __future__ import absolute_import
from __future__ import print_function
import os
import sys
import shutil
import re
import numpy
import math
from scipy import signal
import unittest

from casatasks.private.casa_transition import is_CASA6
if is_CASA6:
    from casatools import ctsys, table, ms
    from casatasks import sdsmooth
    from casatasks.private import sdutil

    tb = table( )
else:
    from __main__ import default
    from tasks import *
    from taskinit import *
    import sdutil
    from sdsmooth import sdsmooth
    from taskinit import mstool as ms

    # the global tb tool is used here

def gaussian_kernel(nchan, kwidth):
    sigma = kwidth / (2.0 * math.sqrt(2.0 * math.log(2.0)))
    g = signal.gaussian(nchan, sigma, False)
    g /= g.sum()
    g0 = g[0]
    g[:-1] = g[1:]
    g[-1] = g0
    return g

class sdsmooth_test_base(unittest.TestCase):
    """
    Base class for sdsmooth unit test.
    The following attributes/functions are defined here.

        datapath
        decorators (invalid_argument_case, exception_case)
    """
    # Data path of input
    if is_CASA6:
        datapath=ctsys.resolve('unittest/sdsmooth/')
    else:
        datapath=os.path.join(os.environ.get('CASAPATH').split()[0],'casatestdata/unittest/sdsmooth/')

    # Input
    infile_data = 'tsdsmooth_test.ms'
    infile_float = 'tsdsmooth_test_float.ms'

    # task execution result
    result = None

    @property
    def outfile(self):
        return self.infile.rstrip('/') + '_out'

    # decorators
    @staticmethod
    def invalid_argument_case(func):
        """
        Decorator for the test case that is intended to fail
        due to invalid argument.
        """
        import functools
        @functools.wraps(func)
        def wrapper(self):
            func(self)
            self.assertFalse(self.result, msg='The task must return False')
        return wrapper

    @staticmethod
    def exception_case(exception_type, exception_pattern):
        """
        Decorator for the test case that is intended to throw
        exception.

            exception_type: type of exception
            exception_pattern: regex for inspecting exception message
                               using re.search
        """
        def wrapper(func):
            import functools
            @functools.wraps(func)
            def _wrapper(self):
                self.assertTrue(len(exception_pattern) > 0, msg='Internal Error')
                with self.assertRaises(exception_type) as ctx:
                    func(self)
                    self.fail(msg='The task must throw exception')
                the_exception = ctx.exception
                message = str(the_exception)
                self.assertIsNotNone(re.search(exception_pattern, message), msg='error message \'%s\' is not expected.'%(message))
            return _wrapper
        return wrapper

    @staticmethod
    def weight_case(func):
        import functools
        @functools.wraps(func)
        def wrapper(self):
            with sdutil.tbmanager(self.infile) as tb:
                for irow in range(tb.nrows()):
                    self.assertTrue(tb.iscelldefined('WEIGHT_SPECTRUM', irow))

            # weight mode flag
            self.weight_propagation = True

            func(self)

        return wrapper

    def run_test(self, *args, **kwargs):
        datacol_name = self.datacolumn.upper()
        weight_mode = hasattr(self, 'weight_propagation') and getattr(self, 'weight_propagation') is True

        if 'kwidth' in kwargs:
            kwidth = kwargs['kwidth']
        else:
            kwidth = 5

        self.result = sdsmooth(infile=self.infile, outfile=self.outfile, kernel='gaussian', datacolumn=self.datacolumn, **kwargs)

        # sanity check
        self.assertIsNone(self.result, msg='The task must complete without error')
        self.assertTrue(os.path.exists(self.outfile), msg='Output file is not properly created.')

        if 'spw' in kwargs:
            spw = kwargs['spw']
        else:
            spw = ''
        dd_selection = None
        if len(spw) == 0:
            expected_nrow = 2
            with sdutil.tbmanager(self.infile) as tb:
                data_in = tb.getvarcol(datacol_name)
                flag_in = tb.getvarcol('FLAG')
                if weight_mode is True:
                    weight_in = tb.getvarcol('WEIGHT_SPECTRUM')
        else:
            myms = ms()
            a = myms.msseltoindex(self.infile, spw=spw)
            spw_selection = a['spw']
            dd_selection = a['dd']
            expected_nrow = len(spw_selection)
            with sdutil.tbmanager(self.infile) as tb:
                try:
                    tsel = tb.query('DATA_DESC_ID IN %s'%(dd_selection.tolist()))
                    data_in = tsel.getvarcol(datacol_name)
                    flag_in = tsel.getvarcol('FLAG')
                    if weight_mode is True:
                        weight_in = tsel.getvarcol('WEIGHT_SPECTRUM')
                finally:
                    tsel.close()

        with sdutil.tbmanager(self.outfile) as tb:
            nrow = tb.nrows()
            data_out = tb.getvarcol(datacol_name)
            flag_out = tb.getvarcol('FLAG')
            if weight_mode is True:
                weight_out = tb.getvarcol('WEIGHT_SPECTRUM')

        # verify nrow
        self.assertEqual(nrow, expected_nrow, msg='Number of rows mismatch (expected %s actual %s)'%(expected_nrow, nrow))

        # verify data
        eps = 1.0e-6
        for key in data_out.keys():
            row_in = data_in[key]
            flg_in = flag_in[key]
            row_in[numpy.where(flg_in == True)] = 0.0
            row_out = data_out[key]
            self.assertEqual(row_in.shape, row_out.shape, msg='Shape mismatch in row %s'%(key))

            npol, nchan, _ = row_out.shape
            kernel_array = gaussian_kernel(nchan, kwidth)
            expected = numpy.convolve(row_in[0,:,0], kernel_array, mode='same')
            output = row_out[0,:,0]
            zero_index = numpy.where(numpy.abs(expected) <= eps)
            self.assertTrue(all(numpy.abs(output[zero_index]) < eps), msg='Failed to verify zero values: row %s'%(key))
            nonzero_index= numpy.where(numpy.abs(expected) > eps)
            diff = numpy.abs((output[nonzero_index] - expected[nonzero_index]) / expected[nonzero_index].max())
            #print diff
            #print output[nonzero_index]
            #print expected[nonzero_index]
            self.assertTrue(all(diff < eps), msg='Failed to verify nonzero values: row %s'%(key))
            #print 'row_in', row_in[0,:,0].tolist()
            #print 'gaussian', kernel_array.tolist()
            #print 'expected', expected.tolist()
            #print 'result', row_out[0,:,0].tolist()

            # weight check if this is weight test
            if weight_mode is True:
                #print 'Weight propagation test'
                wgt_in = weight_in[key]
                wgt_out = weight_out[key]
                wkwidth = int(kwidth + 0.5)
                wkwidth += (1 if wkwidth % 2 == 0 else 0)
                half_width = wkwidth // 2
                peak_chan = kernel_array.argmax()
                start_chan = peak_chan - half_width
                wkernel = kernel_array[start_chan:start_chan+wkwidth].copy()
                wkernel /= sum(wkernel)
                weight_expected = wgt_in.copy()
                for ichan in range(half_width, nchan-half_width):
                    s = numpy.zeros(npol, dtype=float)
                    for jchan in range(wkwidth):
                        s += wkernel[jchan] * wkernel[jchan] / wgt_in[:,ichan-half_width+jchan,0]
                    weight_expected[:,ichan,0] = 1.0 / s
                #print weight_expected[:,:10]
                diff = numpy.abs((wgt_out - weight_expected) / weight_expected)
                self.assertTrue(all(diff.flatten() < eps), msg='Failed to verify spectral weight: row %s'%(key))

    def _setUp(self, files, task):
        for f in files:
            if os.path.exists(f):
                shutil.rmtree(f)
            shutil.copytree(os.path.join(self.datapath, f), f)

        if not is_CASA6:
            default(task)

    def _tearDown(self, files):
        for f in files:
            if os.path.exists(f):
                shutil.rmtree(f)

    def setUp(self):
        self._setUp([self.infile], sdsmooth)

    def tearDown(self):
        self._tearDown([self.infile, self.outfile])

class sdsmooth_test_fail(sdsmooth_test_base):
    """
    Unit test for task sdsmooth.

    The list of tests:
    test_sdsmooth_fail01 --- default parameters (raises an error)
    test_sdsmooth_fail02 --- invalid kernel type
    test_sdsmooth_fail03 --- invalid selection (empty selection result)
    test_sdsmooth_fail04 --- outfile exists (overwrite=False)
    test_sdsmooth_fail05 --- empty outfile
    test_sdsmooth_fail06 --- invalid data column name
    """
    invalid_argument_case = sdsmooth_test_base.invalid_argument_case
    exception_case = sdsmooth_test_base.exception_case

    infile = sdsmooth_test_base.infile_data

    @invalid_argument_case
    def test_sdsmooth_fail01(self):
        """test_sdsmooth_fail01 --- default parameters (raises an error)"""
        # casatasks throw exceptions, CASA5 tasks return False
        if is_CASA6:
            self.assertRaises(Exception, sdsmooth)
        else:
            self.result = sdsmooth()

    @invalid_argument_case
    def test_sdsmooth_fail02(self):
        """test_sdsmooth_fail02 --- invalid kernel type"""
        # casatasks throw exceptions, CASA5 tasks return False
        if is_CASA6:
            self.assertRaises(Exception, sdsmooth, infile=self.infile, kernel='normal', outfile=self.outfile)
        else:
            self.result = sdsmooth(infile=self.infile, kernel='normal', outfile=self.outfile)

    @exception_case(RuntimeError, 'Spw Expression: No match found for 3')
    def test_sdsmooth_fail03(self):
        """test_sdsmooth_fail03 --- invalid selection (empty selection result)"""
        self.result = sdsmooth(infile=self.infile, kernel='gaussian', outfile=self.outfile, spw='3')

    @exception_case(Exception, 'sdsmooth_test\.ms_out exists\.')
    def test_sdsmooth_fail04(self):
        """test_sdsmooth_fail04 --- outfile exists (overwrite=False)"""
        shutil.copytree(self.infile, self.outfile)
        self.result = sdsmooth(infile=self.infile, kernel='gaussian', outfile=self.outfile, overwrite=False)

    @exception_case(Exception, 'outfile is empty\.')
    def test_sdsmooth_fail05(self):
        """test_sdsmooth_fail05 --- empty outfile"""
        self.result = sdsmooth(infile=self.infile, kernel='gaussian', outfile='')

    @invalid_argument_case
    def test_sdsmooth_fail06(self):
        """test_sdsmooth_fail06 --- invalid data column name"""
        # casatasks throw exceptions, CASA5 tasks return False
        if is_CASA6:
            self.assertRaises(Exception, sdsmooth, infile=self.infile, outfile=self.outfile, kernel='gaussian', datacolumn='spectra')
        else:
            self.result = sdsmooth(infile=self.infile, outfile=self.outfile, kernel='gaussian', datacolumn='spectra')


class sdsmooth_test_complex(sdsmooth_test_base):
    """
    Unit test for task sdsmooth. Process MS having DATA column.

    The list of tests:
    test_sdsmooth_complex_fail01 --- non-existing data column (FLOAT_DATA)
    test_sdsmooth_complex_gauss01 --- gaussian smoothing (kwidth 5)
    test_sdsmooth_complex_gauss02 --- gaussian smoothing (kwidth 3)
    test_sdsmooth_complex_select --- data selection (spw)
    test_sdsmooth_complex_overwrite --- overwrite existing outfile (overwrite=True)
    """
    exception_case = sdsmooth_test_base.exception_case
    infile = sdsmooth_test_base.infile_data
    datacolumn = 'data'

    @exception_case(RuntimeError, 'Desired column \(FLOAT_DATA\) not found in the input MS')
    def test_sdsmooth_complex_fail01(self):
        """test_sdsmooth_complex_fail01 --- non-existing data column (FLOAT_DATA)"""
        self.result = sdsmooth(infile=self.infile, outfile=self.outfile, kernel='gaussian', datacolumn='float_data')

    def test_sdsmooth_complex_gauss01(self):
        """test_sdsmooth_complex_gauss01 --- gaussian smoothing (kwidth 5)"""
        self.run_test(kwidth=5)

    def test_sdsmooth_complex_gauss02(self):
        """test_sdsmooth_complex_gauss02 --- gaussian smoothing (kwidth 3)"""
        self.run_test(kwidth=3)

    def test_sdsmooth_complex_select(self):
        """test_sdsmooth_complex_select --- data selection (spw)"""
        self.run_test(kwidth=5, spw='1')

    def test_sdsmooth_complex_overwrite(self):
        """test_sdsmooth_complex_overwrite --- overwrite existing outfile (overwrite=True)"""
        shutil.copytree(self.infile, self.outfile)
        self.run_test(kwidth=5, overwrite=True)

class sdsmooth_test_float(sdsmooth_test_base):
    """
    Unit test for task sdsmooth. Process MS having FLOAT_DATA column.

    The list of tests:
    test_sdsmooth_float_fail01 --- non-existing data column (DATA)
    test_sdsmooth_float_gauss01 --- gaussian smoothing (kwidth 5)
    test_sdsmooth_float_gauss02 --- gaussian smoothing (kwidth 3)
    test_sdsmooth_float_select --- data selection (spw)
    test_sdsmooth_float_overwrite --- overwrite existing outfile (overwrite=True)
    """
    exception_case = sdsmooth_test_base.exception_case
    infile = sdsmooth_test_base.infile_float
    datacolumn = 'float_data'

    @exception_case(RuntimeError, 'Desired column \(DATA\) not found in the input MS')
    def test_sdsmooth_float_fail01(self):
        """test_sdsmooth_complex_fail01 --- non-existing data column (DATA)"""
        self.result = sdsmooth(infile=self.infile, outfile=self.outfile, kernel='gaussian', datacolumn='data')

    def test_sdsmooth_float_gauss01(self):
        """test_sdsmooth_float_gauss01 --- gaussian smoothing (kwidth 5)"""
        self.run_test(kwidth=5)

    def test_sdsmooth_float_gauss02(self):
        """test_sdsmooth_float_gauss02 --- gaussian smoothing (kwidth 3)"""
        self.run_test(kwidth=3)

    def test_sdsmooth_float_select(self):
        """test_sdsmooth_float_select --- data selection (spw)"""
        self.run_test(kwidth=5, spw='1')

    def test_sdsmooth_float_overwrite(self):
        """test_sdsmooth_float_overwrite --- overwrite existing outfile (overwrite=True)"""
        shutil.copytree(self.infile, self.outfile)
        self.run_test(kwidth=5, overwrite=True)

class sdsmooth_test_weight(sdsmooth_test_base):
    """
    Unit test for task sdsmooth. Verify weight propagation.

    The list of tests:
    test_sdsmooth_weight_gauss01 --- gaussian smoothing (kwidth 5)
    test_sdsmooth_weight_gauss02 --- gaussian smoothing (kwidth 3)
    """
    weight_case = sdsmooth_test_base.weight_case
    infile = sdsmooth_test_base.infile_data
    datacolumn = 'data'

    def setUp(self):
        super(sdsmooth_test_weight, self).setUp()

        # initialize WEIGHT_SPECTRUM
        with sdutil.cbmanager(self.infile) as cb:
            cb.initweights()

    @weight_case
    def test_sdsmooth_weight_gauss01(self):
        """test_sdsmooth_weight_gauss01 --- gaussian smoothing (kwidth 5)"""
        self.run_test(kwidth=5)

    @weight_case
    def test_sdsmooth_weight_gauss02(self):
        """test_sdsmooth_weight_gauss02 --- gaussian smoothing (kwidth 3)"""
        self.run_test(kwidth=3)

class sdsmooth_test_boxcar(sdsmooth_test_base):
    """
    Unit test for checking boxcar smoothing.

    The input data (sdsmooth_delta.ms) has data with the following features:
      in row0, pol0: 1 at index 100, 0 elsewhere,
      in row0, pol1: 1 at index 0 and 2047(i.e., at both ends), 0 elsewhere,
      in row1, pol0: 1 at index 10 and 11, 0 elsewhere,
      in row1, pol1: 0 throughout.
    If input spectrum has delta-function-like feature, the
    expected output spectrum will be smoothing kernel itself.
    As for the data at [row0, pol0], the output data will be:
      kwidth==1 -> spec[100] = 1
      kwidth==2 -> spec[100,101] = 1/2 (=0.5)
      kwidth==3 -> spec[99,100,101] = 1/3 (=0.333...)
      kwidth==4 -> spec[99,100,101,102] = 1/4 (=0.25)
      kwidth==5 -> spec[98,99,100,101,102] = 1/5 (=0.2)
      and so on.
    """

    infile = 'tsdsmooth_delta.ms'
    datacolumn = 'float_data'
    centers = {'00': [100], '01': [0,2047], '10': [10,11], '11':[]}

    def _getLeftWidth(self, kwidth):
        assert(0 < kwidth)
        return (2-kwidth)//2

    def _getRightWidth(self, kwidth):
        assert(0 < kwidth)
        return kwidth//2

    def _checkResult(self, spec, kwidth, centers, tol=5.0e-06):
        sys.stdout.write('testing kernel_width = '+str(kwidth)+'...')
        for i in range(len(spec)):
            count = 0
            for j in range(len(centers)):
                lidx = centers[j] + self._getLeftWidth(kwidth)
                ridx = centers[j] + self._getRightWidth(kwidth)
                if (lidx <= i) and (i <= ridx): count += 1
            value = count/float(kwidth)
            self.assertTrue(((spec[i] - value) < tol), msg='Failed.')
        sys.stdout.write('OK.\n')

    def setUp(self):
        super(sdsmooth_test_boxcar, self).setUp()

    def test000(self):
        # testing kwidth from 1 to 5.
        for kwidth in range(1,6):
            result = sdsmooth(infile=self.infile, outfile=self.outfile,
                               datacolumn=self.datacolumn, overwrite=True,
                               kernel='boxcar', kwidth = kwidth)
            with sdutil.tbmanager(self.outfile) as tb:
                for irow in range(tb.nrows()):
                    spec = tb.getcell(self.datacolumn.upper(), irow)
                    for ipol in range(len(spec)):
                        center = self.centers[str(irow)+str(ipol)]
                        self._checkResult(spec[ipol], kwidth, center)

    def test000_datacolumn_uppercase(self):
        # testing kwidth from 1 to 5.
        datacolumn = "FLOAT_DATA"
        for kwidth in range(1,6):
            result = sdsmooth(infile=self.infile, outfile=self.outfile,
                               datacolumn=datacolumn, overwrite=True,
                               kernel='boxcar', kwidth = kwidth)
            with sdutil.tbmanager(self.outfile) as tb:
                for irow in range(tb.nrows()):
                    spec = tb.getcell(datacolumn.upper(), irow)
                    for ipol in range(len(spec)):
                        center = self.centers[str(irow)+str(ipol)]
                        self._checkResult(spec[ipol], kwidth, center)


class sdsmooth_selection(sdsmooth_test_base, unittest.TestCase):
    infile = "analytic_type1.sm.ms"
    outfile = "smoothed.ms"
    common_param = dict(infile=infile, outfile=outfile,
                        kernel='boxcar', kwidth=5)
    selections=dict(intent=("CALIBRATE_ATMOSPHERE#OFF*", [1]),
                    antenna=("DA99", [1]),
                    field=("M1*", [0]),
                    spw=(">6", [1]),
                    timerange=("2013/4/28/4:13:21",[1]),
                    scan=("0~8", [0]),
                    pol=("YY", [1]))
    verbose = False

    def _get_selection_string(self, key):
        if key not in self.selections.keys():
            raise ValueError("Invalid selection parameter %s" % key)
        return {key: self.selections[key][0]}

    def _get_selected_row_and_pol(self, key):
        if key not in self.selections.keys():
            raise ValueError("Invalid selection parameter %s" % key)
        pols = [0,1]
        rows = [0,1]
        if key == 'pol':  #self.selection stores pol ids
            pols = self.selections[key][1]
        else: #self.selection stores row ids
            rows = self.selections[key][1]
        return (rows, pols)

    def _get_reference(self, nchan, row_offset, pol_offset, datacol):
        if datacol.startswith("float"):
            col_offset = 10
        elif datacol.startswith("corr"):
            col_offset = 50
        else:
            raise ValueError("Got unexpected datacolumn.")
        spike_chan = col_offset + 20*row_offset + 10*pol_offset
        reference = numpy.zeros(nchan)
        reference[spike_chan-2:spike_chan+3] = 0.2
        if self.verbose: print("reference=%s" % str(reference))
        return reference

    def run_test(self, sel_param, datacolumn, reindex=True):
        inparams = self._get_selection_string(sel_param)
        inparams.update(self.common_param)
        sdsmooth(datacolumn=datacolumn, reindex=reindex, **inparams)
        self._test_result(inparams["outfile"], sel_param, datacolumn)

    def _test_result(self, msname, sel_param, dcol, atol=1.e-5, rtol=1.e-5):
        # Make sure output MS exists
        self.assertTrue(os.path.exists(msname), "Could not find output MS")
        # Compare output MS with reference (nrow, npol, and spectral values)
        (rowids, polids) = self._get_selected_row_and_pol(sel_param)
        if dcol.startswith("float"):
            testcolumn = "FLOAT_DATA"
        else: #output is in DATA column
            testcolumn = "DATA"
        tb.open(msname)
        try:
            self.assertEqual(tb.nrows(), len(rowids), "Row number is wrong %d (expected: %d)" % (tb.nrows(), len(rowids)))
            for out_row in range(len(rowids)):
                in_row = rowids[out_row]
                sp = tb.getcell(testcolumn, out_row)
                self.assertEqual(sp.shape[0], len(polids), "Number of pol is wrong in row=%d:  %d (expected: %d)" % (out_row,len(polids),sp.shape[0]))
                nchan = sp.shape[1]
                for out_pol in range(len(polids)):
                    in_pol = polids[out_pol]
                    reference = self._get_reference(nchan, in_row, in_pol, dcol)
                    if self.verbose: print("data=%s" % str(sp[out_pol]))
                    self.assertTrue(numpy.allclose(sp[out_pol], reference,
                                                   atol=atol, rtol=rtol),
                                    "Smoothed spectrum differs in row=%d, pol=%d" % (out_row, out_pol))
        finally:
            tb.close()


    def testIntentF(self):
        """Test selection by intent (float_data)"""
        self.run_test("intent", "float_data")

    def testIntentC(self):
        """Test selection by intent (corrected)"""
        self.run_test("intent", "corrected")

    def testAntennaF(self):
        """Test selection by antenna (float_data)"""
        self.run_test("antenna", "float_data")

    def testAntennaC(self):
        """Test selection by antenna (corrected)"""
        self.run_test("antenna", "corrected")

    def testFieldF(self):
        """Test selection by field (float_data)"""
        self.run_test("field", "float_data")

    def testFieldC(self):
        """Test selection by field (corrected)"""
        self.run_test("field", "corrected")

    def testSpwF(self):
        """Test selection by spw (float_data)"""
        self.run_test("spw", "float_data")

    def testSpwC(self):
        """Test selection by spw (corrected)"""
        self.run_test("spw", "corrected")

    def testTimerangeF(self):
        """Test selection by timerange (float_data)"""
        self.run_test("timerange", "float_data")

    def testTimerangeC(self):
        """Test selection by timerange (corrected)"""
        self.run_test("timerange", "corrected")

    def testScanF(self):
        """Test selection by scan (float_data)"""
        self.run_test("scan", "float_data")

    def testScanC(self):
        """Test selection by scan (corrected)"""
        self.run_test("scan", "corrected")

    def testPolF(self):
        """Test selection by pol (float_data)"""
        self.run_test("pol", "float_data")

    def testPolC(self):
        """Test selection by pol (corrected)"""
        self.run_test("pol", "corrected")

    def testReindexSpw(self):
        """Test reindex =T/F in spw selection"""
        outfile = self.common_param['outfile']
        for datacol in ['float_data', 'corrected']:
            print("Test: %s" % datacol.upper())
            for (reindex, ddid, spid) in zip([True, False], [0, 1], [0,7]):
                print("- reindex=%s" % str(reindex))
                self.run_test("spw", datacol, reindex=reindex)
                tb.open(outfile)
                try:
                    self.assertEqual(ddid, tb.getcell('DATA_DESC_ID', 0),
                                     "comparison of DATA_DESCRIPTION_ID failed.")
                finally: tb.close()
                tb.open(outfile+'/DATA_DESCRIPTION')
                try:
                    self.assertEqual(spid, tb.getcell('SPECTRAL_WINDOW_ID', ddid),
                                     "comparison of SPW_ID failed.")
                finally: tb.close()
                shutil.rmtree(outfile)

    def testReindexIntent(self):
        """Test reindex =T/F in intent selection"""
        outfile = self.common_param['outfile']
        for datacol in ['float_data', 'corrected']:
            print("Test: %s" % datacol.upper())
            for (reindex, idx) in zip([True, False], [0, 4]):
                print("- reindex=%s" % str(reindex))
                self.run_test("intent", datacol, reindex=reindex)
                tb.open(outfile)
                try:
                    self.assertEqual(idx, tb.getcell('STATE_ID', 0),
                                     "comparison of state_id failed.")
                finally: tb.close()
                shutil.rmtree(outfile)

def suite():
    return [sdsmooth_test_fail, sdsmooth_test_complex,
            sdsmooth_test_float, sdsmooth_test_weight,
            sdsmooth_test_boxcar, sdsmooth_selection]

if is_CASA6:
    if __name__ == '__main__':
        unittest.main()