import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
[docs]class Plot:
"""Plot DataSet objects with matplotlib."""
##### Functions to control plot settings #####
def __init__(self,dim=2,elevation=20,angle=130):
"""Set default parameters"""
self.num_datasets = 0 # Total number of added data sets
self.datasets = [] # List of added data sets
self.initialised = False # Figure and axes initialised
self.finalised = False # Final plot properties adjusted
self.set_plot_size() # Initialise plot size to 4x4cm
self.set_text() # Initialise text size to 10pt
self.set_dimensions(dim=dim) # 2D/3D plot
self.set_axes() # Default axes labels
self.set_legend() # No legend
self.set_view(elevation=elevation,angle=angle) # Orientation for 3D plot
pylab.rcParams['axes.xmargin'] = 0.0 # Remove padding on x-axis
pylab.rcParams['axes.ymargin'] = 0.0 # Remove padding on y-axis
[docs] def set_plot_size(self, width = 4, height = 4):
"""Set plot size in cm."""
params = {"figure.figsize": (width, height)}
pylab.rcParams.update(params)
[docs] def set_text(self, font='serif', latex=False, legend = 10, title = 10, label = 10):
"""
Set font and text size.
:param latex: enable latex formatting - this is slower to render image but can give better fonts
:type latex: bool
"""
if latex:
params = {
'font.family' : 'serif',
'font.serif' : 'STIX',
'mathtext.fontset' : 'stix',
'text.usetex' : True,
'legend.title_fontsize': legend,
'legend.fontsize': legend,
'axes.labelsize': label,
'axes.titlesize': title,
'xtick.labelsize': label,
'ytick.labelsize': label
}
elif font == 'serif':
params = {
'font.family' : font,
'font.serif' : 'DejaVu Serif',
'mathtext.fontset' : 'dejavuserif',
'legend.title_fontsize': legend,
'legend.fontsize': legend,
'axes.labelsize': label,
'axes.titlesize': title,
'xtick.labelsize': label,
'ytick.labelsize': label
}
else:
params = {
'legend.title_fontsize': legend,
'legend.fontsize': legend,
'axes.labelsize': label,
'axes.titlesize': title,
'xtick.labelsize': label,
'ytick.labelsize': label
}
pylab.rcParams.update(params)
[docs] def set_legend(self,**kwargs):
"""
Set legend properties.
:param legend: turn legend on or off
:type legend: bool
:param title: set legend title
:type title: str
:param columns: number of columns in legend
:type columns: int
:param reverse: reverse order of legend
:type reverse: bool
:param location: location ('upper left' etc.)
:type location: str
"""
self.legend = kwargs.get("legend", False)
self.legend_title = kwargs.get("title", None)
self.legend_columns = kwargs.get("cols", 1)
self.legend_anchor = kwargs.get("anchor", None)
self.legend_reverse = kwargs.get("reverse", False)
self.legend_location = kwargs.get("location", 'best')
[docs] def set_dimensions(self,dim=2):
"""Set dimensionality of plot."""
if dim<2 or dim>3:
print("Must be 2D or 3D")
else:
self.dimensions = dim
[docs] def set_axes(self, **kwargs):
"""
Sets axes properties for x,y,z axes.
:param xlabel: x-axis label
:type xlabel: str
:param xlim: x-axis limits
:type xlim: tuple
:param xticks: positions of major and minor ticks
:type xticks: tuple
:param xlog: use logarithmic scale for x-axis
:type xlog: bool
"""
self.axis_xlabel = kwargs.get("xlabel", r"$x$")
self.axis_ylabel = kwargs.get("ylabel", r"$y$")
self.axis_zlabel = kwargs.get("zlabel", r"$z$")
self.axis_xlim = kwargs.get("xlim", None)
self.axis_ylim = kwargs.get("ylim", None)
self.axis_zlim = kwargs.get("zlim", None)
self.axis_xticks = kwargs.get("xticks", None)
self.axis_yticks = kwargs.get("yticks", None)
self.axis_zticks = kwargs.get("zticks", None)
self.axis_xlog = kwargs.get("xlog", False)
self.axis_ylog = kwargs.get("ylog", False)
self.axis_zlog = kwargs.get("zlog", False)
[docs] def set_view(self,elevation=None,angle=None):
"""Set view in 3D plot."""
self.view_elevation = elevation
self.view_angle = angle
##### Functions to add data sets #####
[docs] def add_dataset(self,dataset):
"""Add DataSet object."""
try:
self.datasets.append(dataset) # Append to data sets
self.num_datasets += 1
except:
print("Cannot add data set")
##### Plotting functions #####
def initialise_plot(self):
"""Initialise axis and figure"""
# Initialise plot if not already called as 2D or 3D plot
if self.initialised:
pass
else:
if self.dimensions == 2:
self.fig, self.ax = plt.subplots()
elif self.dimensions == 3:
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111, projection='3d')
self.initialised = True
[docs] def plot(self):
"""Plot graphs."""
if self.dimensions == 2:
self.initialise_plot()
for i,dataset in enumerate(self.datasets):
if dataset.plot_type == 'scatter':
self.scatter_2d(dataset)
elif dataset.plot_type == 'line':
self.line_2d(dataset)
elif dataset.plot_type == 'error_bar':
self.errorbar_2d(dataset)
elif dataset.plot_type == 'error_shade':
self.errorshade_2d(dataset)
elif dataset.plot_type == 'bar':
self.bar_2d(dataset,i)
elif dataset.plot_type == 'heat':
self.heat_2d(dataset)
elif dataset.plot_type == 'contour':
self.contour_2d(dataset)
elif self.dimensions == 3:
self.initialise_plot()
for dataset in self.datasets:
if dataset.plot_type == 'scatter':
self.scatter_3d(dataset)
elif dataset.plot_type == 'line':
self.line_3d(dataset)
elif dataset.plot_type == 'surface_mesh':
self.surfacemesh_3d(dataset)
elif dataset.plot_type == 'surface_points':
self.surfacepoints_3d(dataset)
def scatter_2d(self,dataset):
"""Scatter graph in 2D"""
if dataset.colour_map is None:
self.ax.scatter(dataset.data[:,0], dataset.data[:,1], label=dataset.label, zorder=dataset.zorder,
marker=dataset.marker_style, s=dataset.marker_size,
color=dataset.colour)
else:
self.ax.scatter(dataset.data[:,0], dataset.data[:,1], label=dataset.label, zorder=dataset.zorder,
marker=dataset.marker_style, s=dataset.marker_size,
c=dataset.colour, cmap=dataset.colour_map, norm=dataset.colour_norm)
def scatter_3d(self,dataset):
"""Scatter graph in 3D"""
if dataset.colour_map is None:
self.ax.scatter(dataset.data[:,0], dataset.data[:,1], dataset.data[:,2], label=dataset.label, zorder=dataset.zorder,
marker=dataset.marker_style, s=dataset.marker_size,
color=dataset.colour)
else:
self.ax.scatter(dataset.data[:,0], dataset.data[:,1], dataset.data[:,2], label=dataset.label, zorder=dataset.zorder,
marker=dataset.marker_style, s=dataset.marker_size,
c=dataset.colour, cmap=dataset.colour_map, norm=dataset.colour_norm)
def line_2d(self,dataset):
"""Line graph in 2D"""
self.ax.plot(dataset.data[:,0], dataset.data[:,1], label=dataset.label, zorder=dataset.zorder,
marker=dataset.marker_style, ms=dataset.marker_size,
lw=dataset.line_width, ls=dataset.line_style,
color=dataset.colour)
def line_3d(self,dataset):
"""Line graph in 3D"""
self.ax.plot(dataset.data[:,0], dataset.data[:,1], dataset.data[:,2], label=dataset.label, zorder=dataset.zorder,
marker=dataset.marker_style, ms=dataset.marker_size,
lw=dataset.line_width, ls=dataset.line_style,
color=dataset.colour)
def errorbar_2d(self,dataset):
"""Line graph with symmetric errors in 2D"""
self.ax.errorbar(dataset.data[:,0], dataset.data[:,1], xerr=dataset.error_x, yerr=dataset.error_y,
label= dataset.label, zorder=dataset.zorder, errorevery=dataset.error_interval,
marker=dataset.marker_style, ms=dataset.marker_size,
lw=dataset.line_width, ls=dataset.line_style,
elinewidth=dataset.error_width, capsize=dataset.error_cap,
color=dataset.colour)
def errorshade_2d(self,dataset):
"""Line graph with shaded region indicating y error"""
data = dataset.data
y1 = data[:,1] - dataset.error_y
y2 = data[:,1] + dataset.error_y
self.ax.fill_between(data[:,0],y1=y1,y2=y2,label=dataset.label, zorder=dataset.zorder, color=dataset.colour)
def bar_2d(self,dataset,shift):
"""Bar graph"""
total_bw = dataset.bar_width
bw = total_bw/self.num_datasets
data = dataset.data
data[:,0] = data[:,0] - total_bw/2 + bw/2 + shift*bw
self.ax.bar(data[:,0],data[:,1],label=dataset.label,zorder=dataset.zorder,
width=bw,color=dataset.colour,
xerr=dataset.error_x,yerr=dataset.error_y,error_kw={'zorder':dataset.zorder+self.num_datasets})
def heat_2d(self,dataset):
"""Heat map"""
x = dataset.data[0]
y = dataset.data[1]
z = dataset.data[2]
self.ax.imshow(z,origin="lower",cmap=dataset.colour_map,norm=dataset.colour_norm,aspect='auto',
extent=(np.min(x),np.max(x),np.min(y),np.max(y)),interpolation=dataset.surface_interpolation)
def contour_2d(self,dataset):
"""Contour plot"""
self.ax.contour(dataset.data[0],dataset.data[1],dataset.data[2],levels=dataset.contour_levels,cmap=dataset.colour_map,norm=dataset.colour_norm,
linewidths=dataset.line_width,linestyles=dataset.line_style)
def surfacemesh_3d(self,dataset):
"""Surface plot in 3D using mesh"""
self.ax.plot_surface(dataset.data[0],dataset.data[1],dataset.data[2],label=dataset.label, zorder=dataset.zorder,
cmap=dataset.colour_map,norm=dataset.colour_norm)
def surfacepoints_3d(self,dataset):
"""Surface plot in 3D using points"""
self.ax.plot_trisurf(dataset.data[:,0],dataset.data[:,0],dataset.data[:,2],label=dataset.label,zorder=dataset.zorder,
cmap=dataset.colour_map,norm=dataset.colour_norm)
def finalise_plot(self):
"""Finalise plot properties - called when saved or visualised"""
# Finalise plot if not already called
if self.finalised:
pass
else:
# Axes properties
# Labels
self.ax.set_xlabel(self.axis_xlabel)
self.ax.set_ylabel(self.axis_ylabel)
# X ticks
if self.axis_xticks is not None:
x_major = self.axis_xticks[0]
x_minor = self.axis_xticks[1]
else:
x_major_locator = self.ax.xaxis.get_major_locator()
auto_major = x_major_locator()
x_major = auto_major[1]-auto_major[0]
x_minor = x_major/5
# Y ticks
if self.axis_yticks is not None:
y_major = self.axis_yticks[0]
y_minor = self.axis_yticks[1]
else:
y_major_locator = self.ax.yaxis.get_major_locator()
auto_major = y_major_locator()
y_major = auto_major[1]-auto_major[0]
y_minor = y_major/5
# X limits
if self.axis_xlim is not None:
self.ax.set_xlim(self.axis_xlim)
else:
auto_xlim=self.ax.get_xlim()
xlim=[np.round(auto_xlim[0]/x_major)*x_major,np.round(auto_xlim[1]/x_major)*x_major]
# if xlim[0]>auto_xlim[0]: xlim[0]-=major_tick
# if xlim[1]<auto_xlim[1]: xlim[1]+=major_tick
self.ax.set_xlim(xlim)
x_minor_locator = ticker.MultipleLocator(x_minor)
x_major_locator = ticker.MultipleLocator(x_major)
self.ax.xaxis.set_minor_locator(x_minor_locator)
self.ax.xaxis.set_major_locator(x_major_locator)
# Y limits
if self.axis_ylim is not None:
self.ax.set_ylim(self.axis_ylim)
else:
auto_ylim=self.ax.get_ylim()
ylim=[(np.round(auto_ylim[0]/y_major)-0.5)*y_major,(np.round(auto_ylim[1]/y_major)+0.5)*y_major]
# if ylim[0]>auto_ylim[0]: ylim[0]-=major_tick
# if ylim[1]<auto_ylim[1]: ylim[1]+=major_tick
self.ax.set_ylim(ylim)
y_minor_locator = ticker.MultipleLocator(y_minor)
y_major_locator = ticker.MultipleLocator(y_major)
self.ax.yaxis.set_minor_locator(y_minor_locator)
self.ax.yaxis.set_major_locator(y_major_locator)
# Log scales
if self.axis_xlog: self.ax.set_xscale('log')
if self.axis_ylog: self.ax.set_yscale('log')
if self.dimensions == 3:
# 3D additions
self.ax.set_zlabel(self.axis_zlabel)
if self.axis_zlim is not None: self.ax.set_zlim(self.axis_zlim)
# No minor locator in 3D
if self.axis_zticks is not None:
major_locator = ticker.MultipleLocator(self.axis_zticks[0])
self.ax.zaxis.set_major_locator(major_locator)
if self.axis_zlog: self.ax.set_zscale('log')
# Set 3D specific options
self.ax.view_init(elev=self.view_elevation,azim=self.view_angle)
# Remove coloured panes and adjust grid
# self.ax.grid(False)
self.ax.xaxis.pane.fill = False
self.ax.yaxis.pane.fill = False
self.ax.zaxis.pane.fill = False
self.ax.xaxis.pane.set_edgecolor('k')
self.ax.yaxis.pane.set_edgecolor('k')
self.ax.zaxis.pane.set_edgecolor('k')
self.ax.xaxis.pane.set_alpha(1)
self.ax.yaxis.pane.set_alpha(1)
self.ax.zaxis.pane.set_alpha(1)
# Legend properties
if self.legend:
handles, labels = self.ax.get_legend_handles_labels()
if self.legend_reverse:
handles = handles[::-1]
labels = labels[::-1]
if self.legend_anchor is None:
legend = self.ax.legend(handles, labels, title=self.legend_title,
ncol=self.legend_columns, loc=self.legend_location)
else:
legend = self.ax.legend(handles, labels, title=self.legend_title, ncol=self.legend_columns,
bbox_to_anchor=self.legend_anchor)
legend.get_frame().set_edgecolor('grey')
# Reset ids to reuse auto-colours and markers
self.datasets[0].__class__.auto_id = 0
self.finalised = True
##### Save or visualise #####
[docs] def display(self):
"""Display figure."""
self.finalise_plot() # Apply final changes to plot
plt.show()
self.initialised = False
self.finalised = False
[docs] def save(self, name="plot", fmt="pdf", dpi_quality=400):
"""Save figure."""
self.finalise_plot() # Apply final changes to plot
filename = name+"."+fmt
if self.dimensions == 2:
plt.savefig(filename, dpi=dpi_quality, bbox_inches="tight")
elif self.dimensions == 3: # Prevent cutoff
plt.savefig(filename, dpi=dpi_quality)