import os
import shutil
import copy
from collections import namedtuple
import numpy
from scipy.optimize import curve_fit
import unittest

from casatasks.private.casa_transition import is_CASA6
if is_CASA6:
    from casatasks import sdsidebandsplit
    from casatools import quanta
    from casatools import image
    from casatools import ctsys
    datapath = ctsys.resolve('unittest/sdsidebandsplit/')

    # default isn't used in casatasks
    def default(atask):
        pass

    # stack_frame_find
    def stack_frame_find():
        return {}
else:
    from tasks import sdsidebandsplit
    from taskinit import qatool as quanta
    from taskinit import iatool as image
    from casa_stack_manip import stack_frame_find

    datapath = os.environ.get('CASAPATH').split()[0] + '/casatestdata/unittest/sdsidebandsplit/'


# Gaussian fit
def gauss_func(x, *p):
    amp, center, width, offset = p
    y = amp * numpy.exp(-(x - center)**2 / (2. * width**2)) + offset
    return y


def gauss_fit(x, y):
    # initial guess
    o = y.mean()
    a = numpy.abs(y - o).max()
    c = x[numpy.where(numpy.abs(y - o) == a)[0]][0]
    w = numpy.abs(y - o).sum() / a
    # print("initial guess: (%f, %f, %f, %f)" % (a,c,w,o))
    return curve_fit(gauss_func, x, y, p0=(a, c, w, o))


# a named tuple to store spectral information.
# start, end: start and end of a channel range
# max, min: max and min value of the channel range
# peak, center, width: gaussian fit parameters of the channel range
SpectralInfo = namedtuple('SpectralInfo',
                          ['start', 'end', 'max', 'min', 'peak', 'center', 'width', 'offset'])


class sdsidebandsplitTestBase(unittest.TestCase):
    standard_param = dict(
        imagename=['onepix_noiseless_shift0.image', 'onepix_noiseless_shift-102.image',
                   'onepix_noiseless_shift8.image', 'onepix_noiseless_shift62.image',
                   'onepix_noiseless_shift88.image', 'onepix_noiseless_shift100.image'],
        outfile='separated.image',
        overwrite=False,
        signalshift=[0.0, -102, +8, +62, +88, +100],
        imageshift=[0.0, 102, -8, -62, -88, -100],
        getbothside=False,
        refchan=0.0,
        refval='805GHz',
        otherside=False,
        threshold=0.2
    )

    def update_task_param(self, new_param={}):
        """
        Overwrite standard task parameter and return a new dictionary
        with updated parameters.
        Note this task does not check validity of parameter names in
        the input parameter.

        Parameter
            new_param : a dictionary of parameter names (key) and
                        values (value) to overwrite standard task
                        execution parameters in tests.
        """
        if type(new_param) is not dict:
            raise TypeError('The input should be a dictionary')
        updated_param = copy.deepcopy(self.standard_param)
        updated_param.update(new_param)
        return updated_param

    def setUp(self):
        # copy input images
        for name in self.standard_param['imagename']:
            if os.path.exists(name):
                shutil.rmtree(name)
            shutil.copytree(datapath + name, name)
        # remove output files of previous run
        prefix = self.standard_param['outfile']
        for suffix in ['.signalband', '.imageband']:
            if os.path.exists(prefix + suffix):
                shutil.rmtree(prefix + suffix)

    def tearDown(self):
        # remove input images
        for name in self.standard_param['imagename']:
            if os.path.exists(name):
                shutil.rmtree(name)

        # remove output files
        prefix = self.standard_param['outfile']
        for suffix in ['.signalband', '.imageband']:
            if os.path.exists(prefix + suffix):
                shutil.rmtree(prefix + suffix)

    def run_test(self, reference, **new_param):
        """
        Run sdsidebandsplit with given parameters and test result

        Arguments
            reference : a reference to compare results.
                        A dictionary with keys, 'signal' and 'image',
                        for signal and image sideband, respectively.
                        The data structure of 'signal' and 'image' values
                        depend on the implementation of test,
                        e.g., compare_image_data method.
            other key word arguments : test specific prameters to run tests
        """
        # Run task
        task_param = self.update_task_param(new_param)
        ret = sdsidebandsplit(**task_param)
        self.assertEqual(ret, None, 'The return value of task should be None')
        # Test results
        template_image = task_param['imagename'][0]
        self.assertTrue(os.path.exists(template_image),
                        "Could not find template image '%s'" % template_image)
        refcsys, refshape = self.get_image(template_image)
        self.assertTrue('signal' in reference,
                        'Internal Error: No valid reference value for signal sideband')
        # test signal band image
        imagename = task_param['outfile'] + '.signalband'
        self.check_result(imagename, refcsys, refshape, reference['signal'])
        # test image band image
        imagename = task_param['outfile'] + '.imageband'
        if task_param['getbothside']:
            self.assertTrue('image' in reference,
                            'Internal Error: No valid reference value for image sideband')
            # modify refcsys for image sideband
            spid = refcsys.findaxisbyname('spectral')
            refcsys.setreferencepixel(task_param['refchan'], 'spectral')
            myqa = quanta()
            refcsys.setreferencevalue(myqa.convert(task_param['refval'],
                                                   refcsys.units()[spid])['value'],
                                      'spectral')
            inc = refcsys.increment(format='n', type='spectral')['numeric'][0]
            refcsys.setincrement(-inc, 'spectral')

            self.check_result(imagename, refcsys, refshape, reference['image'])

    def get_image(self, imagename, getdata=False, getmask=False):
        """
        Returns image coordinate system object, shape.
        Optionally returns image pixel and mask values.
        Return values are in the order of
            csys, shape, data (optional), mask (optional).

        Parameters:
            imagename : the name of image
            getdata   : if True, returns image pixel values
            getmask   : if True, return image pixel mask
        """
        self.assertTrue(os.path.exists(imagename),
                        "Could not find image '%s'" % imagename)
        myia = image()
        myia.open(imagename)
        try:
            imcsys = myia.coordsys()
            imshape = myia.shape()
            if getdata:
                imdata = myia.getchunk()
            if getmask:
                immask = myia.getchunk(getmask=True)
        finally:
            myia.close()
        retval = [imcsys, imshape]
        if getdata:
            retval.append(imdata)
        if getmask:
            retval.append(immask)
        return retval

    def check_result(self, imagename, ref_csys, ref_shape, ref_value):
        """
        Compare an image with reference coordinate system, shape, and values.
        Details of tests shold be defined by methods called from this method,
        i.e., compare_image_coordinate and compare_image_data.

        Arguments
            imagename : the name of image to be tested
            ref_csys  : the reference coordinate system
            ref_shape : the reference of image shape
            ref_value : the data structure which defines image data
        """
        self.assertTrue(os.path.exists(imagename),
                        "Output image '%s' does not exist." % imagename)
        mycsys, myshape, mydata = self.get_image(imagename, getdata=True)
        self.compare_image_coordinate(mycsys, myshape, ref_csys, ref_shape)
        self.compare_image_data(mydata, ref_value)

    def compare_image_coordinate(self, csys, shape, ref_csys, ref_shape):
        """
        This method compares a coordinate system and shape with reference ones.
        The order of axes should be the same.
        Tested items:
        - dimension of shape
        - shape of each dimension
        - coordinate types of axes
        - axes units
        - reference pixel ids
        - reference values
        - increments
        """
        # dimension
        self.assertEqual(len(shape), len(ref_shape), 'Dimension of shape differs from reference.')
        # dimension of csys
        self.assertEqual(ref_csys.naxes(), csys.naxes(),
                         'Dimension of coordinate system differs from reference')
        # confirm dimension of csys and shape
        self.assertEqual(len(shape), csys.naxes(),
                         'Dimention mismatch between shape and coordinate system')
        for i in range(len(ref_shape)):
            # shape of each dimension
            self.assertEqual(shape[i], ref_shape[i],
                             'Shape in %d-th dimension differs' % i)
            # axis type
            self.assertEqual(csys.axiscoordinatetypes()[i],
                             ref_csys.axiscoordinatetypes()[i],
                             'Axis coordinate type does not match in dimension %d' % i)
            # axis unit
            self.assertEqual(csys.units()[i], ref_csys.units()[i],
                             'Axis unit does not match in dimension %d' % i)
            # axis reference pixel
            self.assertAlmostEqual(csys.referencepixel()['numeric'][i],
                                   ref_csys.referencepixel()['numeric'][i],
                                   'Reference pixel does not match in dimension %d' % i)
            # axis reference value
            self.assertAlmostEqual(csys.referencevalue()['numeric'][i],
                                   ref_csys.referencevalue()['numeric'][i],
                                   ' does not match in dimension %d' % i)
            # axis increment
            self.assertAlmostEqual(csys.increment()['numeric'][i],
                                   ref_csys.increment()['numeric'][i],
                                   ' does not match in dimension %d' % i)


class failureTestCase(sdsidebandsplitTestBase):
    """
    A class to test invalid task parameters to run sdsidebandsplit.
    Implemented based on test case table attached to CAS-8091
    """
    def setUp(self):
        self.g = stack_frame_find()
        if '__rethrow_casa_exceptions' in self.g:
            self.rethrow_backup = self.g['__rethrow_casa_exceptions']
        else:
            self.rethrow_backup = None
        self.g['__rethrow_casa_exceptions'] = True
        super(failureTestCase, self).setUp()

    def tearDown(self):
        if self.rethrow_backup is None:
            self.g.pop('__rethrow_casa_exceptions')
        else:
            self.g['__rethrow_casa_exceptions'] = self.rethrow_backup
        super(failureTestCase, self).tearDown()
        del self.g

    def run_exception(self, ref_message, **new_param):
        """
        Run task and compare
        """
        task_param = self.update_task_param(new_param)
        self.assertRaisesRegexp(Exception, ref_message, sdsidebandsplit, **task_param)

    # T-001
    def test_imagename_1image(self):
        """test failure: len(imagename)<2"""
        imagename = [self.standard_param['imagename'][0]]
        ref_message = 'At least two valid input data are required for processing'
        self.run_exception(ref_message, imagename=imagename)

    # T-005
    def test_imagename_invalidname(self):
        """test failure: len(imagename)==2 but includes an invalid imagename"""
        invalid_name = 'invalid.image'
        imagename = self.standard_param['imagename'][:-2] + [invalid_name]
        if is_CASA6:
            ref_message = '.*cReqPathVec type.*'
        else:
            ref_message = 'Could not find %s' % invalid_name
        self.run_exception(ref_message, imagename=imagename)

    # T-006
    def test_outfile_undefined(self):
        """test failure: outfile is empty"""
        ref_message = 'Output file name is undefined.'
        self.run_exception(ref_message, outfile='')

    # T-008, T-009
    def test_outfile_exists(self):
        """test failure: overwrite=F and outfile already exists."""
        for sideband in ('signalband', 'imageband'):
            print('Test %s' % sideband)
            name = self.standard_param['outfile'] + '.' + sideband
            os.mkdir(name)
            ref_message = 'Image %s already exists.' % name
            param = dict(getbothside=(sideband == 'imageband'))
            self.run_exception(ref_message, **param)
            shutil.rmtree(name)

    # T-012
    def test_shifts_undefined(self):
        """test failure: both signalshift and imageshift are undefined"""
        ref_message = 'Channel shift was not been set.'
        self.run_exception(ref_message, signalshift=[], imageshift=[])

    # T-014, T-015, T-017, T-018
    def test_shift_wrong_length(self):
        """test failure: lengh of signalshift or imageshift does not match len(imagename)"""
        ref_message = "The number of shift should match that of images"
        for sideband in ['signalshift', 'imageshift']:
            myshift = self.standard_param[sideband] + [50]
            for shift in (myshift[:5], myshift):
                print('Test len(%s)=%d' % (sideband, len(shift)))
                param = {sideband: shift}
                self.run_exception(ref_message, **param)

    # T-022, T-023, T-024
    def test_refval_invalid(self):
        """test failure: refval is invalid (empty, a negative freqency or not a frequency)"""
        ref_message = ('refval is not a dictionary',
                       'Frequency should be positive',
                       'From/to units not consistent.')
        for refval, message in zip(('', '-100GHz', '300K'), ref_message):
            print("Test refval='%s'" % refval)
            self.run_exception(message, refval=refval, getbothside=True)

    # T-027, T-031
    def test_threshold_outofrange(self):
        """test failure: threshold = 0.0, 1.0"""
        ref_message = 'Rejection limit should be > 0.0 and < 1.0'
        for thres in (0.0, 1.0):
            print('Test threshold=%f' % thres)
            self.run_exception(ref_message, threshold=thres)


class standardTestCase(sdsidebandsplitTestBase):
    """
    A class to test valid task parameters to run sdsidebandsplit.
    Implemented based on test case table attached to CAS-8091
    The input images are synthesized spectra of
    1 x 1 pixel, stokes I, 4080 channels.
    """

    standard_reference = dict(signal=(SpectralInfo(0, 1500, 4.06522, 0.99518, 2.96347, 898.4841, 30.48852, 1.08165),
                                      SpectralInfo(1700, 2700, 6.05933, 1.02671, 4.92205, 2297.6872, 19.75842, 1.14450)),
                              image=(SpectralInfo(1000, 2000, 8.07553, 1.05448, 6.94301, 1600.1052, 19.95953, 1.13069),
                                     SpectralInfo(2500, 3500, 3.06859, 1.00263, 1.99147, 2999.9953, 9.92538, 1.07988)))

    def assertAlmostEqual2(self, first, second, eps=1.0e-7, msg=None):
        if second == 0:
            self.assertLessEqual(abs(first), eps, msg)
        else:
            reldiff = abs((first - second) / second)
            self.assertLessEqual(reldiff, eps, msg)

    def compare_image_data(self, data, reference):
        """
        Compare image data with reference.

        Arguments
            data      : pixel value of image
            reference : a list of namedtuple, SpectralInfo,
                        which defines spectral feature of segments of spectrum
                        See the begining of this code about definition of SpectralInfo.
        """
        self.assertEqual(data.shape, (1, 1, 1, 4080), 'Data shape is not expected one')
        for seg in reference:
            sp = data[0, 0, 0, seg.start:seg.end]
            x = numpy.asarray(list(range(seg.start, seg.end)))
            # print('Max: ref {0} val {1}'.format(seg.max, sp.max()))
            # print('Min: ref {0} val {1}'.format(seg.min, sp.min()))
            self.assertAlmostEqual2(sp.max(), seg.max, 1e-3, 'Max comparison failed')
            self.assertAlmostEqual2(sp.min(), seg.min, 0.01, 'Min comparison failed')
            # compare gaussian fit
            fitp, _dummy = gauss_fit(x, sp)
            # print('Peak: ref {0} val {1}'.format(seg.peak, fitp[0]))
            # print('Peak Pos: ref {0} val {1}'.format(seg.center, fitp[1]))
            # print('Width: ref {0} val {1}'.format(seg.width, fitp[2]))
            # print('Offset: ref {0} val {1}'.format(seg.offset, fitp[3]))
            self.assertAlmostEqual2(fitp[0], seg.peak, 1e-3, 'Peak comparison failed')
            self.assertAlmostEqual2(fitp[1], seg.center, 1e-3, 'Peak position comparison failed')
            self.assertAlmostEqual2(numpy.abs(fitp[2]), numpy.abs(seg.width), 1e-3, 'Width comparison failed')
            self.assertAlmostEqual2(fitp[3], seg.offset, 1e-3, 'Offset comparison failed')

    # T-002
    def test_imagename_2images(self):
        """len(imagename)==2"""
        reference = dict(signal=[SpectralInfo(0, 1500, 4.0, 1.0, 2.86439, 898.70730, 30.33598, 1.09944),
                                 SpectralInfo(1500, 3000, 6.0, 1.0, 4.91510, 2297.94793, 19.555875, 1.10176954)])
        imagename = self.standard_param['imagename'][:2]
        signalshift = self.standard_param['signalshift'][:2]
        imageshift = self.standard_param['imageshift'][:2]
        self.run_test(reference, imagename=imagename,
                      signalshift=signalshift, imageshift=imageshift)

    # T-007
    def test_imagename_6images(self):
        """standard run: valid outfile, len(imagename)==6"""
        self.run_test(self.standard_reference)

    # T-010
    def test_imageband_exists_signalonly(self):
        """imageband image exists but only signal band is solved (must succeed)"""
        imageband = self.standard_param['outfile'] + '.imageband'
        os.mkdir(imageband)
        self.assertTrue(os.path.exists(imageband), "Failed to create '%s'" % imageband)
        self.run_test(self.standard_reference, getbothside=False, overwrite=False)

    # T-011
    def test_overwrite(self):
        """overwrite = True"""
        for sideband in ['.signalband', '.imageband']:
            name = self.standard_param['outfile'] + sideband
            os.mkdir(name)
            self.assertTrue(os.path.exists(name), "Failed to create '%s'" % name)
        self.run_test(self.standard_reference, overwrite=True, getbothside=True)

    # T-013
    def test_signalshift_from_imageshift(self):
        """obtain signalshift from imageshift"""
        self.run_test(self.standard_reference, signalshift=[])

    # T-016
    def test_imageshift_from_signalshift(self):
        """obtain imageshift from signalshift"""
        self.run_test(self.standard_reference, imageshift=[])

    # T-019
    def test_getbothside(self):
        """getbothside = True"""
        self.run_test(self.standard_reference, getbothside=True)

    # T-020
    def test_refchan_negative(self):
        """refchan = -1.0"""
        self.run_test(self.standard_reference, getbothside=True, refchan=-1.0)

    # T-021
    def test_refchan_large(self):
        """refchan > nchan"""
        self.run_test(self.standard_reference, getbothside=True, refchan=5000.0)

    # T-025
    def test_otherside(self):
        """otherside = True"""
        reference = dict(signal=(SpectralInfo(0, 1500, 2.77938, -0.198638, 2.90884, 898.4125, 29.57543, -0.12153),
                                 SpectralInfo(1500, 3000, 4.74292, -0.20648, 4.88534, 2298.0446, 19.75160, -0.13152)),
                         image=(SpectralInfo(1000, 2200, 6.67820, -0.24841, 6.84674, 1599.9131, 19.75747, -0.15339),
                                SpectralInfo(2500, 3500, 1.88367, -0.11315, 1.96069, 2999.8335, 9.94222, -0.07489)))
        for doboth in [False, True]:
            print('getbothside = %s' % str(doboth))
            self.run_test(reference, otherside=True, getbothside=doboth, overwrite=True)

    # T-028, T-029, T-030
    def test_threshold(self):
        """various threshold values"""
        ref_small = dict(signal=(SpectralInfo(0, 1500, 4.09385, 1.08961, 2.99224, 898.04215, 29.99680, 1.09845),
                                 SpectralInfo(1500, 3000, 6.08972, 1.08962, 4.99175, 2297.9978, 19.98419, 1.09853)))
        ref_mid = dict(signal=(SpectralInfo(0, 1500, 4.06263, 0.96078, 2.97665, 898.26048, 29.81733, 1.07312),
                               SpectralInfo(1500, 3000, 6.02824, 0.99354, 4.87862, 2297.8214, 19.10228, 1.19242)))
        ref_large = dict(signal=(SpectralInfo(0, 1500, 3.99807, 0.97891, 2.96490, 898.00416, -29.48243, 1.04387),
                                 SpectralInfo(1500, 3000, 5.99822, 0.96876, 4.83709, 2298.0035, -18.79443, 1.21969)))
        for val, ref in zip((0.0001, 0.5, 0.9999), (ref_small, ref_mid, ref_large)):
            print('Threshold=%f' % val)
            self.run_test(ref, threshold=val, overwrite=True)


class MultiPixTestCase(sdsidebandsplitTestBase):
    """
    A class to test sdsidebandsplit with multi-pixel images.
    Implemented based on test case table attached to CAS-8091 (T-032)
    """
    standard_param = dict(
        imagename=['multipix_noiseless_shift0.image', 'multipix_noiseless_shift-102.image',
                   'multipix_noiseless_shift8.image', 'multipix_noiseless_shift62.image',
                   'multipix_noiseless_shift88.image', 'multipix_noiseless_shift100.image'],
        outfile='separated.image',
        overwrite=False,
        signalshift=[0.0, -102, +8, +62, +88, +100],
        imageshift=[0.0, 102, -8, -62, -88, -100],
        getbothside=False,
        refchan=0.0,
        refval='805GHz',
        otherside=False,
        threshold=0.2
    )

    def compare_image_data(self, data, reference):
        """
        Compare image data with reference.

        Arguments
            data      : pixel value of image
            reference : a reference image name
        """
        self.assertTrue(os.path.exists(reference),
                        "Could not find reference image '%s'" % reference)
        myia = image()
        myia.open(reference)
        try:
            ref_data = myia.getchunk()
        finally:
            myia.close()
        self.assertEqual(data.shape, ref_data.shape, 'Image shape comparison failed')
        self.assertAlmostEqual(data.max(), ref_data.max(), 3, 'Max comparison failed')
#         self.assertEqual(numpy.where(data==data.max()),
#                          numpy.where(ref_data==ref_data.max()),
#                          'Max position comparison failed')
        self.assertAlmostEqual(data.min(), ref_data.min(), 3, 'Min comparison failed')
#         self.assertEqual(numpy.where(data==data.min()),
#                          numpy.where(ref_data==ref_data.min()),
#                          'Max position comparison failed')
        self.assertAlmostEqual(data.std(), ref_data.std(), 3, 'StdDev comparison failed')

    # T-032
    def test_multi_pixels(self):
        """images with 10x10 spatial pixel"""
        reference = dict(signal=datapath + 'ref_multipix.signalband',
                         image=datapath + 'ref_multipix.imageband')
        self.run_test(reference, getbothside=True)


def suite():
    return [failureTestCase, standardTestCase, MultiPixTestCase]


if is_CASA6:
    if __name__ == '__main__':
        unittest.main()