from SimPEG import Mesh
import os
import sys
import pylab as plt
import numpy as np

basePath = os.path.sep.join(os.path.abspath(os.getenv('HOME')).split(os.path.sep)+['Downloads']+['SimPEGtemp']) #hide
# Load the mesh, model and data
mesh = Mesh.TensorMesh.readUBC(basePath + os.path.sep + "Mesh.msh")

# Load models
m_lp = Mesh.TensorMesh.readModelUBC(mesh, basePath + os.path.sep +  "SimPEG_MAG_lplq.sus")
m_l2 = Mesh.TensorMesh.readModelUBC(mesh, basePath + os.path.sep +  "SimPEG_MAG_l2l2.sus")
m_true = Mesh.TensorMesh.readModelUBC(mesh, basePath + os.path.sep +  "Initm.sus")

m_lp[m_lp==-100] = np.nan
m_l2[m_l2==-100] = np.nan
m_true[m_true==-100] = np.nan

fig = plt.figure()
vmin, vmax = 0.0, 0.015
xmin, xmax = -500 + 557300, 500 + 557300
ymin, ymax = -500 + 7133600, 500 + 7133600
zmin, zmax = -500 + 450, 0 + 450
indz = 17
indx = 17

# Axis label
x = np.linspace(xmin+200, xmax-200,3)
y = np.linspace(zmin+50, zmax-50,3)

ax1 = plt.subplot(1,1,1)
pos =  ax1.get_position()
ax1.set_position([pos.x0-0.1, pos.y0+0.3,  pos.width*0.5, pos.height*0.5])
dat = mesh.plotSlice(m_l2, ax = ax1, normal='Z', ind=indz, clim=np.r_[vmin, vmax],pcolorOpts={'cmap':'viridis'})
#     plt.colorbar(dat[0])
plt.gca().set_aspect('equal')
plt.title('Smooth')
ax1.xaxis.set_visible(False)
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
plt.ylabel('Northing (m)')
labels = ax1.get_yticklabels()
plt.setp(labels, rotation=90)

# ax2 = plt.subplot(2,2,3)
pos =  ax1.get_position()
ax2 = fig.add_axes([pos.x0+0.0525, pos.y0 - 0.315,  pos.width*0.725, pos.height])
# ax2.yaxis.set_visible(False)
# ax2.set_position([pos.x0 -0.04 , pos.y0,  pos.width, pos.height])

dat = mesh.plotSlice(m_l2, ax = ax2, normal='Y', ind=indx, clim=np.r_[vmin, vmax],pcolorOpts={'cmap':'viridis'})
#     plt.colorbar(dat[0])
plt.gca().set_aspect('equal')
plt.title('')
plt.xlim(xmin, xmax)
plt.ylim(zmin, zmax)
ax2.set_xticks(map(int, x))
ax2.set_xticklabels(map(str, map(int, x)),size=12)
plt.xlabel('Easting (m)')
plt.ylabel('Elev. (m)')
ax2.set_yticks(map(int, y))
ax2.set_yticklabels(map(str, map(int, y)),size=12)
labels = ax2.get_yticklabels()
plt.setp(labels, rotation=90)

## Add compact model
ax3 = fig.add_axes([pos.x0+0.3, pos.y0,  pos.width, pos.height])
dat = mesh.plotSlice(m_lp, ax = ax3, normal='Z', ind=indz, clim=np.r_[vmin, vmax],pcolorOpts={'cmap':'viridis'})
#     plt.colorbar(dat[0])
plt.gca().set_aspect('equal')
plt.title('Compact')
ax3.xaxis.set_visible(False)
ax3.yaxis.set_visible(False)
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)

ax4 = fig.add_axes([pos.x0+0.3525, pos.y0 - 0.315,  pos.width*0.725, pos.height])
# ax2.yaxis.set_visible(False)
# ax2.set_position([pos.x0 -0.04 , pos.y0,  pos.width, pos.height])

dat = mesh.plotSlice(m_lp, ax = ax4, normal='Y', ind=indx, clim=np.r_[vmin, vmax],pcolorOpts={'cmap':'viridis'})
#     plt.colorbar(dat[0])
plt.gca().set_aspect('equal')
ax4.yaxis.set_visible(False)
plt.title('')
plt.xlim(xmin, xmax)
plt.ylim(zmin, zmax)
ax4.set_xticks(map(int, x))
ax4.set_xticklabels(map(str, map(int, x)),size=12)
plt.xlabel('')
# ylabel('Elev. (m)')

## Add True model
ax5 = fig.add_axes([pos.x0+0.6, pos.y0,  pos.width, pos.height])
dat = mesh.plotSlice(m_true, ax = ax5, normal='Z', ind=indz, clim=np.r_[vmin, vmax],pcolorOpts={'cmap':'viridis'})
#     plt.colorbar(dat[0])
plt.gca().set_aspect('equal')
plt.title('True')
ax5.xaxis.set_visible(False)
ax5.yaxis.set_visible(False)
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)

ax6 = fig.add_axes([pos.x0+0.6525, pos.y0 - 0.315,  pos.width*0.725, pos.height])
# ax2.yaxis.set_visible(False)
# ax2.set_position([pos.x0 -0.04 , pos.y0,  pos.width, pos.height])

dat = mesh.plotSlice(m_true, ax = ax6, normal='Y', ind=indx, clim=np.r_[vmin, vmax],pcolorOpts={'cmap':'viridis'})
#     plt.colorbar(dat[0])
plt.gca().set_aspect('equal')
ax6.yaxis.set_visible(False)
plt.title('')
plt.xlim(xmin, xmax)
plt.ylim(zmin, zmax)
ax6.set_xticks(map(int, x))
ax6.set_xticklabels(map(str, map(int, x)),size=12)
plt.xlabel('')
# ylabel('Elev. (m)')

pos =  ax4.get_position()
cbarax = fig.add_axes([pos.x0 , pos.y0+0.05 ,  pos.width, pos.height*0.1])  ## the parameters are the specified position you set
cb = fig.colorbar(dat[0],cax=cbarax, orientation="horizontal", ax = ax4, ticks=np.linspace(vmin,vmax, 4),format='%.3f')
cbarax.tick_params(labelsize=12)
# cb.ax.xaxis.set_label_position('top')
cb.set_label("Susceptibility (SI)",size=14)

import shutil
shutil.rmtree(basePath) #hide

plt.show()