from __future__ import absolute_import
import numpy
import os
import contextlib
from collections import Counter

# get is_CASA6 and is_python3
from casatasks.private.casa_transition import *
if is_CASA6:
    from casatools import singledishms, table, msmetadata
    from casatools import ms as mstool
    from casatasks import casalog
    from .mstools import write_history
    from . import sdutil

    ms = mstool()
    sdms = singledishms()
    tb = table()
    msmd = msmetadata()
else:
    from taskinit import gentools, casalog
    from mstools import write_history
    import sdutil
    ms, sdms, tb, msmd = gentools(['ms', 'sdms', 'tb', 'msmd'])


def sdbaseline(infile=None, datacolumn=None, antenna=None, field=None,
               spw=None, timerange=None, scan=None, pol=None, intent=None,
               reindex=None, maskmode=None, thresh=None, avg_limit=None,
               minwidth=None, edge=None, blmode=None, dosubtract=None,
               blformat=None, bloutput=None, bltable=None, blfunc=None,
               order=None, npiece=None, applyfft=None, fftmethod=None,
               fftthresh=None, addwn=None, rejwn=None, clipthresh=None,
               clipniter=None, blparam=None, verbose=None, 
               updateweight=None, sigmavalue=None,
               showprogress=None, minnrow=None, 
               outfile=None, overwrite=None):

    casalog.origin('sdbaseline')
    try:
        # CAS-12985 requests the following params be given case insensitively,
        # so they need to be converted to lowercase here (2021/1/28 WK)
        blfunc = blfunc.lower()
        blmode = blmode.lower()
        fftmethod = fftmethod.lower()
        if isinstance(fftthresh, str):
            fftthresh = fftthresh.lower()

        if not os.path.exists(infile):
            raise Exception("infile='" + str(infile) + "' does not exist.")
        if (outfile == '') or not isinstance(outfile, str):
            #casalog.post("type=%s, value=%s" % (type(outfile), str(outfile)))
            #raise ValueError, "outfile name is empty."
            outfile = infile.rstrip('/') + '_bs'
            casalog.post("outfile is empty or non-string. set to '" + outfile + "'")
        if os.path.exists(outfile) and not overwrite:
            raise Exception("outfile='%s' exists, and cannot overwrite it." % (outfile))
        if (maskmode == 'interact'):
            raise ValueError("maskmode='%s' is not supported yet" % maskmode)
        if (blfunc == 'variable' and not os.path.exists(blparam)):
            raise ValueError("input file '%s' does not exists" % blparam)
        blparam_file = infile + '_blparam.txt'
        if os.path.exists(blparam_file):
            remove_data(blparam_file)  # CAS-11781
        
        if (spw == ''): spw = '*'

        if (blmode == 'apply'):
            if not os.path.exists(bltable):
                raise ValueError("file specified in bltable '%s' does not exist." % bltable)

            sorttab_info = remove_sorted_table_keyword(infile)

            if overwrite and os.path.exists(outfile) and (infile != outfile):
                os.system('rm -rf %s' % outfile)

            selection = ms.msseltoindex(vis=infile, spw=spw, field=field, 
                                        baseline=antenna, time=timerange, 
                                        scan=scan)
            sdms.open(infile)
            sdms.set_selection(spw=sdutil.get_spwids(selection), field=field, 
                               antenna=antenna, timerange=timerange, 
                               scan=scan, polarization=pol, intent=intent,
                               reindex=reindex)
            sdms.apply_baseline_table(bltable=bltable,
                                      datacolumn=datacolumn,
                                      spw=spw,
                                      updateweight=updateweight,
                                      sigmavalue=sigmavalue,
                                      outfile=outfile)
            sdms.close()
            
            restore_sorted_table_keyword(infile, sorttab_info)
            
        elif (blmode == 'fit'):

            if(blfunc == 'sinusoid'):
                addwn = sdutil.parse_wavenumber_param(addwn)
                rejwn = sdutil.parse_wavenumber_param(rejwn)
                check_fftthresh(fftthresh)

            blformat, bloutput = prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite)

            output_bloutput_text_header(blformat, bloutput,
                                        blfunc, maskmode,
                                        infile, outfile)
            
            if (blfunc == 'variable'):
                sorttab_info = remove_sorted_table_keyword(infile)
        
            if overwrite and os.path.exists(outfile) and (infile != outfile):
                os.system('rm -rf %s' % outfile)

            selection = ms.msseltoindex(vis=infile, spw=spw, field=field, 
                                        baseline=antenna, time=timerange, 
                                        scan=scan)
            sdms.open(infile)
            sdms.set_selection(spw=sdutil.get_spwids(selection),
                               field=field, antenna=antenna,
                               timerange=timerange, scan=scan,
                               polarization=pol, intent=intent,
                               reindex=reindex)
            params, func = prepare_for_baselining(blfunc=blfunc,
                                                  datacolumn=datacolumn,
                                                  outfile=outfile,
                                                  bloutput=','.join(bloutput),
                                                  dosubtract=dosubtract,
                                                  spw=spw,
                                                  pol=pol,
                                                  linefinding=(maskmode=='auto'),
                                                  threshold=thresh,
                                                  avg_limit=avg_limit,
                                                  minwidth=minwidth,
                                                  edge=edge,
                                                  order=order,
                                                  npiece=npiece,
                                                  applyfft=applyfft,
                                                  fftmethod=fftmethod,
                                                  fftthresh=fftthresh,
                                                  addwn=addwn,
                                                  rejwn=rejwn,
                                                  clip_threshold_sigma=clipthresh,
                                                  num_fitting_max=clipniter+1,
                                                  blparam=blparam,
                                                  verbose=verbose,
                                                  updateweight=updateweight,
                                                  sigmavalue=sigmavalue)
            func(**params)
            sdms.close()
            
            if (blfunc == 'variable'):
                restore_sorted_table_keyword(infile, sorttab_info)

        # Remove {WEIGHT|SIGMA}_SPECTRUM columns if updateweight=True (CAS-13161)
        if updateweight:
            with sdutil.tbmanager(outfile, nomodify=False) as mytb:
                cols_remove = []
                for col in ['WEIGHT_SPECTRUM', 'SIGMA_SPECTRUM']:
                    if col in mytb.colnames():
                        cols_remove.append(col)
                if len(cols_remove) > 0:
                    mytb.removecols(' '.join(cols_remove))

        # Write history to outfile
        param_names = sdbaseline.__code__.co_varnames[:sdbaseline.__code__.co_argcount]
        if is_python3:
            vars = locals()
            param_vals = [vars[p] for p in param_names]
        else:
            param_vals = [eval(p) for p in param_names]
        write_history(ms, outfile, 'sdbaseline', param_names,
                      param_vals, casalog)


    except Exception:
        raise


blformat_item = ['csv', 'text', 'table']
blformat_ext  = ['csv', 'txt',  'bltable']


def remove_data(filename):
    if os.path.exists(filename):
        if os.path.isdir(filename):
            shutil.rmtree(filename)
        elif os.path.isfile(filename):
            os.remove(filename)
        else:
            # could be a symlink
            os.remove(filename)

def check_fftthresh(fftthresh):
    has_valid_type = isinstance(fftthresh, float) or isinstance(fftthresh, int) or isinstance(fftthresh, str)
    if not has_valid_type:
        raise ValueError('fftthresh must be float or integer or string.')

    not_positive_mesg = 'threshold given to fftthresh must be positive.'
    
    if isinstance(fftthresh, str):
        try:
            val_not_positive = False
            if (3 < len(fftthresh)) and (fftthresh[:3] == 'top'):
                val_top = int(fftthresh[3:])
                if (val_top <= 0):
                    val_not_positive = True
            elif (5 < len(fftthresh)) and (fftthresh[-5:] == 'sigma'):
                val_sigma = float(fftthresh[:-5])
                if (val_sigma <= 0.0):
                    val_not_positive = True
            else:
                val_sigma = float(fftthresh)
                if (val_sigma <= 0.0):
                    val_not_positive = True
            
            if val_not_positive:
                raise ValueError(not_positive_mesg)
        except Exception as e:
            if (str(e) == not_positive_mesg):
                raise
            else:
                raise ValueError('fftthresh has a wrong format.')

    else:
        if (fftthresh <= 0.0):
            raise ValueError(not_positive_mesg)

def prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite):
    # force to string list
    blformat = force_to_string_list(blformat, 'blformat')
    bloutput = force_to_string_list(bloutput, 'bloutput')

    # the default bloutput value '' is expanded to a list 
    # with length of blformat, and with '' throughout.
    if (bloutput == ['']): bloutput *= len(blformat)

    # check length
    if (len(blformat) != len(bloutput)):
        raise ValueError('blformat and bloutput must have the same length.')

    # check duplication
    if has_duplicate_nonnull_element(blformat):
        raise ValueError('duplicate elements in blformat.')
    if has_duplicate_nonnull_element_ex(bloutput, blformat):
        raise ValueError('duplicate elements in bloutput.')

    # fill bloutput items to be output, then rearrange them
    # in the order of blformat_item.
    bloutput = normalise_bloutput(infile, blformat, bloutput, overwrite)

    return blformat, bloutput

def force_to_string_list(s, name):
    mesg = '%s must be string or list of string.' % name
    if isinstance(s, str): s = [s]
    elif isinstance(s, list):
        for i in range(len(s)):
            if not isinstance(s[i], str):
                raise ValueError(mesg)
    else:
        raise ValueError(mesg)
    return s

def has_duplicate_nonnull_element(in_list):
    #return True if in_list has duplicated elements other than ''
    duplicates = [key for key, val in Counter(in_list).items() if val > 1]
    len_duplicates = len(duplicates)
    
    if (len_duplicates >= 2):
        return True
    elif (len_duplicates == 1):
        return (duplicates[0] != '')
    else: #len_duplicates == 0
        return False


def has_duplicate_nonnull_element_ex(lst, base):
    # lst and base must have the same length.
    #
    # (1) extract elements from lst and make a new list
    #     if the element of base with the same index
    #     is not ''.
    # (2) check if the list made in (1) has duplicated
    #     elements other than ''.
    
    return has_duplicate_nonnull_element(
        [lst[i] for i in range(len(lst)) if base[i] != ''])

def normalise_bloutput(infile, blformat, bloutput, overwrite):
    normalised_bloutput = []
    for item in zip(blformat_item, blformat_ext):
        normalised_bloutput.append(
            get_normalised_name(infile, blformat, bloutput, item[0], item[1], overwrite))
    return normalised_bloutput

def get_normalised_name(infile, blformat, bloutput, name, ext, overwrite):
    fname = ''
    blformat_lower = [s.lower() for s in blformat]
    if (name in blformat_lower):
        fname = bloutput[blformat_lower.index(name)]
        if (fname == ''):
            fname = infile + '_blparam.' + ext
    if os.path.exists(fname):
        if overwrite:
            os.system('rm -rf %s' % fname)
        else:
            raise Exception(fname + ' exists.')
    return fname

def output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile):
    fname = bloutput[blformat_item.index('text')]
    if (fname == ''): return
    
    f = open(fname, 'w')

    blf = blfunc.lower()
    if (blf == 'poly'):
        ftitles = ['Fit order']
    elif (blf == 'chebyshev'):
        ftitles = ['Fit order']
    elif (blf == 'cspline'):
        ftitles = ['nPiece']
    elif (blf=='sinusoid'):
        ftitles = ['applyFFT', 'fftMethod', 'fftThresh', 'addWaveN', 'rejWaveN']
    elif (blf=='variable'):
        ftitles = []
    else:
        raise ValueError("Unsupported blfunc = %s" % blfunc)
        

    mm = maskmode.lower()
    if (mm == 'auto'):
        mtitles = ['Threshold', 'avg_limit', 'Edge']
    elif (mm == 'list'):
        mtitles = []
    else: # interact
        mtitles = []

    ctitles = ['clipThresh', 'clipNIter']

    info = [['Source Table', infile],
            ['Output File', outfile if (outfile != '') else infile],
            ['Mask mode', maskmode]]

    separator = '#' * 60 + '\n'
    
    f.write(separator)
    for i in range(len(info)):
        f.write('%12s: %s\n' % tuple(info[i]))
    f.write(separator)
    f.write('\n')
    f.close()

def prepare_for_baselining(**keywords):
    params = {}
    funcname = 'subtract_baseline'

    blfunc = keywords['blfunc']
    keys = ['datacolumn', 'outfile', 'bloutput', 'dosubtract', 'spw', 
            'updateweight', 'sigmavalue']
    if blfunc in ['poly', 'chebyshev']:
        keys += ['blfunc', 'order']
    elif blfunc == 'cspline':
        keys += ['npiece']
        funcname += ('_' + blfunc)
    elif blfunc =='sinusoid':
        keys += ['applyfft', 'fftmethod', 'fftthresh', 'addwn', 'rejwn']
        funcname += ('_' + blfunc)
    elif blfunc == 'variable':
        keys += ['blparam', 'verbose']
        funcname += ('_' + blfunc)
    else:
        raise ValueError("Unsupported blfunc = %s" % blfunc)
    if blfunc!= 'variable':
        keys += ['clip_threshold_sigma', 'num_fitting_max']
        keys += ['linefinding', 'threshold', 'avg_limit', 'minwidth', 'edge']
    for key in keys: params[key] = keywords[key]

    baseline_func = getattr(sdms, funcname)

    return params, baseline_func
    
    
def remove_sorted_table_keyword(infile):
    res = {'is_sorttab': False, 'sorttab_keywd': '', 'sorttab_name': ''}
    with sdutil.tbmanager(infile, nomodify=False) as tb:
        try:
            sorttab_keywd = 'SORTED_TABLE'
            if sorttab_keywd in tb.keywordnames():
                res['is_sorttab'] = True
                res['sorttab_keywd'] = sorttab_keywd
                res['sorttab_name'] = tb.getkeyword(sorttab_keywd)
                tb.removekeyword(sorttab_keywd)
        except Exception:
            raise

    return res

def restore_sorted_table_keyword(infile, sorttab_info):
    if sorttab_info['is_sorttab'] and (sorttab_info['sorttab_name'] != ''):
        with sdutil.tbmanager(infile, nomodify=False) as tb:
            try:
                tb.putkeyword(sorttab_info['sorttab_keywd'],
                              sorttab_info['sorttab_name'])
            except Exception:
                raise