from collections import Counter
import datetime
import os
import shutil

from casatasks import casalog
from casatools import ms as mstool
from casatools import singledishms

from . import sdutil
from .mstools import write_history

ms = mstool()


@sdutil.callable_sdtask_decorator
def sdbaseline(infile=None, datacolumn=None, antenna=None, field=None,
               spw=None, timerange=None, scan=None, pol=None, intent=None,
               reindex=None, maskmode=None, thresh=None, avg_limit=None,
               minwidth=None, edge=None, blmode=None, dosubtract=None,
               blformat=None, bloutput=None, bltable=None, blfunc=None,
               order=None, npiece=None, applyfft=None, fftmethod=None,
               fftthresh=None, addwn=None, rejwn=None, clipthresh=None,
               clipniter=None, blparam=None, verbose=None,
               updateweight=None, sigmavalue=None,
               showprogress=None, minnrow=None,
               outfile=None, overwrite=None):

    temp_outfile = ''

    try:
        # CAS-12985 requests the following params be given case insensitively,
        # so they need to be converted to lowercase here (2021/1/28 WK)
        blfunc = blfunc.lower()
        blmode = blmode.lower()
        fftmethod = fftmethod.lower()
        if isinstance(fftthresh, str):
            fftthresh = fftthresh.lower()

        if (spw == ''):
            spw = '*'

        if not os.path.exists(infile):
            raise ValueError("infile='" + str(infile) + "' does not exist.")
        if (outfile == '') or not isinstance(outfile, str):
            outfile = infile.rstrip('/') + '_bs'
            casalog.post("outfile is empty or non-string. set to '" + outfile + "'")
        if (not overwrite) and dosubtract and os.path.exists(outfile):
            raise ValueError("outfile='%s' exists, and cannot overwrite it." % (outfile))
        if (blfunc == 'variable') and not os.path.exists(blparam):
            raise ValueError("input file '%s' does not exists" % blparam)

        if (blmode == 'fit'):
            temp_outfile = _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan,
                                   pol, intent, reindex, maskmode, thresh, avg_limit, minwidth,
                                   edge, dosubtract, blformat, bloutput, blfunc, order, npiece,
                                   applyfft, fftmethod, fftthresh, addwn, rejwn, clipthresh,
                                   clipniter, blparam, verbose, updateweight, sigmavalue,
                                   outfile, overwrite)
        elif (blmode == 'apply'):
            _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
                      reindex, bltable, updateweight, sigmavalue, outfile, overwrite)

        # Remove {WEIGHT|SIGMA}_SPECTRUM columns if updateweight=True (CAS-13161)
        if updateweight:
            with sdutil.table_manager(outfile, nomodify=False) as mytb:
                cols_spectrum = ['WEIGHT_SPECTRUM', 'SIGMA_SPECTRUM']
                cols_remove = [col for col in cols_spectrum if col in mytb.colnames()]
                if len(cols_remove) > 0:
                    mytb.removecols(' '.join(cols_remove))

        # Write history to outfile
        if dosubtract:
            param_names = sdbaseline.__code__.co_varnames[:sdbaseline.__code__.co_argcount]
            var_local = locals()
            param_vals = [var_local[p] for p in param_names]
            write_history(ms, outfile, 'sdbaseline', param_names,
                          param_vals, casalog)

    finally:
        if (not dosubtract):
            # Remove (skeleton) outfile
            if temp_outfile != '':
                outfile = temp_outfile
            remove_data(outfile)


blformat_item = ['csv', 'text', 'table']
blformat_ext = ['csv', 'txt', 'bltable']

mesg_invalid_wavenumber = 'wrong value given for addwn/rejwn'


def remove_data(filename):
    if not os.path.exists(filename):
        return

    if os.path.isdir(filename):
        shutil.rmtree(filename)
    elif os.path.isfile(filename):
        os.remove(filename)
    else:
        # could be a symlink
        os.remove(filename)


def is_empty(blformat):
    """Check if blformat is empty.

    returns True if blformat is None, '', [] and
    a string list containing only '' (i.e., ['', '', ..., ''])
    """
    if isinstance(blformat, list):
        return all(map(is_empty, blformat))

    return not blformat


def prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite):
    # force to string list
    blformat = force_to_string_list(blformat, 'blformat')
    bloutput = force_to_string_list(bloutput, 'bloutput')

    # the default bloutput value '' is expanded to a list
    # with length of blformat, and with '' throughout.
    if (bloutput == ['']):
        bloutput *= len(blformat)

    # check length
    if (len(blformat) != len(bloutput)):
        raise ValueError('blformat and bloutput must have the same length.')

    # check duplication
    if has_duplicate_nonnull_element(blformat):
        raise ValueError('duplicate elements in blformat.')
    if has_duplicate_nonnull_element_ex(bloutput, blformat):
        raise ValueError('duplicate elements in bloutput.')

    # fill bloutput items to be output, then rearrange them
    # in the order of blformat_item.
    bloutput = normalise_bloutput(infile, blformat, bloutput, overwrite)

    return blformat, bloutput


def force_to_string_list(s, name):
    mesg = '%s must be string or list of string.' % name
    if isinstance(s, str):
        s = [s]
    elif isinstance(s, list):
        for i in range(len(s)):
            if not isinstance(s[i], str):
                raise ValueError(mesg)
    else:
        raise ValueError(mesg)
    return s


def has_duplicate_nonnull_element(in_list):
    # return True if in_list has duplicated elements other than ''
    duplicates = [key for key, val in Counter(in_list).items() if val > 1]
    len_duplicates = len(duplicates)

    if (len_duplicates >= 2):
        return True
    elif (len_duplicates == 1):
        return (duplicates[0] != '')
    else:  # len_duplicates == 0
        return False


def has_duplicate_nonnull_element_ex(lst, base):
    # lst and base must have the same length.
    #
    # (1) extract elements from lst and make a new list
    #     if the element of base with the same index
    #     is not ''.
    # (2) check if the list made in (1) has duplicated
    #     elements other than ''.

    return has_duplicate_nonnull_element(
        [lst[i] for i in range(len(lst)) if base[i] != ''])


def normalise_bloutput(infile, blformat, bloutput, overwrite):
    return [get_normalised_name(infile, blformat, bloutput, item[0], item[1], overwrite)
            for item in zip(blformat_item, blformat_ext)]


def get_normalised_name(infile, blformat, bloutput, name, ext, overwrite):
    fname = ''
    blformat_lower = [s.lower() for s in blformat]
    if (name in blformat_lower):
        fname = bloutput[blformat_lower.index(name)]
        if (fname == ''):
            fname = infile + '_blparam.' + ext
    if os.path.exists(fname):
        if overwrite:
            remove_data(fname)
        else:
            raise ValueError(fname + ' exists.')
    return fname


def output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile):
    fname = bloutput[blformat_item.index('text')]
    if (fname == ''):
        return

    with open(fname, 'w') as f:
        info = [['Source Table', infile],
                ['Output File', outfile if (outfile != '') else infile],
                ['Mask mode', maskmode]]

        separator = '#' * 60 + '\n'

        f.write(separator)
        for i in range(len(info)):
            f.write('%12s: %s\n' % tuple(info[i]))
        f.write(separator)
        f.write('\n')


def get_temporary_file_name(basename):
    name = basename + '_sdbaseline_pid' + str(os.getpid()) + '_' \
        + datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
    return name


def parse_wavenumber_param(wn):
    if isinstance(wn, bool):
        raise ValueError(mesg_invalid_wavenumber)
    elif isinstance(wn, list):
        __check_positive_or_zero(wn)
        wn_uniq = list(set(wn))
        wn_uniq.sort()
        return ','.join(__get_strlist(wn_uniq))
    elif isinstance(wn, tuple):
        __check_positive_or_zero(wn)
        wn_uniq = list(set(wn))
        wn_uniq.sort()
        return ','.join(__get_strlist(wn_uniq))
    elif isinstance(wn, int):
        __check_positive_or_zero(wn)
        return str(wn)
    elif isinstance(wn, str):
        if '.' in wn:
            # case of float value as string
            raise ValueError(mesg_invalid_wavenumber)
        elif ',' in wn:
            # cases 'a,b,c,...'
            val0 = wn.split(',')
            __check_positive_or_zero(val0)
            val = []
            for v in val0:
                val.append(int(v))
            val.sort()
            res = list(set(val))  # uniq
        elif '-' in wn:
            # case 'a-b' : return [a,a+1,...,b-1,b]
            val = wn.split('-')
            __check_positive_or_zero(val)
            val = [int(val[0]), int(val[1])]
            val.sort()
            res = [i for i in range(val[0], val[1] + 1)]
        elif '~' in wn:
            # case 'a~b' : return [a,a+1,...,b-1,b]
            val = wn.split('~')
            __check_positive_or_zero(val)
            val = [int(val[0]), int(val[1])]
            val.sort()
            res = [i for i in range(val[0], val[1] + 1)]
        elif wn[:2] == '<=' or wn[:2] == '=<':
            # cases '<=a','=<a' : return [0,1,...,a-1,a]
            val = wn[2:]
            __check_positive_or_zero(val)
            res = [i for i in range(int(val) + 1)]
        elif wn[-2:] == '>=' or wn[-2:] == '=>':
            # cases 'a>=','a=>' : return [0,1,...,a-1,a]
            val = wn[:-2]
            __check_positive_or_zero(val)
            res = [i for i in range(int(val) + 1)]
        elif wn[0] == '<':
            # case '<a' :         return [0,1,...,a-2,a-1]
            val = wn[1:]
            __check_positive_or_zero(val, False)
            res = [i for i in range(int(val))]
        elif wn[-1] == '>':
            # case 'a>' :         return [0,1,...,a-2,a-1]
            val = wn[:-1]
            __check_positive_or_zero(val, False)
            res = [i for i in range(int(val))]
        elif wn[:2] == '>=' or wn[:2] == '=>':
            # cases '>=a','=>a' : return [a,-999], which is
            #                     then interpreted in C++
            #                     side as [a,a+1,...,a_nyq]
            #                     (CAS-3759)
            val = wn[2:]
            __check_positive_or_zero(val)
            res = [int(val), -999]
        elif wn[-2:] == '<=' or wn[-2:] == '=<':
            # cases 'a<=','a=<' : return [a,-999], which is
            #                     then interpreted in C++
            #                     side as [a,a+1,...,a_nyq]
            #                     (CAS-3759)
            val = wn[:-2]
            __check_positive_or_zero(val)
            res = [int(val), -999]
        elif wn[0] == '>':
            # case '>a' :         return [a+1,-999], which is
            #                     then interpreted in C++
            #                     side as [a+1,a+2,...,a_nyq]
            #                     (CAS-3759)
            val0 = wn[1:]
            val = int(val0) + 1
            __check_positive_or_zero(val)
            res = [val, -999]
        elif wn[-1] == '<':
            # case 'a<' :         return [a+1,-999], which is
            #                     then interpreted in C++
            #                     side as [a+1,a+2,...,a_nyq]
            #                     (CAS-3759)
            val0 = wn[:-1]
            val = int(val0) + 1
            __check_positive_or_zero(val)
            res = [val, -999]
        else:
            # case 'a'
            __check_positive_or_zero(wn)
            res = [int(wn)]

        # return res
        return ','.join(__get_strlist(res))
    else:
        raise ValueError(mesg_invalid_wavenumber)


def __get_strlist(param):
    return [str(p) for p in param]


def check_fftthresh(fftthresh):
    """Validate fftthresh value.

    The fftthresh must be one of the following:
    (1) positive value (float, integer or string)
    (2) 'top' + positive integer value
    (3) positive float value + 'sigma'
    """
    has_invalid_type = False
    val_not_positive = False

    if isinstance(fftthresh, bool):
        # Checking for bool must precede checking for integer
        has_invalid_type = True
    elif isinstance(fftthresh, int) or isinstance(fftthresh, float):
        if (fftthresh <= 0.0):
            val_not_positive = True
    elif isinstance(fftthresh, str):
        try:
            if (3 < len(fftthresh)) and (fftthresh[:3] == 'top'):
                if (int(fftthresh[3:]) <= 0):
                    val_not_positive = True
            elif (5 < len(fftthresh)) and (fftthresh[-5:] == 'sigma'):
                if (float(fftthresh[:-5]) <= 0.0):
                    val_not_positive = True
            else:
                if (float(fftthresh) <= 0.0):
                    val_not_positive = True
        except Exception:
            raise ValueError('fftthresh has a wrong format.')
    else:
        has_invalid_type = True

    if has_invalid_type:
        raise ValueError('fftthresh must be float or integer or string.')
    if val_not_positive:
        raise ValueError('threshold given to fftthresh must be positive.')


def __check_positive_or_zero(param, allowzero=True):
    if isinstance(param, list) or isinstance(param, tuple):
        for i in range(len(param)):
            __do_check_positive_or_zero(int(param[i]), allowzero)
    elif isinstance(param, int):
        __do_check_positive_or_zero(param, allowzero)
    elif isinstance(param, str):
        __do_check_positive_or_zero(int(param), allowzero)
    else:
        raise ValueError(mesg_invalid_wavenumber)


def __do_check_positive_or_zero(param, allowzero):
    if (param < 0) or ((param == 0) and not allowzero):
        raise ValueError(mesg_invalid_wavenumber)


def prepare_for_baselining(**keywords):
    params = {}
    funcname = 'subtract_baseline'

    blfunc = keywords['blfunc']
    keys = ['datacolumn', 'outfile', 'bloutput', 'dosubtract', 'spw',
            'updateweight', 'sigmavalue']
    if blfunc in ['poly', 'chebyshev']:
        keys += ['blfunc', 'order']
    elif blfunc == 'cspline':
        keys += ['npiece']
        funcname += ('_' + blfunc)
    elif blfunc == 'sinusoid':
        keys += ['applyfft', 'fftmethod', 'fftthresh', 'addwn', 'rejwn']
        funcname += ('_' + blfunc)
    elif blfunc == 'variable':
        keys += ['blparam', 'verbose']
        funcname += ('_' + blfunc)
    else:
        raise ValueError("Unsupported blfunc = %s" % blfunc)
    if blfunc != 'variable':
        keys += ['clip_threshold_sigma', 'num_fitting_max']
        keys += ['linefinding', 'threshold', 'avg_limit', 'minwidth', 'edge']
    for key in keys:
        params[key] = keywords[key]

    baseline_func = getattr(keywords['sdms'], funcname)

    return params, baseline_func


def remove_sorted_table_keyword(infile):
    res = {'is_sorttab': False, 'sorttab_keywd': '', 'sorttab_name': ''}

    with sdutil.table_manager(infile, nomodify=False) as tb:
        sorttab_keywd = 'SORTED_TABLE'
        if sorttab_keywd in tb.keywordnames():
            res['is_sorttab'] = True
            res['sorttab_keywd'] = sorttab_keywd
            res['sorttab_name'] = tb.getkeyword(sorttab_keywd)
            tb.removekeyword(sorttab_keywd)

    return res


def restore_sorted_table_keyword(infile, sorttab_info):
    if sorttab_info['is_sorttab'] and (sorttab_info['sorttab_name'] != ''):
        with sdutil.table_manager(infile, nomodify=False) as tb:
            tb.putkeyword(sorttab_info['sorttab_keywd'], sorttab_info['sorttab_name'])


def _do_apply(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
              reindex, bltable, updateweight, sigmavalue, outfile, overwrite):
    if not os.path.exists(bltable):
        raise ValueError("file specified in bltable '%s' does not exist." % bltable)

    # Note: the condition "infile != outfile" in the following line is for safety
    # to prevent from accidentally removing infile by setting outfile=infile.
    # Don't remove it.
    if overwrite and (infile != outfile) and os.path.exists(outfile):
        remove_data(outfile)

    sorttab_info = remove_sorted_table_keyword(infile)

    with sdutil.tool_manager(infile, singledishms) as mysdms:
        selection = ms.msseltoindex(vis=infile, spw=spw, field=field,
                                    baseline=antenna, time=timerange,
                                    scan=scan)
        mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field,
                             antenna=antenna, timerange=timerange,
                             scan=scan, polarization=pol, intent=intent,
                             reindex=reindex)
        mysdms.apply_baseline_table(bltable=bltable,
                                    datacolumn=datacolumn,
                                    spw=spw,
                                    updateweight=updateweight,
                                    sigmavalue=sigmavalue,
                                    outfile=outfile)

    restore_sorted_table_keyword(infile, sorttab_info)


def _do_fit(infile, datacolumn, antenna, field, spw, timerange, scan, pol, intent,
            reindex, maskmode, thresh, avg_limit, minwidth, edge, dosubtract, blformat,
            bloutput, blfunc, order, npiece, applyfft, fftmethod, fftthresh, addwn,
            rejwn, clipthresh, clipniter, blparam, verbose, updateweight, sigmavalue,
            outfile, overwrite):

    temp_outfile = ''

    if (not dosubtract) and is_empty(blformat):
        raise ValueError("blformat must be specified when dosubtract is False")

    blformat, bloutput = prepare_for_blformat_bloutput(infile, blformat, bloutput, overwrite)

    output_bloutput_text_header(blformat, bloutput, blfunc, maskmode, infile, outfile)

    # Set temporary name for output MS if dosubtract is False and outfile exists
    # for not removing/overwriting outfile that already exists
    if os.path.exists(outfile):
        # Note: the condition "infile != outfile" in the following line is for safety
        # to prevent from accidentally removing infile by setting outfile=infile
        # Don't remove it.
        if dosubtract and overwrite and (infile != outfile):
            remove_data(outfile)
        elif (not dosubtract):
            outfile = get_temporary_file_name(infile)
            temp_outfile = outfile

    if (blfunc == 'variable'):
        sorttab_info = remove_sorted_table_keyword(infile)
    elif (blfunc == 'sinusoid'):
        addwn = parse_wavenumber_param(addwn)
        rejwn = parse_wavenumber_param(rejwn)
        check_fftthresh(fftthresh)

    with sdutil.tool_manager(infile, singledishms) as mysdms:
        selection = ms.msseltoindex(vis=infile, spw=spw, field=field, baseline=antenna,
                                    time=timerange, scan=scan)
        mysdms.set_selection(spw=sdutil.get_spwids(selection), field=field, antenna=antenna,
                             timerange=timerange, scan=scan, polarization=pol, intent=intent,
                             reindex=reindex)
        params, func = prepare_for_baselining(sdms=mysdms,
                                              blfunc=blfunc,
                                              datacolumn=datacolumn,
                                              outfile=outfile,
                                              bloutput=','.join(bloutput),
                                              dosubtract=dosubtract,
                                              spw=spw,
                                              pol=pol,
                                              linefinding=(maskmode == 'auto'),
                                              threshold=thresh,
                                              avg_limit=avg_limit,
                                              minwidth=minwidth,
                                              edge=edge,
                                              order=order,
                                              npiece=npiece,
                                              applyfft=applyfft,
                                              fftmethod=fftmethod,
                                              fftthresh=fftthresh,
                                              addwn=addwn,
                                              rejwn=rejwn,
                                              clip_threshold_sigma=clipthresh,
                                              num_fitting_max=clipniter + 1,
                                              blparam=blparam,
                                              verbose=verbose,
                                              updateweight=updateweight,
                                              sigmavalue=sigmavalue)
        func(**params)

    if (blfunc == 'variable'):
        restore_sorted_table_keyword(infile, sorttab_info)

    return temp_outfile