import os
import sys
import math
import shutil
import string
import time
import numpy as np
import math

try:
    # CASA 6
    from casatools import table
    # Most of this helper file is about table operations. The ms and image tools are used
    # only for a couple of functions (for which there might be a better place)
    from casatools import ms, image

    tb_local = table()
    tb_local2 = table()
    ms_local = ms()
    image_local = image()
except ImportError:
    # CASA 5
    from taskinit import tbtool
    from taskinit import mstool, iatool

    tb_local = tbtool()
    tb_local2 = tbtool()
    ms_local = mstool()
    image_local = iatool()


'''
A set of common helper functions for unit tests:
   compTables - compare two CASA tables
   compVarColTables - Compare a variable column of two tables
   DictDiffer - a class with methods to take a difference of two 
                Python dictionaries
   verify_ms - Function to verify spw and channels information in an MS   
   create_input - Save the string in a text file with the given name           
'''

def phasediffabsdeg(c1, c2):
    try:
        a = c1.imag
        a = c2.imag
    except:
        print("Phase difference of real numbers is always zero.")
        return 0.

    a = math.atan2(c1.imag, c1.real)
    b = math.atan2(c2.imag, c2.real)
    diff = abs(a-b)
    if diff>np.pi:
        diff = 2*np.py - diff
    return diff/np.pi*180. # (degrees)

def compTables(referencetab, testtab, excludecols, tolerance=0.001, mode="percentage", startrow = 0, nrow = -1, rowincr = 1):

    """
    compTables - compare two CASA tables
    
       referencetab - the table which is assumed to be correct

       testtab - the table which is to be compared to referencetab

       excludecols - list of column names which are to be ignored

       tolerance - permitted fractional difference (default 0.001 = 0.1 percent)

       mode - comparison is made as "percentage", "absolute", "phaseabsdeg" (for complex numbers = difference of the phases in degrees)  
    """

    rval = True

    tb_local.open(referencetab)
    cnames = tb_local.colnames()

    tb_local2.open(testtab)

    try:
        for c in cnames:
            if c in excludecols:
                continue
            
            print("\nTesting column " + c)
            
            a = 0
            try:
                a = tb_local.getcol(c,startrow=startrow,nrow=nrow,rowincr=rowincr)
            except:
                rval = False
                print('Error accessing column ', c, ' in table ', referencetab)
                print(sys.exc_info()[0])
                break

            b = 0
            try:
                b = tb_local2.getcol(c,startrow=startrow,nrow=nrow,rowincr=rowincr)
            except:
                rval = False
                print('Error accessing column ', c, ' in table ', testtab)
                print(sys.exc_info()[0])
                break

            if not (len(a)==len(b)):
                print('Column ',c,' has different length in tables ', referencetab, ' and ', testtab)
                print(a)
                print(b)
                rval = False
                break
            else:
                differs = False
                if not (a==b).all():
                    for i in range(0,len(a)):
                        if (isinstance(a[i],float)):
                            if ((mode=="percentage") and (abs(a[i]-b[i]) > tolerance*abs(a[i]))) or ((mode=="absolute") and (abs(a[i]-b[i]) > tolerance)):
                                print("Column " + c + " differs")
                                print("Row=" + str(i))
                                print("Reference file value: " + str(a[i]))
                                print("Input file value: " + str(b[i]))
                                if (mode=="percentage"):
                                    print("Tolerance is {0}%; observed difference was {1} %".format (tolerance * 100, 100*abs(a[i]-b[i])/abs(a[i])))
                                else:
                                    print("Absolute tolerance is {0}; observed difference: {1}".format (tolerance, (abs(a[i]-b[i]))))
                                differs = True
                                rval = False
                                break
                        elif (isinstance(a[i],int) or isinstance(a[i],np.int32)):
                            if (abs(a[i]-b[i]) > 0):
                                print("Column " + c + " differs")
                                print("Row=" + str(i))
                                print("Reference file value: " + str(a[i]))
                                print("Input file value: " + str(b[i]))
                                if (mode=="percentage"):
                                    print("tolerance in % should be " + str(100*abs(a[i]-b[i])/abs(a[i])))
                                else:
                                    print("absolute tolerance should be " + str(abs(a[i]-b[i])))
                                differs = True
                                rval = False
                                break
                        elif (isinstance(a[i],str) or isinstance(a[i],np.bool_)):
                            if not (a[i]==b[i]):
                                print("Column " + c + " differs")
                                print("Row=" + str(i))
                                print("Reference file value: " + str(a[i]))
                                print("Input file value: " + str(b[i]))
                                if (mode=="percentage"):   
                                    print("tolerance in % should be " + str(100*abs(a[i]-b[i])/abs(a[i])))
                                else:
                                    print("absolute tolerance should be " + str(abs(a[i]-b[i])))
                                differs = True
                                rval = False
                                break
                        elif (isinstance(a[i],list)) or (isinstance(a[i],np.ndarray)):
                            for j in range(0,len(a[i])):
                                if differs: break
                                if ((isinstance(a[i][j],float)) or (isinstance(a[i][j],int))):
                                    if ((mode=="percentage") and (abs(a[i][j]-b[i][j]) > tolerance*abs(a[i][j]))) or ((mode=="absolute") and (abs(a[i][j]-b[i][j]) > tolerance)):
                                        print("Column " + c + " differs")
                                        print("(Row,Element)=(" + str(j) + "," + str(i) + ")")
                                        print("Reference file value: " + str(a[i][j]))
                                        print("Input file value: " + str(b[i][j]))
                                        if (mode=="percentage"):
                                            print("Tolerance in % should be " + str(100*abs(a[i][j]-b[i][j])/abs(a[i][j])))
                                        else:
                                            print("Absolute tolerance should be " + str(abs(a[i][j]-b[i][j])))
                                        differs = True
                                        rval = False
                                        break
                                elif (isinstance(a[i][j],list)) or (isinstance(a[i][j],np.ndarray)):
                                    it = range(0,len(a[i][j]))
                                    if mode=="percentage":
                                        diff = np.abs(np.subtract(a[i][j], b[i][j])) > tolerance * np.abs(a[i][j])
                                        it = np.where(diff)[0]
                                    elif (mode=="absolute"):
                                        diff = np.abs(np.subtract(a[i][j], b[i][j])) > tolerance
                                        it = np.where(diff)[0]
                                    for k in it:
                                        if differs: break
                                        if ( ((mode=="percentage") and (abs(a[i][j][k]-b[i][j][k]) > tolerance*abs(a[i][j][k]))) \
                                                 or ((mode=="absolute") and (abs(a[i][j][k]-b[i][j][k]) > tolerance)) \
                                                 or ((mode=="phaseabsdeg") and (phasediffabsdeg(a[i][j][k],b[i][j][k])>tolerance)) \
                                                 ):
                                            print("Column " + c + " differs")
                                            print("(Row,Channel,Corr)=(" + str(k) + "," + str(j) + "," + str(i) + ")")
                                            print("Reference file value: " + str(a[i][j][k]))
                                            print("Input file value: " + str(b[i][j][k]))
                                            if (mode=="percentage"):
                                                print("Tolerance in % should be " + str(100*abs(a[i][j][k]-b[i][j][k])/abs(a[i][j][k])))
                                            elif (mode=="absolute"):
                                                print("Absolute tolerance should be " + str(abs(a[i][j][k]-b[i][j][k])))
                                            elif (mode=="phaseabsdeg"):
                                                print("Phase tolerance in degrees should be " + str(phasediffabsdeg(a[i][j][k],b[i][j][k])))
                                            else:
                                                print("Unknown comparison mode: ",mode)
                                            differs = True
                                            rval = False
                                            break                                          
                                            
                        else:
                            print("Unknown data type: ",type(a[i]))
                            differs = True
                            rval = False
                            break
                
                if not differs: print("Column " + c + " PASSED")
    finally:
        tb_local.close()
        tb_local2.close()

    return rval

    
def compVarColTables(referencetab, testtab, varcol, tolerance=0.):
    '''Compare a variable column of two tables.
       referencetab  --> a reference table
       testtab       --> a table to verify
       varcol        --> the name of a variable column (str)
       Returns True or False.
    '''
    
    retval = True

    tb_local.open(referencetab)
    cnames = tb_local.colnames()

    tb_local2.open(testtab)
    col = varcol
    if tb_local.isvarcol(col) and tb_local2.isvarcol(col):
        try:
            # First check
            if tb_local.nrows() != tb_local2.nrows():
                print('Length of %s differ from %s, %s!=%s'%(referencetab,testtab,len(rk),len(tk)))
                retval = False
            else:
                for therow in range(tb_local.nrows()):
            
                    rdata = tb_local.getcell(col,therow)
                    tdata = tb_local2.getcell(col,therow)

#                    if not (rdata==tdata).all():
                    if not rdata.all()==tdata.all():
                        if (tolerance>0.):
                            differs=False
                            for j in range(0,len(rdata)):
###                                if (type(rdata[j])==float or type(rdata[j])==int):
                                if ((isinstance(rdata[j],float)) or (isinstance(rdata[j],int))):
                                    if (abs(rdata[j]-tdata[j]) > tolerance*abs(rdata[j]+tdata[j])):
#                                        print('Column ', col,' differs in tables ', referencetab, ' and ', testtab)
#                                        print(therow, j)
#                                        print(rdata[j])
#                                        print(tdata[j])
                                        differs = True
###                                elif (type(rdata[j])==list or type(rdata[j])==np.ndarray):
                                elif (isinstance(rdata[j],list)) or (isinstance(rdata[j],np.ndarray)):
                                    for k in range(0,len(rdata[j])):
                                        if (abs(rdata[j][k]-tdata[j][k]) > tolerance*abs(rdata[j][k]+tdata[j][k])):
#                                            print('Column ', col,' differs in tables ', referencetab, ' and ', testtab)
#                                            print(therow, j, k)
#                                            print(rdata[j][k])
#                                            print(tdata[j][k])
                                            differs = True
                                if differs:
                                    print('ERROR: Column %s of %s and %s do not agree within tolerance %s'%(col,referencetab, testtab, tolerance))
                                    retval = False
                                    break
                        else:
                            print('ERROR: Column %s of %s and %s do not agree.'%(col,referencetab, testtab))
                            print('ERROR: First row to differ is row=%s'%therow)
                            retval = False
                            break
        finally:
            tb_local.close()
            tb_local2.close()
    
    else:
        print('Columns are not varcolumns.')
        retval = False

    if retval:
        print('Column %s of %s and %s agree'%(col,referencetab, testtab))
        
    return retval

    
        
class DictDiffer(object):
    """
    Calculate the difference between two dictionaries as:
    (1) items added
    (2) items removed
    (3) keys same in both but changed values
    (4) keys same in both and unchanged values
    Example:
            mydiff = DictDiffer(dict1, dict2)
            mydiff.changed()  # to show what has changed
    """
    def __init__(self, current_dict, past_dict):
        self.current_dict, self.past_dict = current_dict, past_dict
        self.set_current, self.set_past = set(current_dict.keys()), set(past_dict.keys())
        self.intersect = self.set_current.intersection(self.set_past)
    def added(self):
        return self.set_current - self.intersect 
    def removed(self):
        return self.set_past - self.intersect 
    def changed(self):
        return set(o for o in self.intersect if self.past_dict[o] != self.current_dict[o])            
    def unchanged(self):
        return set(o for o in self.intersect if self.past_dict[o] == self.current_dict[o])


def verifyMS(msname, expnumspws, expnumchan, inspw, expchanfreqs=[], ignoreflags=False):
    '''Function to verify spw and channels information in an MS
       msname        --> name of MS to verify
       expnumspws    --> expected number of SPWs in the MS
       expnumchan    --> expected number of channels in spw
       inspw         --> SPW ID
       expchanfreqs  --> numpy array with expected channel frequencies
       ignoreflags   --> do not check the FLAG column
           Returns a list with True or False and a state message'''
    
    msg = ''
    tb_local.open(msname+'/SPECTRAL_WINDOW')
    nc = tb_local.getcell("NUM_CHAN", inspw)
    nr = tb_local.nrows()
    cf = tb_local.getcell("CHAN_FREQ", inspw)
    tb_local.close()
    # After channel selection/average, need to know the exact row number to check,
    # ignore this check in these cases.
    if not ignoreflags:
        tb_local.open(msname)
        dimdata = tb_local.getcell("FLAG", 0)[0].size
        tb_local.close()
        
    if not (nr==expnumspws):
        msg =  "Found "+str(nr)+", expected "+str(expnumspws)+" spectral windows in "+msname
        return [False,msg]
    if not (nc == expnumchan):
        msg = "Found "+ str(nc) +", expected "+str(expnumchan)+" channels in spw "+str(inspw)+" in "+msname
        return [False,msg]
    if not ignoreflags and (dimdata != expnumchan):
        msg = "Found "+ str(dimdata) +", expected "+str(expnumchan)+" channels in FLAG column in "+msname
        return [False,msg]

    if not (expchanfreqs==[]):
        print("Testing channel frequencies ...")
#        print(cf)
#        print(expchanfreqs)
        if not (expchanfreqs.size == expnumchan):
            msg =  "Internal error: array of expected channel freqs should have dimension ", expnumchan
            return [False,msg]
        df = (cf - expchanfreqs)/expchanfreqs
        if not (abs(df) < 1E-8).all:
            msg = "channel frequencies in spw "+str(inspw)+" differ from expected values by (relative error) "+str(df)
            return [False,msg]

    return [True,msg]


def getChannels(msname, spwid, chanlist):
    '''From a list of channel indices, return their frequencies
       msname       --> name of MS
       spwid        --> spw ID
       chanlist     --> list of channel indices
    Return a numpy array, the same size of chanlist, with the frequencies'''
    
    try:
        try:
            tb_local.open(msname+'/SPECTRAL_WINDOW')
        except:
            print('Cannot open table '+msname+'SPECTRAL_WINDOW')
            
        cf = tb_local.getcell("CHAN_FREQ", spwid)
        
        # Get only the requested channels
        b = [cf[i] for i in chanlist]
        selchans = np.array(b)
    
    finally:
        tb_local.close()
        
    return selchans

def get_channel_freqs_widths(msname, spwid):
    '''
    Get frequencies and widths of all the channels for an spw ID
       msname       --> name of MS
       spwid        --> spw ID

    Return two numpy arrays (frequencies, widths), each of the same length as the number of
    channels'''

    try:
        spw_table = os.path.join(msname, 'SPECTRAL_WINDOW')
        try:
            tb_local.open(spw_table)
        except RuntimeError:
            print('Cannot open table: {0}').format(spw_table)

        freqs = tb_local.getcell("CHAN_FREQ", spwid)
        widths = tb_local.getcell("CHAN_WIDTH", spwid)

    finally:
        tb_local.close()

    return freqs, widths

def getColDesc(table, colname):
    '''Get the description of a column in a table
       table    --> name of table or MS
       colname  --> column name
    Return a dictionary with the column description'''
    
    coldesc = {}
    try:
        try:
            tb_local.open(table)
            tcols = tb_local.colnames()
            if tcols.__contains__(colname):
                coldesc = tb_local.getcoldesc(colname)
        except:
            pass                        
    finally:
        tb_local.close()
        
    return coldesc

def getVarCol(table, colname):
    '''Return the requested variable column
       table    --> name of table or MS
       colname  --> column name
    Return the column as a dictionary'''
    
    col = {}
    try:
        try:
            tb_local.open(table)
            col = tb_local.getvarcol(colname)
        except:
            print('Cannot open table '+table)

    finally:
        tb_local.close()
        
    return col
   
def createInput(str_text, filename):
    '''Save the string in a text file with the given name
    str_text    --> string to save
    filename    --> name of the file to save
            It will remove the filename if it exist!'''
    
    inp = filename    
    cmd = str_text
    
    # remove file first
    if os.path.exists(inp):
        os.system('rm -f '+ inp)
    
    try:
        # save to a file    
        with open(inp, 'w') as f:
            f.write(cmd)
            
    finally:  
        f.close()
    
    return

def calculateHanning(dataB,data,dataA):
    '''Calculate the Hanning smoothing of each element'''
    const0 = 0.25
    const1 = 0.5
    const2 = 0.25
    S = const0*dataB + const1*data + const2*dataA
    return S


def getTileShape(mydict, column='DATA'):
    '''Return the value of TileShape for a given column
       in the dictionary from data managers (tb.getdminfo).
       mydict --> dictionary from tb.getdminfo()
       column --> column where to look for TileShape'''
    
    tsh = {}
    for key, value in mydict.items():
        if mydict[key]['COLUMNS'][0] == column:
             # Dictionary for requested column
            hyp = mydict[key]['SPEC']['HYPERCUBES']
                    
             # This is the HYPERCUBES dictionary
            for hk in hyp.keys():
                tsh = hyp[hk]['TileShape']
                break
                    
            break
    
    return tsh

def checkwithtaql(taqlstring):
    os.system('rm -rf mynewtable.tab')
    tb_local.create('mynewtable.tab')
    tb_local.open('mynewtable.tab',nomodify=False)
    rval = tb_local.taql(taqlstring)
    tb_local.close()
    therval = rval.nrows()
    tmpname = rval.name()
    rval.close()
    os.system('rm -rf mynewtable.tab')
    os.system('rm -rf '+tmpname)
    print("Found ", therval, " rows in selection.")
    return therval


def compcaltabnumcol(cal1, cal2, tolerance, colname1='CPARAM', colname2="CPARAM", testspw=None):
    print("Comparing column "+colname1+" of caltable "+cal1)
    print("     with column "+colname2+" of caltable "+cal2)
    if testspw!=None:
        print("for SPW "+str(testspw)+" only.")
    print("Discrepant row search ...")
    rval = False
    try:
        discrepantrows = -1
        if(testspw==None):
            discrepantrows = checkwithtaql("select from [select from "+cal1+" orderby TIME, FIELD_ID, SPECTRAL_WINDOW_ID, ANTENNA1, ANTENNA2 ] t1, [select from "+cal2+" orderby TIME, FIELD_ID, SPECTRAL_WINDOW_ID, ANTENNA1, ANTENNA2 ] t2 where (not all(near(t1."+colname1+",t2."+colname2+", "+str(tolerance)+")))")
        else:
            discrepantrows = checkwithtaql("select from [select from "+cal1+" where SPECTRAL_WINDOW_ID=="+str(testspw)+" orderby TIME, FIELD_ID, ANTENNA1, ANTENNA2 ] t1, [select from "+cal2+" where SPECTRAL_WINDOW_ID=="+str(testspw)+" orderby TIME, FIELD_ID, ANTENNA1, ANTENNA2 ] t2 where (not all(near(t1."+colname1+",t2."+colname2+", "+str(tolerance)+")))")
        if discrepantrows==0:
            print("The two columns agree.")
            rval = True
    except Exception as instance:
        print("Error: "+str(instance))

    return rval

                    
def compmsmainnumcol(vis1, vis2, tolerance, colname1='DATA', colname2="DATA"):
    print("Comparing column "+colname1+" of MS "+vis1)
    print("     with column "+colname2+" of MS "+vis2)
    print("Discrepant row search ...")
    rval = False
    try:
        discrepantrows = checkwithtaql("select from [select from "+vis1+" orderby TIME, DATA_DESC_ID, ANTENNA1, ANTENNA2 ] t1, [select from "+vis2+" orderby TIME, DATA_DESC_ID, ANTENNA1, ANTENNA2 ] t2 where (not all(near(t1."+colname1+",t2."+colname2+", "+str(tolerance)+")))")
        if discrepantrows==0:
            print("The two columns agree.")
            rval = True
    except Exception as instance:
        print("Error: "+str(instance))

    return rval

def compmsmainboolcol(vis1, vis2, colname1='FLAG', colname2='FLAG'):
    print("Comparing column "+colname1+" of MS "+vis1)
    print("     with column "+colname2+" of MS "+vis2)
    print("Discrepant row search ...")
    rval = False
    try:
        discrepantrows = checkwithtaql("select from [select from "+vis1+" orderby TIME, DATA_DESC_ID, ANTENNA1, ANTENNA2 ] t1, [select from "+vis2+" orderby TIME, DATA_DESC_ID, ANTENNA1, ANTENNA2 ] t2 where (not all(t1."+colname1+"==t2."+colname2+"))")
        if discrepantrows==0:
            print("The two columns agree.")
            rval = True
    except Exception as instance:
        print("Error: "+str(instance))

    return rval

def compareSubTables(input,reference,order=None,excluded_cols=[]):
    
    tbinput = tb_local
    tbinput.open(input)
    if order is not None:
        tbinput_sorted = tbinput.taql("SELECT * from " + input + " order by " + order)
    else:
        tbinput_sorted = tbinput
    
    tbreference = tb_local2
    tbreference.open(reference)
    if order is not None:
        tbreference_sorted = tbreference.taql("SELECT * from " + reference + " order by " + order)
    else:
        tbreference_sorted = tbreference
    
    columns = tbinput.colnames()
    for col in columns:
        if not col in excluded_cols:
            col_input = tbinput_sorted.getcol(col)
            col_reference = tbreference_sorted.getcol(col)
            if not (col_input == col_reference).all():
                tbinput.close()
                tbreference.close()
                return (False,col)

    tbinput.close()
    tbreference.close()
    if order is not None:
        tbinput_sorted.close()
        tbreference_sorted.close()

    return (True,"OK")

def getColShape(tab,col,start_row=0,nrow=1,row_inc=1):
    """ Get the shape of the given column.
    Keyword arguments:
        tab        --    input table or MS
        col        --    column to get the shape
        start_row  --    start row (default 0)
        nrow       --    number of rows to read (default 1)
        row_inc    --    increment of rows to read (default 1)
        
        Return a list of strings with the shape of each row in the column.
    
    """

    col_shape = []
    try:
        try:
            tb_local.open(tab)
            col_shape = tb_local.getcolshapestring(col,start_row,nrow,row_inc)
        except:
            print('Cannot get shape of col %s from table %s '%(col,tab))

    finally:
        tb_local.close()
            
    return col_shape

def findTemplate(testname,refimage,copy=False):
    """
    find a template image (or MS - it does assume its a directory)
    look in order in:
    REGRESSION_DATA/regression/testname/refimage
    CASAPATH/data/regression/testname/refimage
    REGRESSION_DATA/regression/testname/reference/refimage/
    CASAPATH/data/regression/testname/reference/refimage
    if copy=True, copy what's found to cwd
    """
    from os import F_OK
    try:
        datapaths=REGRESSION_DATA
    except:
        datapaths=[]
    datapaths.append(os.environ.get('CASAPATH').split()[0]+"/data/")
    possibilities=map(lambda x: x+'/regression/'+testname+'/'+refimage,datapaths)+map(lambda x: x+'/regression/'+testname+'/reference/'+refimage,datapaths) 

    #print(possibilities)
    from itertools import dropwhile
    try:
        found = dropwhile( lambda x: not os.access(x,F_OK),possibilities).next()
    except:
        raise IOError(" ERROR: "+refimage+" not found")
    if copy:
        from shutil import copytree
        print("Copying "+found)
        copytree(found,msname)
    return found


# As opposed to most other functions in this file, this doesn't use the table tool but the
# image tool
def compImages(im0,im1,keys=['flux','min','max','maxpos','rms'],tol=1e-4,verbose=False):
    """
    compare two images using imstat and the specified keys, 
    to a tolerance tol, and printing the comparison if verbose==True
    note that the string keys like 'blcf' will fail
    """
    from os import F_OK
    if isinstance(tol,float):
        tol=tol+np.zeros(len(keys))
    ims=[im0,im1]
    s=[]
    for i in range(2):
        if not os.access(ims[i],F_OK): 
            print(ims[i]+" not found")
            return False
        image_local.open(ims[1])
        s.append(image_local.statistics())
        image_local.done()
    status=True
    for ik in range(len(keys)):
        k=keys[ik]
        s0=s[0][k][0]
        s1=s[1][k][0]
        if abs(s0-s1)*2/(s0+s1)>tol[ik]: status=False
        if verbose:
            print(("%7s: "%k),s0,s1)
    return status


# As opposed to most other functions in this file, this doesn't use the table tool but the
# ms tool
def compMS(ms0,ms1,keys=['mean','min','max','rms'],ap="amp",tol=1e-4,verbose=False):
    """
    compare two MS using ms.statistics on amp or phase as specified, 
    and the specified keys, 
    to a tolerance tol, and printing the comparison if verbose==True
    """
    from os import F_OK
    if isinstance(tol,float):
        tol=tol+np.zeros(len(keys))
    mss=[ms0,ms1]
    s=[]
    for i in range(2):
        if not os.access(mss[i],F_OK): 
            print(mss[i]+" not found")
            return False
        ms_local.open(mss[1])
        stats = ms_local.statistics("DATA",ap)
        s.append(stats[stats.keys()[0]])
        ms_local.done()
    status=True
    for ik in range(len(keys)):
        k=keys[ik]
        s0=s[0][k]
        s1=s[1][k]
        if abs(s0-s1)*2/(s0+s1)>tol[ik]: status=False
        if verbose:
            print(("%7s: "%k),s0,s1)
    return status


def get_table_cache():
    cache = tb_local.showcache()
    # print('cache = {}'.format(cache))
    return cache

def is_casa6():
    try:
        # CASA 6
        from casatools import table
        return True
    except ImportError:
        try:
            # CASA 5
            from taskinit import tbtool
            return False
        except ImportError:
            raise Exception('Neither CASA5 nor CASA6')
    
class TableCacheValidator(object):
    def __init__(self):
        self.original_cache = get_table_cache()

    def validate(self):
        cache = get_table_cache()
        #print 'original {} current {}'.format(self.original_cache, cache)
        return len(cache) == 0 or cache == self.original_cache