#########################################################################
# test_task_testconcat.py
#
# Copyright (C) 2018
# Associated Universities, Inc. Washington DC, USA
#
# This script is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# [Add the link to the JIRA ticket here once it exists]
#
# Based on the requirements listed in plone found here:
#
#
##########################################################################
import os
import sys
import shutil
import glob
import unittest

from casatools import ctsys
from casatools import table as tbtool
from casatools import ms as mstool
from casatasks import split, testconcat

tb = tbtool( )
ms = mstool( )

datapath = ctsys.resolve('unittest/testconcat/')

myname = 'test_testconcat'

# name of the resulting MS
msname = 'testconcatenated.ms'

def checktable(thename, theexpectation):
    global msname, myname
    tb.open(msname+"/"+thename)
    for mycell in theexpectation:
        print(myname, ": comparing ", mycell)
        value = tb.getcell(mycell[0], mycell[1])
        # see if value is array
        try:
            isarray = value.__len__
        except:
            # it's not an array
            # zero tolerance?
            if mycell[3] == 0:
                in_agreement = (value == mycell[2])
            else:
                in_agreement = ( abs(value - mycell[2]) < mycell[3]) 
        else:
            # it's an array
            # zero tolerance?
            if mycell[3] == 0:
                in_agreement =  (value == mycell[2]).all() 
            else:
                try:
                    in_agreement = (abs(value - mycell[2]) < mycell[3]).all()
                except:
                    in_agreement = False
        if not in_agreement:
            print(myname, ":  Error in MS subtable", thename, ":")
            print("     column ", mycell[0], " row ", mycell[1], " contains ", value)
            print("     expected value is ", mycell[2])
            tb.close()
            return False
    tb.close()
    print(myname, ": table ", thename, " as expected.")
    return True


###########################
# beginning of actual test 

class test_testconcat(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.myinputlist = []
        cpath = os.path.abspath(os.curdir)
        filespresent = sorted(glob.glob("*.ms"))
        os.chdir(datapath)
        for mymsname in sorted(glob.glob("*.ms")):
            if not mymsname in filespresent:
                print("Copying ", mymsname)
                shutil.copytree(mymsname, cpath+'/'+mymsname)
                cls.myinputlist.append(mymsname)
        os.chdir(cpath)

    def setUp(self):
        res = None
        self.tempname = ''

    def tearDown(self):
        shutil.rmtree(msname,ignore_errors=True)
        if os.path.exists(self.tempname): shutil.rmtree(self.tempname)

    @classmethod
    def tearDownClass(cls):
        for ff in cls.myinputlist:
            shutil.rmtree(ff)

    def test1(self):
        '''Testconcat 1: 4 parts, same sources but different spws'''
        self.tempname = self._testMethodName + '.ms'
        retValue = {'success': True, 'msgs': "", 'error_msgs': '' }    
        
        self.res = testconcat(vis=['part1.ms','part2.ms','part3.ms','part4.ms'],testconcatvis=msname)
        self.assertEqual(self.res,None)

        print(myname, ": Success! Now checking output ...")
        mscomponents = set(["table.dat",
                            "table.f0",
                            "table.f1",
                            "table.f2",
                            "table.f3",
                            "table.f4",
                            "table.f5",
                            "table.f6",
                            "table.f7",
                            "table.f8",
                            "ANTENNA/table.dat",
                            "DATA_DESCRIPTION/table.dat",
                            "FEED/table.dat",
                            "FIELD/table.dat",
                            "FLAG_CMD/table.dat",
                            "HISTORY/table.dat",
                            "OBSERVATION/table.dat",
                            "POINTING/table.dat",
                            "POLARIZATION/table.dat",
                            "PROCESSOR/table.dat",
                            "SOURCE/table.dat",
                            "SPECTRAL_WINDOW/table.dat",
                            "STATE/table.dat",
                            "ANTENNA/table.f0",
                            "DATA_DESCRIPTION/table.f0",
                            "FEED/table.f0",
                            "FIELD/table.f0",
                            "FLAG_CMD/table.f0",
                            "HISTORY/table.f0",
                            "OBSERVATION/table.f0",
                            "POINTING/table.f0",
                            "POLARIZATION/table.f0",
                            "PROCESSOR/table.f0",
                            "SOURCE/table.f0",
                            "SPECTRAL_WINDOW/table.f0",
                            "STATE/table.f0"
                            ])
        for name in mscomponents:
            if not os.access(msname+"/"+name, os.F_OK):
                print(myname, ": Error  ", msname+"/"+name, "doesn't exist ...")
                retValue['success']=False
                retValue['error_msgs']=retValue['error_msgs']+msname+'/'+name+' does not exist'
            else:
                print(myname, ": ", name, "present.")
        print(myname, ": pseudo-MS exists. All tables present.")
        if self.tempname in glob.glob("*.ms"):
            shutil.rmtree(self.tempname,ignore_errors=True)
        shutil.copytree(msname,self.tempname)
        print(myname, ": OK. Checking tables in detail ...")
        retValue['success']=True

        # check source table
        name = "SOURCE"
        #             col name, row number, expected value, tolerance
        expected = [
            ['SOURCE_ID',           55, 13, 0],
            ['SPECTRAL_WINDOW_ID',  55, 3, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        # check spw table
        name = "SPECTRAL_WINDOW"
        #             col name, row number, expected value, tolerance
        expected = [
            ['NUM_CHAN',           3, 128, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'


    def test2(self):
        '''Testconcat 2: 3 parts, different sources, different spws '''
        self.tempname = self._testMethodName + '.ms'
        retValue = {'success': True, 'msgs': "", 'error_msgs': '' }    
        self.res = testconcat(vis=['part1.ms','part2-mod.ms','part3.ms'],testconcatvis=msname)
        self.assertEqual(self.res,None)
        
        print(myname, ": Success! Now checking output ...")
        mscomponents = set(["table.dat",
                            "table.f0",
                            "table.f1",
                            "table.f2",
                            "table.f3",
                            "table.f4",
                            "table.f5",
                            "table.f6",
                            "table.f7",
                            "table.f8",
                            "ANTENNA/table.dat",
                            "DATA_DESCRIPTION/table.dat",
                            "FEED/table.dat",
                            "FIELD/table.dat",
                            "FLAG_CMD/table.dat",
                            "HISTORY/table.dat",
                            "OBSERVATION/table.dat",
                            "POINTING/table.dat",
                            "POLARIZATION/table.dat",
                            "PROCESSOR/table.dat",
                            "SOURCE/table.dat",
                            "SPECTRAL_WINDOW/table.dat",
                            "STATE/table.dat",
                            "ANTENNA/table.f0",
                            "DATA_DESCRIPTION/table.f0",
                            "FEED/table.f0",
                            "FIELD/table.f0",
                            "FLAG_CMD/table.f0",
                            "HISTORY/table.f0",
                            "OBSERVATION/table.f0",
                            "POINTING/table.f0",
                            "POLARIZATION/table.f0",
                            "PROCESSOR/table.f0",
                            "SOURCE/table.f0",
                            "SPECTRAL_WINDOW/table.f0",
                            "STATE/table.f0"
                            ])
        for name in mscomponents:
            if not os.access(msname+"/"+name, os.F_OK):
                print(myname, ": Error  ", msname+"/"+name, "doesn't exist ...")
                retValue['success']=False
                retValue['error_msgs']=retValue['error_msgs']+msname+'/'+name+' does not exist'
            else:
                print(myname, ": ", name, "present.")
        print(myname, ": MS exists. All tables present.")

        if self.tempname in glob.glob("*.ms"):
            shutil.rmtree(self.tempname,ignore_errors=True)
        shutil.copytree(msname,self.tempname)
        print(myname, ": OK. Checking tables in detail ...")
        retValue['success']=True
        
        # check source table
        name = "SOURCE"
        #             col name, row number, expected value, tolerance
        expected = [
            ['SOURCE_ID',           41, 13, 0],
            ['SPECTRAL_WINDOW_ID',  41, 2, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        # check spw table
        name = "SPECTRAL_WINDOW"
        #             col name, row number, expected value, tolerance
        expected = [
            ['NUM_CHAN',           2, 128, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'


    def test3(self):
        '''Testconcat 3: 3 parts, different sources, same spws'''
        self.tempname = self._testMethodName + '.ms'
        retValue = {'success': True, 'msgs': "", 'error_msgs': '' }    
        self.res = testconcat(vis=['part1.ms','part2-mod2.ms','part3.ms'],testconcatvis=msname)
        self.assertEqual(self.res,None)

        print(myname, ": Success! Now checking output ...")
        mscomponents = set(["table.dat",
                            "table.f0",
                            "table.f1",
                            "table.f2",
                            "table.f3",
                            "table.f4",
                            "table.f5",
                            "table.f6",
                            "table.f7",
                            "table.f8",
                            "ANTENNA/table.dat",
                            "DATA_DESCRIPTION/table.dat",
                            "FEED/table.dat",
                            "FIELD/table.dat",
                            "FLAG_CMD/table.dat",
                            "HISTORY/table.dat",
                            "OBSERVATION/table.dat",
                            "POINTING/table.dat",
                            "POLARIZATION/table.dat",
                            "PROCESSOR/table.dat",
                            "SOURCE/table.dat",
                            "SPECTRAL_WINDOW/table.dat",
                            "STATE/table.dat",
                            "ANTENNA/table.f0",
                            "DATA_DESCRIPTION/table.f0",
                            "FEED/table.f0",
                            "FIELD/table.f0",
                            "FLAG_CMD/table.f0",
                            "HISTORY/table.f0",
                            "OBSERVATION/table.f0",
                            "POINTING/table.f0",
                            "POLARIZATION/table.f0",
                            "PROCESSOR/table.f0",
                            "SOURCE/table.f0",
                            "SPECTRAL_WINDOW/table.f0",
                            "STATE/table.f0"
                            ])
        for name in mscomponents:
            if not os.access(msname+"/"+name, os.F_OK):
                print(myname, ": Error  ", msname+"/"+name, "doesn't exist ...")
                retValue['success']=False
                retValue['error_msgs']=retValue['error_msgs']+msname+'/'+name+' does not exist'
            else:
                print(myname, ": ", name, "present.")
        print(myname, ": MS exists. All tables present.")

        if self.tempname in glob.glob("*.ms"):
            shutil.rmtree(self.tempname,ignore_errors=True)
        shutil.copytree(msname,self.tempname)
        print(myname, ": OK. Checking tables in detail ...")
        retValue['success']=True
        
        # check source table
        name = "SOURCE"
        #             col name, row number, expected value, tolerance
        expected = [
            ['SOURCE_ID',           28, 13, 0],
            ['SPECTRAL_WINDOW_ID',  28, 1, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        # check spw table
        name = "SPECTRAL_WINDOW"
        #             col name, row number, expected value, tolerance
        expected = [
            ['NUM_CHAN',           1, 128, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
    

    def test4(self):
        '''Testconcat 4: five MSs with identical sources but different time/intervals on them (CSV-268)'''
        self.tempname = self._testMethodName + '.ms'
        retValue = {'success': True, 'msgs': "", 'error_msgs': '' }
        
        self.res = testconcat(vis = ['shortpart1.ms', 'shortpart2.ms', 'shortpart3.ms', 'shortpart4.ms', 'shortpart5.ms'],
                          testconcatvis = msname, copypointing=False)
        self.assertEqual(self.res,None)

        print(myname, ": Success! Now checking output ...")
        mscomponents = set(["table.dat",
                            "table.f0",
                            "table.f1",
                            "table.f2",
                            "table.f3",
                            "table.f4",
                            "table.f5",
                            "table.f6",
                            "table.f7",
                            "table.f8",
                            "ANTENNA/table.dat",
                            "DATA_DESCRIPTION/table.dat",
                            "FEED/table.dat",
                            "FIELD/table.dat",
                            "FLAG_CMD/table.dat",
                            "HISTORY/table.dat",
                            "OBSERVATION/table.dat",
                            "POINTING/table.dat",
                            "POLARIZATION/table.dat",
                            "PROCESSOR/table.dat",
                            "SOURCE/table.dat",
                            "SPECTRAL_WINDOW/table.dat",
                            "STATE/table.dat",
                            "ANTENNA/table.f0",
                            "DATA_DESCRIPTION/table.f0",
                            "FEED/table.f0",
                            "FIELD/table.f0",
                            "FLAG_CMD/table.f0",
                            "HISTORY/table.f0",
                            "OBSERVATION/table.f0",
                            "POINTING/table.f0",
                            "POLARIZATION/table.f0",
                            "PROCESSOR/table.f0",
                            "SOURCE/table.f0",
                            "SPECTRAL_WINDOW/table.f0",
                            "STATE/table.f0"
                            ])
        for name in mscomponents:
            if not os.access(msname+"/"+name, os.F_OK):
                print(myname, ": Error  ", msname+"/"+name, "doesn't exist ...")
                retValue['success']=False
                retValue['error_msgs']=retValue['error_msgs']+msname+'/'+name+' does not exist'
            else:
                print(myname, ": ", name, "present.")
        print(myname, ": MS exists. All tables present.")

        if self.tempname in glob.glob("*.ms"):
            shutil.rmtree(self.tempname,ignore_errors=True)
        shutil.copytree(msname,self.tempname)
        print(myname, ": OK. Checking tables in detail ...")
        retValue['success']=True
        
        
        # check source table
        name = "SOURCE"
        #             col name, row number, expected value, tolerance
        expected = [
            ['SOURCE_ID',           0, 0, 0],
            ['SPECTRAL_WINDOW_ID',  0, 0, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        expected = [
            ['SOURCE_ID',           7, 0, 0],
            ['SPECTRAL_WINDOW_ID',  7, 7, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        expected = [
            ['SOURCE_ID',           8, 1, 0],
            ['SPECTRAL_WINDOW_ID',  8, 0, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        expected = [
            ['SOURCE_ID',           15, 1, 0],
            ['SPECTRAL_WINDOW_ID',  15, 7, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
        expected = [
            ['SOURCE_ID',           16, 0, 100000],
            ['SPECTRAL_WINDOW_ID',  16, 0, 100000]
            ]
        print("The following should fail: SOURCE row 16 should not exist")
        try:
            results = checktable(name, expected)
        except:
            print("Expected error.")
            results = False
        if results: 
            retValue['success']=False
            retValue['error_msgs']='SOURCE row 16 should not existCheck of table '+name+' failed'
        # check spw table
        name = "SPECTRAL_WINDOW"
        #             col name, row number, expected value, tolerance
        expected = [
            ['NUM_CHAN',           8, 4, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
    
                
        self.assertTrue(retValue['success'])
        
    def test5(self):
        '''Testconcat 5: two MSs with different state table (CAS-2601)'''
        self.tempname = self._testMethodName + '.ms'
        retValue = {'success': True, 'msgs': "", 'error_msgs': '' }
        
        self.res = testconcat(vis = ['A2256LC2_4.5s-1.ms','A2256LC2_4.5s-2.ms'],
                              testconcatvis = msname)
        self.assertEqual(self.res,None)

        print(myname, ": Success! Now checking output ...")
        mscomponents = set(["table.dat",
                            "table.f0",
                            "table.f1",
                            "table.f2",
                            "table.f3",
                            "table.f4",
                            "table.f5",
                            "table.f6",
                            "table.f7",
                            "table.f8",
                            "ANTENNA/table.dat",
                            "DATA_DESCRIPTION/table.dat",
                            "FEED/table.dat",
                            "FIELD/table.dat",
                            "FLAG_CMD/table.dat",
                            "HISTORY/table.dat",
                            "OBSERVATION/table.dat",
                            "POINTING/table.dat",
                            "POLARIZATION/table.dat",
                            "PROCESSOR/table.dat",
                            "SOURCE/table.dat",
                            "SPECTRAL_WINDOW/table.dat",
                            "STATE/table.dat",
                            "ANTENNA/table.f0",
                            "DATA_DESCRIPTION/table.f0",
                            "FEED/table.f0",
                            "FIELD/table.f0",
                            "FLAG_CMD/table.f0",
                            "HISTORY/table.f0",
                            "OBSERVATION/table.f0",
                            "POINTING/table.f0",
                            "POLARIZATION/table.f0",
                            "PROCESSOR/table.f0",
                            "SOURCE/table.f0",
                            "SPECTRAL_WINDOW/table.f0",
                            "STATE/table.f0"
                            ])
        for name in mscomponents:
            if not os.access(msname+"/"+name, os.F_OK):
                print(myname, ": Error  ", msname+"/"+name, "doesn't exist ...")
                retValue['success']=False
                retValue['error_msgs']=retValue['error_msgs']+msname+'/'+name+' does not exist'
            else:
                print(myname, ": ", name, "present.")
        print(myname, ": MS exists. All tables present.")

        if self.tempname in glob.glob("*.ms"):
            shutil.rmtree(self.tempname,ignore_errors=True)
        shutil.copytree(msname,self.tempname)
        print(myname, ": OK. Checking tables in detail ...")
        retValue['success']=True        
        
        # check state table
        name = "STATE"
        #             col name, row number, expected value, tolerance
        expected = [
            ['CAL',  0, 0, 0],
            ['SIG',  0, 1, 0],
            ['SUB_SCAN',  2, 1, 0]
            ]
        results = checktable(name, expected)
        if not results:
            retValue['success']=False
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
            retValue['error_msgs']=retValue['error_msgs']+'Check of table '+name+' failed'
                
        self.assertTrue(retValue['success'])

if __name__ == '__main__':
    unittest.main()