#########################################################################
# test_task_simanalyze.py
#
# Copyright (C) 2020
# Associated Universities, Inc. Washington DC, USA.
#
# This script is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
#
# Based on the requirements listed in plone found here:
# https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.manipulation.statwt.html
#
# Test case: requirement
#
##########################################################################
import os
import sys
import shutil
import unittest
import numpy as np
import numpy.ma as ma

subdir = 'unittest/statwt/'

# https://stackoverflow.com/questions/52580105/exception-similar-to-modulenotfounderror-in-python-2-7
try:
    ModuleNotFoundError
except NameError:
    ModuleNotFoundError = ImportError

from casatools import ctsys, table, ms
from casatasks import statwt
datadir = ctsys.resolve(subdir)
mytb = table()
myms = ms()

src = os.path.join(datadir, 'ngc5921_small.statwt.ms')
vlass = os.path.join(datadir, 'test_vlass_subset.ms')

# Reference data location
refdir = datadir + 'statwt_reference/'

# rows and target_row are the row numbers from the subtable formed
# by the baseline query
# In the chan_flags, a value of False means the channel is good (not flagged)
# so should be used. It follows the convention of the FLAGS column in the MS.

# EVEN IF THIS IS NO LONGER USED BY THE TESTS, IT SHOULDN'T BE DELETED BECAUSE
# IT IS USEFUL IN SANTIFY CHECKING NEW TESTS
def get_weights(
    data, flags, chan_flags, exposures, combine_corr, target_exposure, chanbins,
    target_flags, wtrange
):  
    shape = data.shape
    ncorr_groups = 1 if combine_corr else shape[0]
    nchanbins = 1 if chanbins is None else len(chanbins)
    tchanbins = chanbins
    if nchanbins == 1:
        tchanbins = [[0, shape[1]]]
    ncorr = shape[0]
    weights = np.zeros([shape[0], shape[1]])
    wt = np.zeros(shape[0])
    nrows = data.shape[2]
    median_axis = 1 if ncorr_groups > 1 else None
    mod_flags = target_flags[:]
    if type(chan_flags) == type(None):
        myflags = flags[:]
    else:
        t_flags = np.expand_dims(np.expand_dims(chan_flags, 0), 2)
        myflags = np.logical_or(flags, t_flags)
    for corr in range(ncorr_groups):
        end_corr = corr + 1 if ncorr_groups > 1 else ncorr + 1
        for cb in tchanbins:
            var = variance(
                data[corr:end_corr, cb[0]:cb[1], :],
                myflags[corr:end_corr, cb[0]:cb[1], :], exposures
            )
            if flags[corr:end_corr, cb[0]:cb[1]].all():
                weights[corr:end_corr, cb[0]:cb[1]] = 0
                mod_flags[corr:end_corr, cb[0]:cb[1]] = True
            if var == 0:
                weights[corr:end_corr, cb[0]:cb[1]] = 0 
                mod_flags[corr:end_corr, cb[0]:cb[1]] = True
            else:
                weights[corr:end_corr, cb[0]:cb[1]] = target_exposure/var
            if type(wtrange) != type(None):
                condition = np.logical_or(
                    np.less(
                        weights[corr:end_corr, cb[0]:cb[1]], wtrange[0]
                    ),
                    np.greater(
                        weights[corr:end_corr, cb[0]:cb[1]], wtrange[1]
                    )
                )
                exp_condition = np.expand_dims(condition, 2)
                weights[corr:end_corr, cb[0]:cb[1]] = np.where(
                    condition, 0, weights[corr:end_corr, cb[0]:cb[1]]
                )
                mod_flags[corr:end_corr, cb[0]:cb[1]] = np.where(
                    exp_condition, True, mod_flags[corr:end_corr, cb[0]:cb[1]]
                )
            mweights = ma.array(
                weights[corr:end_corr, :],
                mask=mod_flags[corr:end_corr, :]
            )
            wt[corr:end_corr] = np.median(mweights, median_axis)
    mod_flags = np.where(np.expand_dims(weights, 2) == 0, True, mod_flags)
    return (weights, wt, mod_flags)

def variance(data, flags, exposures):
    if flags.all():
        return 0
    expo = ma.masked_array(np.resize(exposures, data.shape), mask=flags)
    d = ma.array(data, mask=flags)
    myreal = np.real(d)
    myimag = np.imag(d)
    mean_r = np.sum(expo*myreal)/np.sum(expo)
    mean_i = np.sum(expo*myimag)/np.sum(expo)
    var_r = np.sum(expo * (myreal - mean_r)*(myreal - mean_r))/d.count()
    var_i = np.sum(expo * (myimag - mean_i)*(myimag - mean_i))/d.count()
    return (var_r + var_i)/2

def _get_dst_cols(dst, other="", dodata=True):
    mytb.open(dst)
    wt = mytb.getcol("WEIGHT")
    wtsp = mytb.getcol("WEIGHT_SPECTRUM")
    flag = mytb.getcol("FLAG")
    frow = mytb.getcol("FLAG_ROW")
    if dodata:
        data = mytb.getcol("CORRECTED_DATA")
    if len(other) > 0:
        if type(other) == type([]):
            othercol = []
            for x in other:
                othercol.append(mytb.getcol(x))
        else:
            othercol = mytb.getcol(other)
    mytb.close()
    cols = [wt, wtsp, flag, frow]
    if dodata:
        cols.append(data)
    if len(other) > 0:
        if type(other) == type([]):
            for x in othercol:
                cols.append(x)
        else:
            cols.append(othercol)
    return cols

def _get_table_cols(mytb):
    times = mytb.getcol("TIME")
    wt = mytb.getcol("WEIGHT")
    wtsp = None if mytb.colnames().count('WEIGHT_SPECTRUM') == 0 \
        else mytb.getcol("WEIGHT_SPECTRUM")
    flag = mytb.getcol("FLAG")
    frow = mytb.getcol("FLAG_ROW")
    data_col_name = 'CORRECTED_DATA' \
        if mytb.colnames().count('CORRECTED_DATA') > 0 else 'DATA'
    data = mytb.getcol(data_col_name)
    sigma = mytb.getcol("SIGMA")
    sisp = None if mytb.colnames().count('SIGMA_SPECTRUM') == 0 \
        else mytb.getcol("SIGMA_SPECTRUM")
    return [times, wt, wtsp, flag, frow, data, sigma, sisp]

# per correlation
def _variance2(dr, di, flag, corr, row):
    fr = numpy.extract(numpy.logical_not(flag[corr,:,row]), dr[corr,:,row])
    fi = numpy.extract(numpy.logical_not(flag[corr,:,row]), di[corr,:,row])
    if len(fr) <= 1:
        return 0
    else:
        vr = numpy.var(fr, ddof=1)
        vi = numpy.var(fi, ddof=1)
        return 2/(vr + vi)

class statwt_test(unittest.TestCase):
    
    def tearDown(self):
        mytb.done()
        myms.done()
    
    def _check_weights(
        self, msname, row_to_rows, data_column, chan_flags, combine_corr,
        chanbins, wtrange
    ):
        if data_column.startswith('c'):
            col_data = 'CORRECTED_DATA'
            check_sigma = False
        elif data_column.startswith('d'):
            col_data = 'DATA'
            check_sigma = True
        else:
            raise Exception("Unhandled column spec " + data_column)
        for ant1 in range(10):
            for ant2 in range((ant1 + 1), 10):
                query_str = 'ANTENNA1=' + str(ant1) + ' AND ANTENNA2=' \
                     + str(ant2)
                tb.open(msname)
                subt = tb.query(query_str)
                data = subt.getcol(col_data)
                flags = subt.getcol('FLAG')
                exposures = subt.getcol('EXPOSURE')
                wt = subt.getcol('WEIGHT')
                wtsp = subt.getcol('WEIGHT_SPECTRUM')
                flag_row = subt.getcol('FLAG_ROW')
                if check_sigma:
                    sigma = subt.getcol('SIGMA')
                    sisp = subt.getcol('SIGMA_SPECTRUM')
                subt.done()
                tb.done()
                nrows = data.shape[2]
                for row in range(nrows):
                    start = row_to_rows[row][0]
                    end = row_to_rows[row][1]
                    (weights, ewt, mod_flags) = get_weights(
                        data[:,:,start:end], flags[:, :, start:end], chan_flags,
                        exposures[start: end], combine_corr, exposures[row],
                        chanbins, flags[:, :, row:row+1], wtrange
                    )
                    self.assertTrue(
                        np.allclose(weights, wtsp[:, :, row]),
                        'Failed wtsp, got ' + str(wtsp[:, :, row])
                        + '\nexpected ' + str(weights) + '\nbaseline '
                        + str([ant1, ant2]) + '\nrow ' + str(row)
                    )
                    self.assertTrue(
                        np.allclose(ewt, wt[:, row]),
                        'Failed weight, got ' + str(wt[:, row])
                        + '\nexpected ' + str(np.median(weights, 1))
                        + '\nbaseline ' + str([ant1, ant2]) + '\nrow '
                        + str(row)
                    )
                    self.assertTrue(
                        (mod_flags == np.expand_dims(flags[:, :, row], 2)).all(),
                        'Failed flag, got ' + str(flags[:, :, row])
                        + '\nexpected ' + str(mod_flags) + '\nbaseline '
                        + str([ant1, ant2]) + '\nrow ' + str(row)
                    )
                    eflag_row = mod_flags.all()
                    self.assertTrue(
                        (eflag_row == flag_row[row]).all(),
                        'Failed flag_row, got ' + str(flag_row[row])
                        + '\nexpected ' + str(eflag_row) + '\nbaseline '
                        + str([ant1, ant2]) + '\nrow ' + str(row)
                    )
                    # all flags must be True where wtsp = 0
                    self.assertTrue(np.extract(weights == 0, mod_flags).all())
                    if check_sigma:
                        esigma = np.where(ewt == 0, -1, 1/np.sqrt(ewt))
                        self.assertTrue(
                            np.allclose(esigma, sigma[:, row]),
                            'Failed sigma, got ' + str(sigma[:, row])
                            + '\nexpected ' + str(esigma)
                            + '\nbaseline ' + str([ant1, ant2]) + '\nrow '
                            + str(row)
                        )
                        esisp = np.where(weights == 0, -1, 1/np.sqrt(weights))
                        self.assertTrue(
                            np.allclose(esisp, sisp[:, :, row]),
                            'Failed sigma_spectrum, got ' + str(sisp[:, :, row])
                            + '\nexpected ' + str(esisp)
                            + '\nbaseline ' + str([ant1, ant2]) + '\nrow '
                            + str(row)
                        )
              
    def compare(self, dst, ref):
        self.assertTrue(mytb.open(dst), "Table open failed for " + dst)
        mytb1 = table()
        ref = os.path.join(refdir, ref)
        self.assertTrue(mytb1.open(ref), "Table open failed for " + ref)
        self.compareTables(mytb, mytb1)
        mytb.done()
        mytb1.done()
                        
    def compareTables(self, dst, ref):
        self.assertEqual(dst.nrows(), ref.nrows(), 'number of rows differ')
        [
            gtimes, gwt, gwtsp, gflag, gfrow, gdata, gsigma, gsisp
        ] = _get_table_cols(dst)
        [
            etimes, ewt, ewtsp, eflag, efrow, edata, esigma, esisp
        ] = _get_table_cols(ref)
        self.assertTrue(np.allclose(gwt, ewt), 'WEIGHT comparison failed')
        if type(ewtsp) != type(None) or type(gwtsp) != type(None):
            self.assertTrue(
                np.allclose(gwtsp, ewtsp), 'WEIGHT_SPECTRUM comparison failed'
            )
        self.assertTrue((gflag == eflag).all(), 'FLAG comparison failed')
        self.assertTrue((gfrow == efrow).all(), 'FLAG_ROW comparison failed')
        # all flags must be True where wtsp = 0
        self.assertTrue(np.extract(gwtsp == 0, gflag).all())
        self.assertTrue(np.allclose(gsigma, esigma), 'SIGMA comparison failed')
        if type(gsisp) != type(None) or type(esisp) != type(None):
            self.assertTrue(np.allclose(
                gsisp, esisp), 'SIGMA_SPECTRUM comparison failed'
            )

    def test_algorithm(self):
        """ Test the algorithm, includes excludechans tests"""
        dst = "ngc5921.split.ms"
        cflags = np.array(63 * [False])
        cflags[10:21] = True
        row_to_rows = []
        for row in range(60):
            row_to_rows.append((row, row+1))
        for combine in ["", "corr"]:
            c = 0
            for fitspw in ["0:0~9;21~62", "", "0:10~20"]:
                shutil.copytree(src, dst)
                excludechans = c == 2
                statwt(
                    dst, combine=combine, fitspw=fitspw,
                    excludechans=excludechans
                )
                chan_flags = cflags if fitspw else None
                if combine == '':
                    if fitspw == '':
                        ref = 'ngc5921_statwt_ref_test_algorithm_sep_corr_no_fitspw.ms'
                    else: 
                        ref = 'ngc5921_statwt_ref_test_algorithm_sep_corr_fitspw.ms'
                else:
                    if fitspw == '':
                        ref = 'ngc5921_statwt_ref_test_algorithm_combine_corr_no_fitspw.ms'
                    else:
                        ref = 'ngc5921_statwt_ref_test_algorithm_combine_corr_has_fitspw.ms'
                self.compare(dst, ref)
                shutil.rmtree(dst)
                c += 1               

    def test_timebin(self):
        """ Test time binning"""
        dst = "ngc5921.split.timebin.ms"
        combine = "corr"
        for timebin in ["300s", 10]:
            shutil.copytree(src, dst) 
            statwt(dst, timebin=timebin, combine=combine)
            ref = 'ngc5921_statwt_ref_test_timebin_' + str(timebin) + '.ms'
            self.compare(dst, ref)
            shutil.rmtree(dst)

    def test_chanbin(self):
        """Test channel binning"""
        dst = "ngc5921.split.chanbin_0.ms"
        row_to_rows = []
        for i in range(60):
            row_to_rows.append([i, i+1])
        bins = [
            [0, 8], [8, 16], [16, 24], [24, 32], [32, 40], [40, 48],
            [48, 56], [56,63]
        ]
        for combine in ["", "corr"]:
            for i in [0, 2]:
                for chanbin in ["195.312kHz", 8]:
                    if i == 2 and combine != '' and chanbin != 8:
                        # only run the check for i == 2 once
                        continue
                    shutil.copytree(src, dst)
                    if i == 0:
                        statwt(dst, chanbin=chanbin, combine=combine)
                    elif i == 2:
                        # check WEIGHT_SPECTRUM is created, only check once,
                        # this test is long as it is
                        mytb.open(dst, nomodify=False)
                        x = mytb.ncols()
                        self.assertTrue(
                            mytb.removecols("WEIGHT_SPECTRUM"),
                            "column not removed"
                        )
                        y = mytb.ncols()
                        self.assertTrue(y == x-1, "wrong number of columns")
                        mytb.done()
                        statwt(dst, chanbin=chanbin, combine=combine)
                    if combine == '':
                        ref = refdir + 'ngc5921_statwt_ref_test_chanbin_sep_corr.ms'
                    else:
                        ref = refdir + 'ngc5921_statwt_ref_test_chanbin_combine_corr.ms'
                    shutil.rmtree(dst)

    def test_minsamp(self):
        """Test minimum number of points"""
        dst = "ngc5921.split.minsamp.ms"
        combine = "corr"
        trow = 12
        for minsamp in [60, 80]:
            shutil.copytree(src, dst)
            statwt(dst, minsamp=minsamp, combine=combine)
            [wt, wtsp, flag, frow, data] = _get_dst_cols(dst)
            if minsamp == 60:
                self.assertTrue(
                    (wt[:, trow] > 0).all(), "Incorrect weight row " + str(trow)
                )
                self.assertTrue(
                    (wtsp[:, :, trow] > 0).all(),
                    "Incorrect weight spectrum row " + str(trow)
                )
                self.assertFalse(
                    flag[:,:,trow].all(), "Incorrect flag row " + str(trow)
                )
                self.assertFalse(
                    frow[trow], "Incorrect flagrow row " + str(trow)
                )
            else:
                self.assertTrue(
                    (wt[:, trow] == 0).all(),
                    "Incorrect weight row " + str(trow)
                )
                self.assertTrue(
                    (wtsp[:, :, trow] == 0).all(),
                    "Incorrect weight spectrum row " + str(trow)
                )
                self.assertTrue(
                    flag[:,:,trow].all(), "Incorrect flag row " + str(trow)
                )
                self.assertTrue(
                    frow[trow], "Incorrect flagrow row " + str(trow)
                )
            shutil.rmtree(dst)
            
    def test_fieldsel(self):
        """Test field selection"""
        dst = "ngc5921.split.fieldsel.ms"
        combine = "corr"
        ref = 'ngc5921_statwt_ref_test_fieldsel.ms'
        for field in ["2", "N5921_2"]:
            shutil.copytree(src, dst)
            statwt(dst, field=field, combine=combine)
            self.compare(dst, ref)
            shutil.rmtree(dst)
          
    def test_spwsel(self):
        """Test spw selection"""
        dst = "ngc5921.split.spwsel.ms"
        ref = 'ngc5921_statwt_ref_test_algorithm_combine_corr_no_fitspw.ms'
        combine = "corr"
        spw="0"
        # data set only has one spw
        shutil.copytree(src, dst)
        statwt(dst, spw=spw, combine=combine)
        self.compare(dst, ref)
        shutil.rmtree(dst)

    def test_scansel(self):
        """CAS-11858 Test scan selection"""
        dst = "ngc5921.split.scansel.ms"
        ref = 'ngc5921_statwt_ref_test_scansel.ms'
        combine = "corr"
        [origwt, origwtsp, origflag, origfrow, origdata] = _get_dst_cols(src)
        scan = "5"
        shutil.copytree(src, dst)
        statwt(dst, scan=scan, combine=combine)
        self.compare(dst, ref)
        shutil.rmtree(dst)
        
    def test_default_boundaries(self):
        """Test default scan, field, etc boundaries"""
        dst = "ngc5921.split.normalbounds.ms"
        ref = 'ngc5921_statwt_ref_test_default_boundaries.ms'
        timebin = "6000s"
        # there are three field_ids, and there is a change in field_id when
        # there is a change in scan number, so specifying combine="field" in the
        # absence of "scan" will give the same result as combine=""
        row_to_rows = []
        for i in range(12):
            row_to_rows.append([0, 12])
        for i in range(12, 17):
            row_to_rows.append([12, 17])
        for i in range(17, 33):
            row_to_rows.append([17, 33])
        for i in range(33, 35):
            row_to_rows.append([33, 35])
        for i in range(35, 38):
            row_to_rows.append([35, 38])
        for i in range(38, 56):
            row_to_rows.append([38, 56])
        for i in range(56, 60):
            row_to_rows.append([56, 60])
        for combine in ["corr", "corr,field"]:
            shutil.copytree(src, dst)
            statwt(dst, timebin=timebin, combine=combine)
            self.compare(dst, ref)
            shutil.rmtree(dst)
            
    def test_no_scan_boundaries(self):
        """Test no scan boundaries"""
        dst = "ngc5921.no_scan_bounds.ms"
        timebin = "6000s"
        ref = os.path.join(refdir, 'ngc5921_statwt_ref_test_no_scan_bounds.ms')
        combine = "corr, scan"
        shutil.copytree(src, dst)
        statwt(dst, timebin=timebin, combine=combine)
        self.compare(dst, ref)
        shutil.rmtree(dst)
    
    def test_no_scan_nor_field_boundaries(self):
        """Test no scan nor field boundaries"""
        dst = "ngc5921.no_scan_nor_field_bounds.ms"
        timebin = "6000s"
        ref = os.path.join(refdir, 'ngc5921_statwt_ref_test_no_scan_nor_field_bounds.ms')
        for combine in ["corr,scan,field", "corr,field,scan"]:
            shutil.copytree(src, dst)
            statwt(dst, timebin=timebin, combine=combine)
            self.compare(dst, ref)
            shutil.rmtree(dst)
                
    def test_statalg(self):
        """Test statalg"""
        # just testing inputs
        dst = "ngc5921.split.statalg.ms"
        for statalg in ["cl", "ch", "h", "f", "bogus"]:
            shutil.copytree(src, dst)
            if statalg == "cl":
                statwt(vis=dst, statalg=statalg)
            elif statalg == "ch":
                statwt(vis=dst, statalg=statalg, zscore=5, maxiter=3)
            elif statalg == "h":
                statwt(vis=dst, statalg=statalg, fence=0.2)
            elif statalg == "f":
                statwt(vis=dst, statalg=statalg, center="median",
                       lside=False)
            elif statalg == "bogus":
                self.assertRaises(
                        RuntimeError, statwt, vis=dst, statalg=statalg
                )

            shutil.rmtree(dst)
                
    def test_wtrange(self):
        """Test weight range"""
        dst = "ngc5921.wtrange.split.timebin.ms"
        ref = "ngc5921_statwt_ref_test_wtrange_300s.ms"
        combine = "corr"
        timebin = "300s"
        wtrange = [1, 2]
        """
        row_to_rows = []
        for i in range(10):
            row_to_rows.append([0, 10])
        for i in range(2):
            row_to_rows.append([10, 12])
        for i in range(5):
            row_to_rows.append([12, 17])
        for i in range(5):
            row_to_rows.append([17, 22])
        for i in range(5):
            row_to_rows.append([22, 27])
        for i in range(5):
            row_to_rows.append([27, 32])
        for i in range(1):
            row_to_rows.append([32, 33])
        for i in range(2):
            row_to_rows.append([33, 35])
        for i in range(3):
            row_to_rows.append([35, 38])
        for i in range(5):
            row_to_rows.append([38, 43])
        for i in range(5):
            row_to_rows.append([43, 48])
        for i in range(5):
            row_to_rows.append([48, 53])
        for i in range(3):
            row_to_rows.append([53, 56])
        for i in range(4):
            row_to_rows.append([56, 60])
        """
        for i in [0, 1]:
            shutil.copytree(src, dst) 
            statwt(dst, timebin=timebin, combine=combine, wtrange=wtrange)
            self.compare(dst, ref)
            # self._check_weights(
            #    dst, row_to_rows, 'c', None, True, None, wtrange
            # )
            shutil.rmtree(dst)

    def test_preview(self):
        """ Test preview mode"""
        dst = "ngc5921.split.preview.ms"
        [refwt, refwtsp, refflag, reffrow, refdata] = _get_dst_cols(src)
        combine = "corr"
        timebin = "300s"
        wtrange = [1, 2]
        preview = True
        shutil.copytree(src, dst)
        statwt(
            dst, timebin=timebin, combine=combine, wtrange=wtrange,
            preview=preview
        )
        [tstwt, tstwtsp, tstflag, tstfrow, tstdata] = _get_dst_cols(dst)
        self.assertTrue(np.all(tstflag == refflag), "FLAGs don't match")
        self.assertTrue(np.all(tstfrow == reffrow), "FLAG_ROWs don't match")
        self.assertTrue(
            np.all(np.isclose(tstwt, refwt)), "WEIGHTs don't match"
        )
        self.assertTrue(
            np.all(np.isclose(tstwtsp, refwtsp)), "WEIGHT_SPECTRUMs don't match"
        )
        shutil.rmtree(dst)

    def test_data_col(self):
        """Test using data column"""
        dst = "ngc5921.split.data.ms"
        ref = 'ngc5921_statwt_ref_test_data_col.ms'
        combine = "corr"
        timebin = 1
        data = "data"
        """
        row_to_rows = []
        for i in range(60):
            row_to_rows.append([i, i+1])
        """
        shutil.copytree(src, dst)
        self.assertTrue(mytb.open(dst, nomodify=False))
        self.assertTrue(mytb.removecols("DATA"))
        self.assertTrue(mytb.renamecol("CORRECTED_DATA", "DATA"))
        mytb.done()
        statwt(dst, timebin=timebin, combine=combine, datacolumn=data)
        # self._check_weights(dst, row_to_rows, 'd', None, True, None, None)
        self.compare(dst, ref)
        shutil.rmtree(dst)

    def test_sliding_time_window(self):
        """Test sliding time window"""
        dst = "ngc5921.split.sliding_time_window.ms"
        ref = 'ngc5921_statwt_ref_test_sliding_time_window.ms'
        timebin = "300s"
        """
        row_to_rows = []
        row_to_rows.append([0, 6])
        row_to_rows.append([0, 7])
        row_to_rows.append([0, 8])
        row_to_rows.append([0, 9])
        row_to_rows.append([0, 9])
        row_to_rows.append([0, 10])
        row_to_rows.append([1, 12])
        row_to_rows.append([2, 12])
        row_to_rows.append([3, 12])
        row_to_rows.append([5, 12])
        row_to_rows.append([6, 12])
        row_to_rows.append([6, 12])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([17, 20])
        row_to_rows.append([17, 21])
        row_to_rows.append([17, 22])
        row_to_rows.append([18, 23])
        row_to_rows.append([19, 24])
        row_to_rows.append([20, 25])
        row_to_rows.append([21, 26])
        row_to_rows.append([22, 27])
        row_to_rows.append([23, 28])
        row_to_rows.append([24, 29])
        row_to_rows.append([25, 30])
        row_to_rows.append([26, 31])
        row_to_rows.append([27, 32])
        row_to_rows.append([28, 33])
        row_to_rows.append([29, 33])
        row_to_rows.append([30, 33])
        row_to_rows.append([33, 35])
        row_to_rows.append([33, 35])
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        row_to_rows.append([38, 41])
        row_to_rows.append([38, 42])
        row_to_rows.append([38, 43])
        row_to_rows.append([39, 44])
        row_to_rows.append([40, 45])
        row_to_rows.append([41, 46])
        row_to_rows.append([42, 47])
        row_to_rows.append([43, 48])
        row_to_rows.append([44, 49])
        row_to_rows.append([45, 50])
        row_to_rows.append([46, 51])
        row_to_rows.append([47, 52])
        row_to_rows.append([48, 53])
        row_to_rows.append([49, 54])
        row_to_rows.append([50, 55])
        row_to_rows.append([51, 56])
        row_to_rows.append([52, 56])
        row_to_rows.append([53, 56])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        """
        shutil.copytree(src, dst)
        statwt(dst, timebin=timebin, slidetimebin=True)
        # self._check_weights(
        #    dst, row_to_rows, 'c', None, False, None, None
        # )
        self.compare(dst, ref)
        shutil.rmtree(dst)
        
    def test_sliding_window_timebin_int(self):
        """Test sliding window with timebin as int specified"""
        dst = "ngc5921.split.sliding_time_window.ms"
        ref = 'ngc5921_statwt_ref_test_sliding_time_window.ms'
        row_to_rows = []
        # odd int, timebin = 5
        row_to_rows.append([0, 5])
        row_to_rows.append([0, 5])
        row_to_rows.append([0, 5])
        row_to_rows.append([1, 6])
        row_to_rows.append([2, 7])
        row_to_rows.append([3, 8])
        row_to_rows.append([4, 9])
        row_to_rows.append([5, 10])
        row_to_rows.append([6, 11])
        row_to_rows.append([7, 12])
        row_to_rows.append([7, 12])
        row_to_rows.append([7, 12])
        
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        
        row_to_rows.append([17, 22])
        row_to_rows.append([17, 22])
        row_to_rows.append([17, 22])
        row_to_rows.append([18, 23])
        row_to_rows.append([19, 24])
        row_to_rows.append([20, 25])
        row_to_rows.append([21, 26])
        row_to_rows.append([22, 27])
        row_to_rows.append([23, 28])
        row_to_rows.append([24, 29])
        row_to_rows.append([25, 30])
        row_to_rows.append([26, 31])
        row_to_rows.append([27, 32])
        row_to_rows.append([28, 33])
        row_to_rows.append([28, 33])
        row_to_rows.append([28, 33])
        
        row_to_rows.append([33, 35])
        row_to_rows.append([33, 35])
        
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        
        row_to_rows.append([38, 43])
        row_to_rows.append([38, 43])
        row_to_rows.append([38, 43])
        row_to_rows.append([39, 44])
        row_to_rows.append([40, 45])
        row_to_rows.append([41, 46])
        row_to_rows.append([42, 47])
        row_to_rows.append([43, 48])
        row_to_rows.append([44, 49])
        row_to_rows.append([45, 50])
        row_to_rows.append([46, 51])
        row_to_rows.append([47, 52])
        row_to_rows.append([48, 53])
        row_to_rows.append([49, 54])
        row_to_rows.append([50, 55])
        row_to_rows.append([51, 56])
        row_to_rows.append([51, 56])
        row_to_rows.append([51, 56])
        
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])

    def test_sliding_window_timebin_int(self):
        """Test sliding window with timebin as int specified"""
        dst = "ngc5921.split.sliding_time_window.ms"
        # row_to_rows = []
        """
        # odd int, timebin = 5
        row_to_rows.append([0, 5])
        row_to_rows.append([0, 5])
        row_to_rows.append([0, 5])
        row_to_rows.append([1, 6])
        row_to_rows.append([2, 7])
        row_to_rows.append([3, 8])
        row_to_rows.append([4, 9])
        row_to_rows.append([5, 10])
        row_to_rows.append([6, 11])
        row_to_rows.append([7, 12])
        row_to_rows.append([7, 12])
        row_to_rows.append([7, 12])
        
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        
        row_to_rows.append([17, 22])
        row_to_rows.append([17, 22])
        row_to_rows.append([17, 22])
        row_to_rows.append([18, 23])
        row_to_rows.append([19, 24])
        row_to_rows.append([20, 25])
        row_to_rows.append([21, 26])
        row_to_rows.append([22, 27])
        row_to_rows.append([23, 28])
        row_to_rows.append([24, 29])
        row_to_rows.append([25, 30])
        row_to_rows.append([26, 31])
        row_to_rows.append([27, 32])
        row_to_rows.append([28, 33])
        row_to_rows.append([28, 33])
        row_to_rows.append([28, 33])
        
        row_to_rows.append([33, 35])
        row_to_rows.append([33, 35])
        
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        
        row_to_rows.append([38, 43])
        row_to_rows.append([38, 43])
        row_to_rows.append([38, 43])
        row_to_rows.append([39, 44])
        row_to_rows.append([40, 45])
        row_to_rows.append([41, 46])
        row_to_rows.append([42, 47])
        row_to_rows.append([43, 48])
        row_to_rows.append([44, 49])
        row_to_rows.append([45, 50])
        row_to_rows.append([46, 51])
        row_to_rows.append([47, 52])
        row_to_rows.append([48, 53])
        row_to_rows.append([49, 54])
        row_to_rows.append([50, 55])
        row_to_rows.append([51, 56])
        row_to_rows.append([51, 56])
        row_to_rows.append([51, 56])
        
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        """
        """
        # even timebin = 6
        row_to_rows.append([0, 6])
        row_to_rows.append([0, 6])
        row_to_rows.append([0, 6])
        row_to_rows.append([1, 7])
        row_to_rows.append([2, 8])
        row_to_rows.append([3, 9])
        row_to_rows.append([4, 10])
        row_to_rows.append([5, 11])
        row_to_rows.append([6, 12])
        row_to_rows.append([6, 12])
        row_to_rows.append([6, 12])
        row_to_rows.append([6, 12])
        
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        row_to_rows.append([12, 17])
        
        row_to_rows.append([17, 23])
        row_to_rows.append([17, 23])
        row_to_rows.append([17, 23])
        row_to_rows.append([18, 24])
        row_to_rows.append([19, 25])
        row_to_rows.append([20, 26])
        row_to_rows.append([21, 27])
        row_to_rows.append([22, 28])
        row_to_rows.append([23, 29])
        row_to_rows.append([24, 30])
        row_to_rows.append([25, 31])
        row_to_rows.append([26, 32])
        row_to_rows.append([27, 33])
        row_to_rows.append([27, 33])
        row_to_rows.append([27, 33])
        row_to_rows.append([27, 33])
        
        row_to_rows.append([33, 35])
        row_to_rows.append([33, 35])
        
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        row_to_rows.append([35, 38])
        
        row_to_rows.append([38, 44])
        row_to_rows.append([38, 44])
        row_to_rows.append([38, 44])
        row_to_rows.append([39, 45])
        row_to_rows.append([40, 46])
        row_to_rows.append([41, 47])
        row_to_rows.append([42, 48])
        row_to_rows.append([43, 49])
        row_to_rows.append([44, 50])
        row_to_rows.append([45, 51])
        row_to_rows.append([46, 52])
        row_to_rows.append([47, 53])
        row_to_rows.append([48, 54])
        row_to_rows.append([49, 55])
        row_to_rows.append([50, 56])
        row_to_rows.append([50, 56])
        row_to_rows.append([50, 56])
        row_to_rows.append([50, 56])
        
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        row_to_rows.append([56, 60])
        """

        for timebin in [5, 6]:
            ref = 'ngc5921_statwt_ref_test_sliding_time_window_' + str(timebin) + '.ms'
            shutil.copytree(src, dst)
            statwt(dst, timebin=timebin, slidetimebin=True)
            #self._check_weights(
            #    dst, row_to_rows, 'c', None, False, None, None
            #)
            self.compare(dst, ref)
            shutil.rmtree(dst)

    def test_residual(self):
        """ Test using corrected_data - model_data column"""
        dst = "ngc5921.split.residualwmodel.ms"
        ref = 'ngc5921_statwt_ref_test_residual.ms'
        data = "residual"
        # row_to_rows = []
        # for i in range(60):
        #    row_to_rows.append([i, i+1])
        shutil.copytree(src, dst)
        statwt(dst, datacolumn=data)
        # self._check_weights(
        #    dst, row_to_rows, data, None, False, None, None
        # )
        self.compare(dst, ref)
        shutil.rmtree(dst)
            
    def test_residual_no_model(self):
        """Test datacolumn='residual' in the absence of a MODEL_DATA column"""
        dst = "ngc5921.split.residualwoutmodel.ms"
        ref = 'ngc5921_statwt_ref_test_residual_no_model.ms'
        data = "residual"
        shutil.copytree(src, dst)
        self.assertTrue(mytb.open(dst, nomodify=False))
        self.assertTrue(mytb.removecols("MODEL_DATA"))
        mytb.done()
        statwt(dst, datacolumn=data)
        # self._check_weights(
        #    dst, row_to_rows, data, None, False, None, None
        # )
        self.compare(dst, ref)
        shutil.rmtree(dst)

    def test_residual_data(self):
        """Test using data - model_data column"""
        dst = "ngc5921.split.residualdatawmodel.ms"
        ref = 'ngc5921_statwt_ref_test_residual_data.ms'
        data = "residual_data"
        # row_to_rows = []
        # for i in range(60):
        #     row_to_rows.append([i, i+1])
        shutil.copytree(src, dst)
        statwt(dst, datacolumn=data)
        # self._check_weights(
        #    dst, row_to_rows, data, None, False, None, None
        # )
        self.compare(dst, ref)
        shutil.rmtree(dst)
        
    def test_residual_data_no_model(self):
        """Test using residual data in absence of MODEL_DATA"""
        dst = "ngc5921.split.residualdatawoutmodel.ms"
        ref = 'ngc5921_statwt_ref_test_residual_data_no_model.ms'
        data = "residual_data"
        # row_to_rows = []
        # for i in range(60):
        #     row_to_rows.append([i, i+1])
        shutil.copytree(src, dst)
        self.assertTrue(mytb.open(dst, nomodify=False))
        self.assertTrue(mytb.removecols("MODEL_DATA"))
        mytb.done()
        statwt(dst, datacolumn=data)
        # self._check_weights(
        #     dst, row_to_rows, data, None, False, None, None
        # )
        self.compare(dst, ref)
        shutil.rmtree(dst)

    def test_returned_stats(self):
        """ Test returned stats, CAS-10881"""
        dst = "ngc5921.split.statstest.ms"
        shutil.copytree(src, dst)
        res = statwt(dst)
        self.assertTrue(
            np.isclose(res['mean'], 3.691224144843796),
            "mean is incorrect"
        )
        self.assertTrue(
            np.isclose(res['variance'], 6.860972180192186),
            "variance is incorrect"
        )
        shutil.rmtree(dst)
        
    def test_multi_spw_no_spectrum_columns(self):
        "Test multi spw with no sigma nor weight spectrum columns works"
        for tb in [1, "5s"]:
            dst = "statwt_test_vlass_timebin" + str(tb) + ".ms"
            shutil.copytree(vlass, dst)
            res = statwt(
                vis=dst, combine='scan,field,state', timebin=tb,
                datacolumn='residual_data'
            )
            ref = 'test_vlass_timebin' + str(tb) + '.ms'
            self.compare(dst, ref)
            shutil.rmtree(dst)

    def test_chanbin_multi_spw_no_spectrum_columns(self):
        """
        Test specifying chanbin when multi spw with no sigma nor weight
        spectrum columns works
        """
        ref = refdir + 'ref_vlass_wtsp_creation.ms'
        for spw in ["", "0"]:
            dst = "statwt_test_vlass_spw_select_" + str(spw) + ".ms"
            shutil.copytree(vlass, dst)
            if spw == '':
                try:
                    statwt(
                        vis=dst, combine='scan,field,state', chanbin=1,
                        timebin='1yr', datacolumn='residual_data',
                        selectdata=True, spw=spw
                    )
                except Exception:
                    self.fail()
                self.compare(dst, ref)
            else:
                # Currently there is a bug which requires statwt to be run twice
                self.assertRaises(
                        RuntimeError, statwt, vis=dst, combine='scan,field,state',
                        chanbin=1, timebin='1yr', datacolumn='residual_data',
                        selectdata=True, spw=spw
                    )

                res = statwt(
                    vis=dst, combine='scan,field,state', chanbin=1,
                    timebin='1yr', datacolumn='residual_data',
                    selectdata=True, spw=spw
                )
                self.assertTrue(res)
                mytb.open(ref)
                reftab = mytb.query("DATA_DESC_ID == 0")
                mytb.done()
                mytb.open(dst)
                dsttab = mytb.query("DATA_DESC_ID == 0")
                mytb.done()
                self.compareTables(dsttab, reftab)
                reftab.done()
                dsttab.done()
            shutil.rmtree(dst)

if __name__ == '__main__':
    unittest.main()