from __future__ import absolute_import
# get is_CASA6 and is_python3
from casatasks.private.casa_transition import is_CASA6

import os
import sys
import numpy as np
import pylab as pl
from textwrap import wrap

from casatools import table, msmetadata, quanta, ms, measures, ctsys
from casatasks import casalog

def plotants( vis=None, figfile=None,
              antindex=None, logpos=None,
              exclude=None, checkbaselines=None,
              title=None, showgui=None ):
        """Plot the antenna distribution in the local reference frame:
                The location of the antennas in the MS will be plotted with
                X-toward local east; Y-toward local north.  The name of each
                antenna is shown next to its respective location.

                Keyword arguments:
                vis -- Name of input visibility file.
                                default: none. example: vis='ngc5921.ms'

                figfile -- Save the plotted figure in this file.
                                default: ''. example: figfile='myFigure.png'

                antindex -- Label antennas with name and antenna ID
                                default: False. example: antindex=True

                logpos -- Produce a logarithmic position plot.
                                default: False. example: logpos=True

                exclude -- antenna IDs or names to exclude from plotting
                                default: []. example: exclude=[2,3,4], exclude='DV15'

                checkbaselines -- Only plot antennas in the MAIN table.
                                This can be useful after a split.  WARNING:  Setting
                                checkbaselines to True will add to runtime in
                                proportion to the number of rows in the dataset.
                                default: False. example: checkbaselines=True

                title -- Title written along top of plot
                                default: ''

                You can zoom in by pressing the magnifier button (bottom,
                third from right) and making a rectangular region with
                the mouse.  Press the home button (left most button) to
                remove zoom.

                A hard-copy of this plot can be obtained by pressing the
                button on the right at the bottom of the display. A file
                dialog will allow you to choose the directory, filename,
                and format of the export.
        """

        # use ctsys values
        showplot = showgui and not ctsys.getnogui() and not ctsys.getagg() and not ctsys.getpipeline()

        if not showplot:
                pl.close()
                pl.ioff()
        else:
                pl.show()
                pl.ion()
        pl.clf()

        # remove trailing / for title basename
        if vis[-1]=='/':
                vis = vis[:-1]
        myms = ms( )
        try:
                exclude = myms.msseltoindex(vis, baseline=exclude)['antenna1'].tolist()
        except RuntimeError as rterr:  # MSSelection failed
                errmsg = str(rterr)
                errmsg = errmsg.replace('specificion', 'specification')
                errmsg = errmsg.replace('Antenna Expression: ', '')
                raise RuntimeError("Exclude selection error: " + errmsg)

        telescope, names, ids, xpos, ypos, stations = getPlotantsAntennaInfo(vis,
                logpos, exclude, checkbaselines)
        if not names:
                raise ValueError("No antennas selected.  Exiting plotants.")

        if not title:
                msname = os.path.basename(vis)
                title = "Antenna Positions for "
                if len(msname) > 55:
                        title += '\n'
                title += msname

        if logpos:
                plotAntennasLog(telescope, names, ids, xpos, ypos, antindex, stations)
        else:
                plotAntennas(telescope, names, ids, xpos, ypos, antindex, stations, showplot)
        pl.title(title, {'fontsize':12})
        
        if figfile:
                pl.savefig(figfile)

def getPlotantsAntennaInfo(msname, log, exclude, checkbaselines):

        tb = table( )
        me = measures( )
        qa = quanta( )

        telescope, arrayPos = getPlotantsObservatoryInfo(msname)
        arrayWgs84 = me.measure(arrayPos, 'WGS84')
        arrayLon, arrayLat, arrayAlt = [arrayWgs84[i]['value']
                for i in ['m0','m1','m2']]

        # Open the ANTENNA subtable to get the names of the antennas in this MS and
        # their positions.  Note that the entries in the ANTENNA subtable are pretty
        # much in random order, so antNames translates between their index and name
        # (e.g., index 11 = STD155).  We'll need these indices for later, since the
        # main data table refers to the antennas by their indices, not names.

        anttabname = msname + '/ANTENNA'
        tb.open(anttabname)
        # Get antenna names from antenna table
        antNames = np.array(tb.getcol("NAME")).tolist()
        stationNames = np.array(tb.getcol("STATION")).tolist()
        if telescope == 'VLBA':  # names = ant@station
                antNames = ['@'.join(antsta) for antsta in zip(antNames,stationNames)]
        # Get antenna positions from antenna table
        antPositions = np.array([me.position('ITRF', qa.quantity(x, 'm'),
                qa.quantity(y, 'm'), qa.quantity(z, 'm'))
                for (x, y, z) in tb.getcol('POSITION').transpose()])
        tb.close()

        allAntIds = list(range(len(antNames)))
        if checkbaselines:
                # Get antenna ids from main table; this will add to runtime
                tb.open(msname)
                ants1 = tb.getcol('ANTENNA1')
                ants2 = tb.getcol('ANTENNA2')
                tb.close()
                antIdsUsed = list(set(np.append(ants1, ants2)))
        else:
                # use them all!
                antIdsUsed = allAntIds

        # handle exclude -- remove from antIdsUsed
        for antId in exclude:
                try:
                        antNameId = antNames[antId] + " (id " + str(antId) + ")"
                        antIdsUsed.remove(antId)
                        casalog.post("Exclude antenna " + antNameId)
                except ValueError:
                        casalog.post("Cannot exclude antenna " + antNameId + ": not in main table", "WARN")

        # apply antIdsUsed mask
        antNames = [antNames[i] for i in antIdsUsed]
        antPositions = [antPositions[i] for i in antIdsUsed]
        stationNames = [stationNames[i] for i in antIdsUsed]

        nAnts = len(antIdsUsed)
        casalog.post("Number of points being plotted: " + str(nAnts))
        if nAnts == 0: # excluded all antennas
                return telescope, antNames, [], [], []

        # Get the names, indices, and lat/lon/alt coords of "good" antennas.
        antWgs84s = np.array([me.measure(pos, 'WGS84') for pos in antPositions])

        # Convert from lat, lon, alt to X, Y, Z (unless VLBA)
        # where X is east, Y is north, Z is up,
        # and 0, 0, 0 is the center
        # Note: this conversion is NOT exact, since it doesn't take into account
        # Earth's ellipticity!  But it's close enough.
        if telescope == 'VLBA' and not log:
                antLons, antLats = [[pos[i] for pos in antWgs84s] for i in ['m0','m1']]
                antXs = [qa.convert(lon, 'deg')['value'] for lon in antLons]
                antYs = [qa.convert(lat, 'deg')['value'] for lat in antLats]
        else:
                antLons, antLats = [np.array( [pos[i]['value']
                        for pos in antWgs84s]) for i in ['m0','m1']]
                radE = 6370000.
                antXs = (antLons - arrayLon) * radE * np.cos(arrayLat)
                antYs = (antLats - arrayLat) * radE
        return telescope, antNames, antIdsUsed, antXs, antYs, stationNames

def getPlotantsObservatoryInfo(msname):
        metadata = msmetadata()
        metadata.open(msname)
        telescope = metadata.observatorynames()[0]
        arrayPos = metadata.observatoryposition()
        metadata.close()
        return telescope, arrayPos

def getAntennaLabelProps(telescope, station, log=False):
        # CAS-7120 make plots more readable (rotate labels)
        vAlign = 'center'
        hAlign = 'left'
        rotAngle = 0
        if station and "VLA" in telescope:
                # these have non-standard format:
                # strip off VLA: or VLA:_ prefix if any
                if 'VLA:' in station:
                        station = station[4:]
                if station[0] == '_':
                        station = station[1:]
                if station in ['W01', 'E01', 'W1', 'E1']:
                        vAlign = 'top'
                        hAlign = 'center'
                else:
                        vAlign = 'bottom'
                        if 'W' in station or 'MAS' in station:
                                hAlign = 'right'
                                rotAngle = -35
                        elif 'E' in station or ('N' in station and not log):
                                rotAngle = 35
        return vAlign, hAlign, rotAngle

def plotAntennasLog(telescope, names, ids, xpos, ypos, antindex, stations):
        fig = pl.figure(1)
        # Set up subplot.
        ax = fig.add_subplot(1, 1, 1, polar=True, projection='polar')
        ax.set_theta_zero_location('N')
        ax.set_theta_direction(-1)
        # Do not show azimuth labels.
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Do not show grid.
        ax.grid(False)

        # code from pipeline summary.py
        # PlotAntsChart draw_polarlog_ant_map_in_subplot
        if 'VLA' in telescope:
                # For (E)VLA, set a fixed local center position that has been
                # tuned to work well for its array configurations (CAS-7479).
                xcenter, ycenter = -32, 0
                rmin_min, rmin_max = 12.5, 350
        else:
                # For non-(E)VLA, take the median of antenna offsets as the
                # center for the plot.
                xcenter = np.median(xpos)
                ycenter = np.median(ypos)
                rmin_min, rmin_max = 3, 350

        # Derive radial offset w.r.t. center position.
        r = ((xpos-xcenter)**2 + (ypos-ycenter)**2)**0.5
        # Set rmin, clamp between a min and max value, ignore station
        # at r=0 if one is there.
        rmin = min(rmin_max, max(rmin_min, 0.8*np.min(r[r > 0])))
        # Update r to move any points below rmin to r=rmin.
        r[r <= rmin] = rmin
        rmin = np.log(rmin)
        # Set rmax.
        rmax = np.log(1.5*np.max(r))
        # Derive angle of offset w.r.t. center position.
        theta = np.arctan2(xpos-xcenter, ypos-ycenter)

        # Draw circles at specific distances from the center.
        angles = np.arange(0, 2.01*np.pi, 0.01*np.pi)
        show_circle = True
        circles = [30, 100, 300, 1000, 3000, 10000]
        if telescope == "VLBA":
                circles = [1e5, 3e5, 1e6, 3e6, 1e7]
        for cr in circles:

                # Only draw circles outside rmin.
                if cr > np.min(r) and show_circle:

                        # Draw the circle.
                        radius = np.ones(len(angles))*np.log(cr)
                        ax.plot(angles, radius, 'k:')

                        # Draw tick marks on the circle at 1 km intervals.
                        inc = 0.1*10000/cr
                        if telescope == "VLBA":
                                inc = 0.1*1e8/cr
                        if cr > 100:
                                for angle in np.arange(inc/2., 2*np.pi+0.05, inc):
                                        ax.plot([angle, angle],
                                                           [np.log(0.95*cr), np.log(1.05*cr)], 'k-')

                        # Add text label to circle to denote distance from center.
                        va = 'top'
                        circle_label_angle = -20.0 * np.pi / 180.
                        if cr >= 1000:
                                if np.log(cr) < rmax:
                                        ax.text(circle_label_angle, np.log(cr),
                                                           '%d km' % (cr/1000), size=8, va=va)
                                        ax.text(circle_label_angle + np.pi, np.log(cr),
                                                           '%d km' % (cr / 1000), size=8, va=va)
                        else:
                                ax.text(circle_label_angle, np.log(cr), '%dm' % (cr),
                                        size=8, va=va)
                                ax.text(circle_label_angle + np.pi, np.log(cr),
                                                '%dm' % (cr), size=8, va=va)

                # Find out if most recently drawn circle was outside all antennas,
                # if so, no more circles will be drawn.
                if np.log(cr) > rmax:
                        show_circle = False

        # plot points and antenna names/ids
        for i, (name, station) in enumerate(zip(names, stations)):
                if station and 'OUT' not in station:
                        ax.plot(theta[i], np.log(r[i]), 'ko', ms=5, mfc='r')
                        if antindex:
                                name += ' (' + str(ids[i]) + ')'
                        # set alignment and rotation angle (for VLA)
                        valign, halign, angle = getAntennaLabelProps(telescope, station, log=True)
                        # adjust so text is not on the circle:
                        yoffset = 0
                        if halign == 'center':
                                yoffset = 0.1
                        ax.text(theta[i], np.log(r[i])+yoffset, ' '+name, size=8, va=valign, ha=halign,
                                rotation=angle, weight='semibold')

        # Set minimum and maximum radius.
        ax.set_rmax(rmax)  
        ax.set_rmin(rmin)

        # Make room for 2-line title
        pl.subplots_adjust(top=0.88)

def plotAntennas(telescope, names, ids, xpos, ypos, antindex, stations, showplot):
        fig = pl.figure(1)
        ax = fig.add_subplot(111)

        if telescope == 'VLBA':
                labelx = 'Longitude (deg)'
                labely = 'Latitude (deg)'
        else:
                # use m or km units
                units = ' (m)'
                if np.median(xpos) > 1e6 or np.median(ypos) > 1e6:
                        xpos /= 1e3
                        ypos /= 1e3
                        units = ' (km)'
                labelx = 'X' + units
                labely = 'Y' + units
        if "VLA" in telescope:
                spacer = ' '
        else:
                spacer = '  '

        # plot points and antenna names/ids
        for i, (x, y, name, station) in enumerate(zip(xpos, ypos, names, stations)):
                if station and 'OUT' not in station:
                        ax.plot(x, y, 'ro')
                        if antindex:
                                name += ' (' + str(ids[i]) + ')'
                        # set alignment and rotation angle (for VLA)
                        valign, halign, angle = getAntennaLabelProps(telescope, station)
                        # adjust so text is not on the circle:
                        if halign == 'center':
                                y -= 10
                        ax.text(x, y, ' '+name, size=8, va=valign, ha=halign, rotation=angle,
                                weight='semibold')
                        if showplot:
                                fig.show()

        pl.xlabel(labelx)
        pl.ylabel(labely)
        pl.margins(0.1, 0.1)