import argparse
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.cm import get_cmap
import metpy.calc as mpcalc
from metpy.units import units
from datetime import datetime
from geopy.geocoders import Nominatim
from sharppy.sharptab import winds, utils, params, thermo, interp, profile
import sharppy.plot.skew as skew

from metpy.plots import Hodograph
import glob, os
import matplotlib as mpl
from matplotlib.patches import Circle
from matplotlib import gridspec
from matplotlib.ticker import ScalarFormatter, MultipleLocator, NullFormatter
from matplotlib.collections import LineCollection
import matplotlib.transforms as transforms
from matplotlib.collections import PatchCollection
import matplotlib.colors as mcolors
import skewx

def plot_barbs(axes, p, u, v):
    xloc = np.empty_like(p)
    xloc.fill(1)
    x_clip_radius=0.1
    y_clip_radius=0.08

    b = axes.barbs(xloc, p, u, v, length = 6, transform=axes.get_yaxis_transform(which='tick2'),
                              clip_on=True, zorder=2, flip_barb = True)
    ax_bbox = transforms.Bbox([[xloc[0] - x_clip_radius, - y_clip_radius],
                               [xloc[0] + x_clip_radius, 1.0 + y_clip_radius]])
    b.set_clip_box(transforms.TransformedBbox(ax_bbox, axes.transAxes))

def plot_skew(P_sounding, H_agl, T_sounding, Td_sounding, u_sounding, v_sounding, time_start, time_valid, lon, lat, leftmover=True, parcel_to_draw='ml'):
    prof = profile.create_profile(pres = P_sounding, hght = H_agl, tmpc = T_sounding, dwpc = Td_sounding, u = u_sounding, v = v_sounding)
    
    #evaluate important thermodynamic and shear parameters
    sfc_pcl = params.parcelx(prof, flag = 1)
    ml_pcl = params.parcelx(prof, flag = 4)
    mu_pcl = params.parcelx(prof, flag = 3)
    
    dcape, d_trace, d_ptrace = params.dcape(prof)
    cape_03 = np.round(ml_pcl.b3km)
    pw = np.round(params.precip_water(prof) * 25.4, 2)

    mulcl = np.round(mu_pcl.lclhght)
    mulfc = np.round(mu_pcl.lfchght)
    mucape = np.round(mu_pcl.bplus)
    mucin = np.round(mu_pcl.bminus)

    mllcl = np.round(ml_pcl.lclhght)
    mllfc = np.round(ml_pcl.lfchght)
    mlcape = np.round(ml_pcl.bplus)
    mlcin = np.round(ml_pcl.bminus)

    sblcl = np.round(sfc_pcl.lclhght)
    sblfc = np.round(sfc_pcl.lfchght)
    sbcape = np.round(sfc_pcl.bplus)
    sbcin = np.round(sfc_pcl.bminus)

    lr_700_500 = np.round(params.lapse_rate(prof, 750, 500), 2)
    lr_850_500 = np.round(params.lapse_rate(prof, 850, 500), 2)
    lr_0_3 = np.round(params.lapse_rate(prof, 0, 3000, pres = False), 2)
    lr_3_6 = np.round(params.lapse_rate(prof, 3000, 6000, pres = False), 2)
    p0c = params.temp_lvl(prof, 0)
    h0c = np.round(interp.hght(prof, p0c))

    P_sfc = prof.pres[prof.sfc]
    mu_el = mu_pcl.elpres * 0.65
    p_6 = interp.pres(prof, interp.to_msl(prof, 6000))
    p_3 = interp.pres(prof, interp.to_msl(prof, 3000))
    p_1 = interp.pres(prof, interp.to_msl(prof, 1000))
    in_pres = [interp.pres(prof, interp.to_agl(prof, h)) for h in [500, 1000, 3000, 6000, 9000, 12000]]
    in_wind = [interp.components(prof, p) for p in in_pres]
        
    eff_inflow = params.effective_inflow_layer(prof)
    eff_bot_m = interp.to_agl(prof, interp.hght(prof, eff_inflow[0]))
    eff_top_m = interp.to_agl(prof, interp.hght(prof, eff_inflow[1]))
    eff_wind = [interp.components(prof, p) for p in [eff_inflow[0], eff_inflow[1]]]
    print(eff_wind)
        
    bwd06 = np.round(winds.wind_shear(prof, P_sfc, p_6), 1)
    bwd06_mag = np.round(utils.mag(bwd06[0], bwd06[1]), 1)
    bwd03 = np.round(winds.wind_shear(prof, P_sfc, p_3), 1)
    bwd03_mag = np.round(utils.mag(bwd03[0], bwd03[1]), 1)
    bwd01 = np.round(winds.wind_shear(prof, P_sfc, p_1), 1)
    bwd01_mag = np.round(utils.mag(bwd01[0], bwd01[1]), 1)

    ebwd = winds.wind_shear(prof, pbot=eff_inflow[0], ptop=interp.pres(prof, mu_pcl.elhght * 0.5))
    eff_bwd = np.round(utils.mag(ebwd[0], ebwd[1]), 1)
        
    rstu, rstv, lstu, lstv = params.bunkers_storm_motion(prof)
    mwind = winds.mean_wind(prof, eff_inflow[0], mu_el, lstu, lstv)
    rst_dir, rst_wsp = np.round(utils.comp2vec(rstu , rstv), 1)
    lst_dir, lst_wsp = np.round(utils.comp2vec(lstu , lstv), 1)
    mwind_dir, mwind_wsp = np.round(utils.comp2vec(mwind[0], mwind[1]), 1)
        
        
    ship = np.round(params.ship(prof), 1)

    if leftmover is True:
        srh0_500, _, _ = np.round(winds.helicity(prof, 0, 500, lstu, lstv))
        srh0_1000, _, _ = np.round(winds.helicity(prof, 0, 1000, lstu, lstv))
        srh0_3000, _, _ = np.round(winds.helicity(prof, 0, 3000, lstu, lstv))
        stp_fix = np.round(params.stp_fixed(sbcape, sblcl, srh0_1000, bwd06_mag), 1)
        srh_eff = np.round(winds.helicity(prof, eff_bot_m, eff_top_m, lstu, lstv))[0]
        scp = np.round(params.scp(mucape, srh_eff, eff_bwd), 1)
        stp_cin = np.round(params.stp_cin(mlcape, srh_eff, eff_bwd, mllcl, mlcin), 1)
    
    else:
        srh0_500, _, _ = np.round(winds.helicity(prof, 0, 500, rstu, rstv))
        srh0_1000, _, _ = np.round(winds.helicity(prof, 0, 1000, rstu, rstv))
        srh0_3000, _, _ = np.round(winds.helicity(prof, 0, 3000, rstu, rstv))
        stp_fix = np.round(params.stp_fixed(sbcape, sblcl, srh0_1000, bwd06_mag), 1)
        srh_eff = np.round(winds.helicity(prof, eff_bot_m, eff_top_m, rstu, rstv))[0]
        scp = np.round(params.scp(mucape, srh_eff, eff_bwd), 1)
        stp_cin = np.round(params.stp_cin(mlcape, srh_eff, eff_bwd, mllcl, mlcin), 1)
    
    #plot our skewT logP diagram
    fig = plt.figure(figsize=(18, 16))
    gs = gridspec.GridSpec(6, 6)
    ax = plt.subplot(gs[:3, :3], projection = 'skewx')

    ax.semilogy(prof.tmpc[~prof.tmpc.mask], prof.pres[~prof.tmpc.mask], 'r', lw=2)
    ax.semilogy(prof.dwpc[~prof.dwpc.mask], prof.pres[~prof.dwpc.mask], 'g', lw=2)
    ax.semilogy(prof.vtmp[~prof.dwpc.mask], prof.pres[~prof.dwpc.mask], 'r--')
    ax.semilogy(prof.wetbulb[~prof.dwpc.mask], prof.pres[~prof.dwpc.mask], 'c-', lw = 1)
    ax.semilogy(d_trace, d_ptrace, 'm--')
    ax.grid(True)

    skew.draw_dry_adiabats(ax, tmin = -70, alpha = 0.2, color = 'y')
    skew.draw_moist_adiabats(ax, tmin = -70, alpha = 0.2, color = 'darkblue')
    skew.draw_heights(ax, prof)
    
    pmask = np.array(prof.hght) * units.m <= 10000 * units.m
    plot_barbs(ax, prof.pres[pmask], prof.u[pmask], prof.v[pmask])
            
    try:
        skew.draw_effective_inflow_layer(ax, prof)
    except:
        print("Couldn't plot effective inflow layer ...")

    # Plot the parcel trace, but this may fail.  If it does so, inform the user.
    try:
        if parcel_to_draw == 'ml':
            ax.semilogy(ml_pcl.ttrace, ml_pcl.ptrace, 'k', lw = 2)
        if parcel_to_draw == 'sb':
            ax.semilogy(sfc_pcl.ttrace, sfc_pcl.ptrace, 'k', lw = 2)
        if parcel_to_draw == 'mu':
            ax.semilogy(mu_pcl.ttrace, mu_pcl.ptrace, 'k', lw = 2)
    except:
        print("Couldn't plot parcel traces ...")

    # Highlight the 0 C and -20 C isotherms.
    [ax.axvline(T_hgz, color='b', ls='--') for T_hgz in [-20, 0]]

    ax.yaxis.set_major_formatter(ScalarFormatter())
    ax.set_yticks(np.linspace(100, 1000, 10))
    ax.set_ylim(1030, 100)
    ax.set_yticks(np.linspace(100, 1000, 10))
    ax.xaxis.set_major_locator(plt.MultipleLocator(10))
    ax.set_xlim(-50, 50)
    ax.tick_params(labelsize = 14)
    ax.set_xlabel('T [°C]', fontsize = 14)
    ax.set_ylabel('P [hPa]', fontsize = 14)

    plt.title(f'LON: {np.round(lon.values, 2)}, LAT: {np.round(lat.values, 2)}' + '\n' + 'MPAS 3 km', loc = 'left')

    mask = np.array(prof.hght) * units.m <= 10000 * units.m
        
    #plot the hodograph
    ax2 = plt.subplot(gs[:3, 3:])
    hodo = Hodograph(ax2)
    hodo.add_grid(linewidth = 0.5)
    hodo.plot_colormapped((np.array(prof.u)*units('kt'))[mask], (np.array(prof.v)*units('kt'))[mask], (np.array(prof.hght)*units('m'))[mask],
                            colors = ['m', 'crimson', 'r', 'y', 'g'], intervals = [0, 500, 1000, 3000, 6000, 10000] * units.m)

    ax2.set_xlim(-60 + lstu, 60 + lstu)
    ax2.set_ylim(-60 + lstv, 60 + lstv)

    ax2.plot(lstu, lstv, marker = 'o', markersize = 8, markerfacecolor = 'red', markeredgecolor = 'black')
    ax2.plot(rstu, rstv, marker = 'o', markersize = 8, markerfacecolor = 'blue', markeredgecolor = 'black')
    try:
        ax2.annotate(f'{rst_wsp}/{int(rst_dir)}',
                    xy=(rstu, rstv), xytext = (rstu + 2, rstv - 5), fontsize = 12)
        ax2.annotate(f'{lst_wsp}/{int(lst_dir)}',
                    xy=(lstu, lstv), xytext = (lstu + 2, lstv + 5), fontsize = 12)
    except:
        pass
    try:
        ax2.annotate(f'{mwind_wsp}/{int(mwind_dir)}',
                    xy=(mwind[0], mwind[1]), xytext = (mwind[0] + 2, mwind[1] - 5), fontsize = 12)
        ax2.plot(mwind[0], mwind[1], marker = 's', markersize = 8, color = 'gray', markeredgecolor = 'black')
    except:
        pass

    for w, l in zip(in_wind, ['0.5', '1', '3', '6', '10']):
        ax2.plot(w[0], w[1], marker = 'o', markersize = 2.5, color = 'black')
        ax2.annotate(f'{l}', xy=(w[0], w[1]), xytext = (w[0] - 2.5, w[1] - 2.5), fontsize = 12)
    ax2.tick_params(axis = 'both', direction = 'in', left = False, bottom = False, labelleft = False, labelbottom = False, labelsize = 12)

    ax2.plot([lstu, eff_wind[0][0]], [lstv, eff_wind[0][1]], ls = '--', color = 'b')
    ax2.plot([lstu, eff_wind[1][0]], [lstv, eff_wind[1][1]], ls = '--', color = 'b')
    
    plt.title(f'Init: {time_start} UTC' +'\n' + f'Val: {time_valid} UTC', loc = 'right')

    #plot the thermodynamic and shear variables
    ax3 = plt.subplot(gs[3, :])
    ax3.axis('off')
    ax3.tick_params(which = 'both', bottom = False, left = False, labelbottom = False, labelleft = False)

    ax3.text(0.0, 0.85, f'MULCL: {mulcl} m', fontsize = 11)
    ax3.text(0.0, 0.7, f'MULFC: {mulfc} m', fontsize = 11)
    ax3.text(0.0, 0.55, f'MUCAPE: {mucape} J/kg', fontsize = 11)
    ax3.text(0.0, 0.4, f'MUCIN: {mucin} J/kg', fontsize = 11)

    ax3.text(0.13, 0.85, f'MLLCL: {mllcl} m', fontsize = 11)
    ax3.text(0.13, 0.7, f'MLLFC: {mllfc} m', fontsize = 11)
    ax3.text(0.13, 0.55, f'MLCAPE: {mlcape} J/kg', fontsize = 11)
    ax3.text(0.13, 0.4, f'MLCIN: {mlcin} J/kg', fontsize = 11)

    ax3.text(0.26, 0.85, f'SBLCL: {sblcl} m', fontsize = 11)
    ax3.text(0.26, 0.7, f'SBLFC: {sblfc} m', fontsize = 11)
    ax3.text(0.26, 0.55, f'SBCAPE: {sbcape} J/kg', fontsize = 11)
    ax3.text(0.26, 0.4, f'SBCIN: {sbcin} J/kg', fontsize = 11)

    ax3.text(0.38, 0.85, r'$\Gamma_{700-500}$:' + f' {lr_700_500} °C/km', fontsize = 11)
    ax3.text(0.38, 0.7, r'$\Gamma_{850-500}$:' + f' {lr_850_500} °C/km', fontsize = 11)
    ax3.text(0.38, 0.55, r'$\Gamma_{0-3}$:' + f' {lr_0_3} °C/km', fontsize = 11)
    ax3.text(0.38, 0.4, r'$\Gamma_{3-6}$:' + f' {lr_3_6} °C/km', fontsize = 11)

    ax3.text(0.51, 0.85, f'PW: {pw} mm', fontsize = 11)
    ax3.text(0.51, 0.7, r'CAPE$_{0-3}$:' + f' {cape_03} J/kg', fontsize = 11)
    ax3.text(0.51, 0.55, f'DCAPE: {np.round(dcape)} J/kg', fontsize = 11)
    ax3.text(0.51, 0.4, f'0°C HGT: {h0c} m', fontsize = 11)

    ax3.text(0.63, 0.85, r'BWD$_{0-6}:$' + f'{bwd06_mag} kt', fontsize = 11)
    ax3.text(0.63, 0.7, r'BWD$_{0-3}:$' + f'{bwd03_mag} kt', fontsize = 11)
    ax3.text(0.63, 0.55, r'BWD$_{0-1}:$' + f'{bwd01_mag} kt', fontsize = 11)
    ax3.text(0.63, 0.4, r'BWD$_{EFF}:$' + f'{eff_bwd} kt', fontsize = 11)

    ax3.text(0.74, 0.85, r'SRH$_{0-0.5}:$' + f'{srh0_500} m²/s²', fontsize = 11)
    ax3.text(0.74, 0.7, r'SRH$_{0-1}:$' + f'{srh0_1000} m²/s²', fontsize = 11)
    ax3.text(0.74, 0.55, r'SRH$_{0-3}:$' + f'{srh0_3000} m²/s²', fontsize = 11)
    ax3.text(0.74, 0.4, r'SRH$_{EFF}:$' + f'{srh_eff} m²/s²', fontsize = 11)

    ax3.text(0.88, 0.85, f'SHIP: {ship}', fontsize = 11)
    ax3.text(0.88, 0.7, f'SCP: {scp}', fontsize = 11)
    ax3.text(0.88, 0.55, f'STP (CIN): {stp_cin}', fontsize = 11)
    ax3.text(0.88, 0.4, f'STP (FIX): {stp_fix}', fontsize = 11)

    gs.update(bottom = .025, hspace = 0.4, wspace = 0.2)

    return fig, ax

def extract_mpas_profile(ds):
    """
    Extrai o perfil vertical de um arquivo history gerado pelo MPAS
    -----
    
    Parâmetros:

    ds: xarray.Dataset(str)
        Dataset do xarray
    -----

    Retorna:
    o perfil de pressão, altura acima do solo, temperatura, ponto de orvalho e componentes zonal e meridional para o ponto de grade de interesse
    """
    
    total_pressure = (xr.concat([ds.surface_pressure, ds.pressure], 'nVertLevels').values * units.Pa).to('hPa')
    tmp = mpcalc.temperature_from_potential_temperature(ds.pressure * units.Pa, ds.theta * units.K)
    total_tmp = (xr.concat([ds.t2m * units.K, tmp], dim = 'nVertLevels').values * units.K).to('degC')
    dwp = mpcalc.dewpoint_from_relative_humidity(tmp, ds.relhum)
    dwp2m = mpcalc.dewpoint_from_specific_humidity(ds.surface_pressure * units.Pa, ds.t2m * units.K, ds.q2)
    total_dewpoint = xr.concat([dwp2m, dwp], dim = 'nVertLevels')
    total_u = (xr.concat([ds.u10, ds.uReconstructZonal], dim = 'nVertLevels').values * units('m/s')).to('kt')
    total_v = (xr.concat([ds.v10, ds.uReconstructMeridional], dim = 'nVertLevels').values * units('m/s')).to('kt')
    total_hgt = ds.zgrid * units.m

    return total_pressure, total_hgt, total_tmp, total_dewpoint, total_u, total_v

def create_skew(dir_files, dir_output, city = None, state = None):
    history = glob.glob(dir_files + '/*.nc')
    city = f"{city}, {state}"
    geoloc = Nominatim(user_agent="Your_Name")
    location = geoloc.geocode(city)
    for file in history:
        ds = xr.open_dataset(file).sel(longitude = location.longitude, latitude = location.latitude,  method = 'nearest')
        ds = ds.sel(Time = ds.Time[0], method = 'nearest')
        t0 = datetime.strptime(ds.Time_Start.values.astype('datetime64[s]')[0].astype('str'), '%Y-%m-%dT%H:%M:%S')
        time = datetime.strptime(ds.Time.values.astype('datetime64[s]').astype('str'), '%Y-%m-%dT%H:%M:%S')
        p, h, t, td, u, v = extract_mpas_profile(ds)
        fig, ax = plot_skew(P_sounding=p, H_agl=h - h[0], T_sounding=t, Td_sounding=td, u_sounding=u, v_sounding=v, time_start=t0, time_valid=time, lon=ds.longitude, lat=ds.latitude)
        plt.savefig(f'{dir_output}' + '/' + f'{time}' '.png', bbox_inches = 'tight', dpi = 300)
        plt.close()
        fig.clf()
        ds.close()

parser = argparse.ArgumentParser(description='Plot MPAS output')
parser.add_argument('--dir_files', metavar='-DF', nargs='+', help='path to history files')
parser.add_argument('--dir_output', metavar='-DO', nargs='+', help='path to save profiles')
parser.add_argument('--city', metavar='-C', nargs='+', help='cidade')
parser.add_argument('--estado', metavar='-UF', nargs='+', help='estado')

args = parser.parse_args()

create_skew(args.dir_files[0], args.dir_output[0], args.city, args.estado)

