import collections
import glob
import itertools
import numpy
import os

try:
    from casatools import msmetadata
    from casatools import table
    from casatools import measures
    from casatools import quanta
    from casatools import ms as mstool
    from casatools import image
    from casatasks import casalog
except Exception:
    from __casac__.msmetadata import msmetadata
    from __casac__.table import table
    from __casac__.measures import measures
    from __casac__.quanta import quanta
    from __casac__.ms import ms as mstool
    from __casac__.image import image
    from taskinit import casalog

MetaDataSet = collections.namedtuple(
    'MetaDataSet',
    ['msmeta', 'ephemmeta']
)

MsMeta = collections.namedtuple(
    'MsMeta',
    ['positions', 'times', 'freqmin', 'freqmax']
)

EphemMeta = collections.namedtuple(
    'EphemMeta',
    ['table', 'times', 'unit_time', 'velocities', 'unit_vel', 'frame_vel']
)

FrequencyRange = collections.namedtuple(
    'FrequencyRange',
    ['min', 'max', 'ref']
)

COMMON_VELOCITY_UNIT = 'km/s'

DEBUG = False


def debug_print(msg):
    if DEBUG:
        for m in msg.split('\n'):
            casalog.post('DEBUG: {}'.format(m))


def inspect_ms(vis, fieldid, spwid, chanstart=0, nchan=-1):
    """Inspect MS

    Arguments:
        vis {str} -- name of MS
        fieldid {int} -- FIELD_ID for target
        spwid {int} -- SPW_ID for target spw
        chanstart {int} -- start channel (default: 0)
        nchan {int} -- number of channels
                       (default: -1 => from chanstart to end channel of spw)

    Returns:
        MsMeta -- namedtuple containing metadata
                      positions: list of antenna positions
                      times: list of timestamps for target field
                      freqmin: minimum frequency of given spw as measure
                      freqmax: maximum frequency of given spw as measure
    """
    msmd = msmetadata()
    msmd.open(vis)
    try:
        positions = list(map(msmd.antennaposition, msmd.antennaids()))
        chanfreqs = msmd.chanfreqs(spwid)
        chanwidths = msmd.chanwidths(spwid)
    finally:
        msmd.close()

    cw = chanwidths.mean()
    start = chanstart
    end = min(chanstart + nchan, len(chanfreqs)) if nchan >= 0 else len(chanfreqs)
    debug_print('start {}, end {}'.format(start, end))
    assert 0 <= start
    assert 0 < end
    assert start < end
    freqmin = chanfreqs[start:end].min() - cw / 2
    freqmax = chanfreqs[start:end].max() + cw / 2
    debug_print('freqmin {}, freqmax {}'.format(freqmin, freqmax))

    ms = mstool()
    ms.open(vis)
    try:
        ms.msselect({'spw': str(spwid), 'field': str(fieldid), 'scanintent': 'OBSERVE_TARGET#ON_SOURCE'})
        data = ms.getdata(['time'])
        times = data['time']
    finally:
        ms.close()

    tb = table()
    try:
        tb.open(os.path.join(vis, 'SPECTRAL_WINDOW'))
        freq_ref_id = tb.getcell('MEAS_FREQ_REF', spwid)
    finally:
        tb.close()

    me = measures()
    codes = me.listcodes(me.frequency())
    refmap = codes['normal']
    assert 0 <= freq_ref_id and freq_ref_id < len(refmap)
    freq_ref_str = refmap[freq_ref_id]

    qa = quanta()

    metadata = MsMeta(
        positions,
        times,
        me.frequency(rf=freq_ref_str, v0=qa.quantity(freqmin, 'Hz')),
        me.frequency(rf=freq_ref_str, v0=qa.quantity(freqmax, 'Hz'))
    )

    me.done()

    return metadata


def get_ephem_table(vis, fieldid):
    """Return a name of the Ephemeris table corresponding to given FIELD_ID

    Arguments:
        vis {str} -- name of the MS
        fieldid {int} -- FIELD_ID

    Returns:
        str -- name of the Ephemeris table
    """
    tb = table()
    field_table = os.path.join(vis, 'FIELD')
    tb.open(field_table)
    try:
        ephem_id = tb.getcell('EPHEMERIS_ID', fieldid)
    finally:
        tb.close()

    pattern = os.path.join(field_table, 'EPHEM{}*'.format(ephem_id))
    candidates = glob.glob(pattern)
    assert len(candidates) == 1
    ephem_table = candidates[0]

    return ephem_table


def inspect_ephem(name):
    """inspect Ephemeris table

    Arguments:
        name {str} -- name of Ephemeris table

    Returns:
        EphemMeta -- data of Ephemeris table
                         time: time list
                         unit_time: unit of the time
                         velocities: velocity list
                         unit_vel: unit of the velocity
                         frame_vel: reference frame of the velocity
    """
    tb = table()
    tb.open(name)
    try:
        eph_time = tb.getcol('MJD')
        eph_time_unit = tb.getcolkeyword('MJD', 'UNIT')
        eph_radvel = tb.getcol('RadVel')
        eph_radvel_unit = tb.getcolkeyword('RadVel', 'UNIT')
        eph_geo_dist = tb.getkeyword('GeoDist')
    finally:
        tb.close()

    # Logic borrowed from FTMachine::initSourceFreqConv
    # The eph_geo_dist is a distrance from the GEOCENTER in km.
    # If eph_geo_dist > 1e-3 km (=1m), velocity reference
    # frame is regarded as TOPO (TOPOCENTRIC).
    # Otherwise, the frame should be GEO (GEOCENTRIC).
    qa = quanta()
    eph_geo_dist = qa.quantity(eph_geo_dist, 'km')
    eph_geo_threshold = qa.quantity(1.0e-3, 'km')
    if qa.gt(eph_geo_dist, eph_geo_threshold):
        eph_vel_frame = 'TOPO'
    else:
        eph_vel_frame = 'GEO'

    data = EphemMeta(
        table=name,
        times=eph_time,
        unit_time=eph_time_unit,
        velocities=eph_radvel,
        unit_vel=eph_radvel_unit,
        frame_vel=eph_vel_frame
    )

    return data


def update_measure(measures_instance, position=None, epoch=None, direction=None, comet_table=None):
    if position is not None:
        measures_instance.doframe(position)

    if epoch is not None:
        measures_instance.doframe(epoch)

    if direction is not None:
        measures_instance.doframe(direction)

    if comet_table is not None:
        assert isinstance(comet_table, str)
        assert os.path.exists(comet_table)
        measures_instance.framecomet(comet_table)
        measures_instance.doframe(measures_instance.direction('COMET'))

    return measures_instance


def get_doppler(measure_instance, radial_velocity, velocity_unit, velocity_frame):
    qa = quanta()
    vel = qa.convert(qa.quantity(radial_velocity, velocity_unit), COMMON_VELOCITY_UNIT)
    if velocity_frame == 'GEO':
        # relative velocity between GEO and TOPO must be subtracted
        radvel_zero = measure_instance.measure(v=measure_instance.radialvelocity(rf='TOPO', v0=qa.quantity(0, COMMON_VELOCITY_UNIT)), rf='GEO')
        qzero = qa.convert(radvel_zero['m0'], COMMON_VELOCITY_UNIT)
        debug_print('velocity in GEO frame. Require conversion to TOPO.')
        debug_print('Original Velocity: {value} {unit}'.format(**vel))
        debug_print('Delta Velocity: {value} {unit}'.format(**qzero))
        vel = qa.sub(vel, radvel_zero['m0'])
    debug_print('TOPO velocity: {value} {unit}'.format(**vel))
    doppler = measure_instance.doppler(rf='RELATIVISTIC', v0=vel)

    return doppler


def ms_freq_range(metadataset):
    msmeta = metadataset.msmeta
    ephem_data = metadataset.ephemmeta
    ephem_table = ephem_data.table

    qa = quanta()
    min_frequency = None
    max_frequency = None
    eph_time = qa.convert(qa.quantity(ephem_data.times, ephem_data.unit_time), 's')['value']
    eph_vel = ephem_data.velocities
    interpolated_velocities = numpy.interp(msmeta.times, eph_time, eph_vel)
    for item in itertools.product(msmeta.positions, zip(msmeta.times, interpolated_velocities)):
        position = item[0]
        timestamp = item[1][0]
        velocity = item[1][1]
        me = measures()
        me.done()
        epoch = me.epoch('UTC', qa.quantity(timestamp, 's'))
        me = update_measure(me, epoch=epoch, position=position, comet_table=ephem_table)
        doppler = get_doppler(me, velocity, ephem_data.unit_vel, ephem_data.frame_vel)
        fmin, fmax = map(lambda x: me.torestfrequency(x, doppler), [msmeta.freqmin, msmeta.freqmax])
        assert fmin['refer'] == 'REST'
        assert fmax['refer'] == 'REST'
        if min_frequency is None or qa.lt(fmin['m0'], min_frequency) is True:
            min_frequency = fmin['m0']
        if max_frequency is None or qa.gt(fmax['m0'], max_frequency) is True:
            max_frequency = fmax['m0']
        debug_print('min freq: {}'.format(qa.tos(min_frequency)))
        debug_print('max freq: {}'.format(qa.tos(max_frequency)))

    return FrequencyRange(min_frequency, max_frequency, 'REST')


def image_freq_range(imagename):
    ia = image()
    ia.open(imagename)
    csys = ia.coordsys()
    try:
        imshape = ia.shape()
        chmin = -0.5
        chmax = imshape[3] - 1 + 0.5
        refpix = [0, 0, 0, chmin]
        wmin = csys.toworld(refpix, format='m')
        refpix[3] = chmax
        wmax = csys.toworld(refpix, format='m')
    finally:
        csys.done()
        ia.close()

    refmin = wmin['measure']['spectral']['frequency']['refer']
    refmax = wmax['measure']['spectral']['frequency']['refer']
    assert refmin == refmax
    fmin = wmin['measure']['spectral']['frequency']['m0']
    fmax = wmax['measure']['spectral']['frequency']['m0']

    return FrequencyRange(fmin, fmax, refmax)


def get_lorentz_factor(metadataset):
    msmeta = metadataset.msmeta
    ephem_data = metadataset.ephemmeta
    ephem_table = ephem_data.table

    me = measures()
    me.done()
    qa = quanta()
    epoch = me.epoch('UTC', qa.quantity(msmeta.times[0], 's'))
    position = me.observatory('ALMA')
    me = update_measure(me, epoch=epoch, position=position, comet_table=ephem_table)
    velocity = ephem_data.velocities[0]
    unit = ephem_data.unit_vel
    frame = ephem_data.frame_vel
    doppler = get_doppler(me, velocity, unit, frame)
    refvel = qa.convert(doppler['m0'], COMMON_VELOCITY_UNIT)
    speed_of_light = qa.convert(qa.constants('c'), COMMON_VELOCITY_UNIT)
    return qa.div(refvel, speed_of_light)['value']


def frequency_value(freq):
    if isinstance(freq, (int, float)):
        val = freq
    elif isinstance(freq, dict) and 'value' in freq:
        val = freq['value']
    else:
        val = None

    return val


def is_frequency_close(freq1, freq2, lorentz_factor, rtol=1e-1):
    val1 = frequency_value(freq1)
    val2 = frequency_value(freq2)
    tolerance = abs(lorentz_factor * rtol)
    reldiff = abs((val2 - val1) / val1)
    debug_print('values: {} {}'.format(val1, val2))
    debug_print('tolerance: {} (factor {})'.format(tolerance, lorentz_factor))
    debug_print('relative diff: {}'.format(reldiff))
    return reldiff <= tolerance


def get_metadataset(vis, fieldid, spwid, chanstart=0, nchan=-1):
    msmeta = inspect_ms(vis, fieldid, spwid)

    ephem_table = get_ephem_table(vis, fieldid)

    ephem_data = inspect_ephem(ephem_table)

    return MetaDataSet(msmeta, ephem_data)