"""
06 Aug 2013
"""
from __future__ import print_function
from warnings import warn
from subprocess import Popen
from itertools import product
import numpy as np
import copy
from scipy import interpolate
try:
from matplotlib.ticker import MultipleLocator
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import FuncFormatter
except ImportError:
warn('matplotlib not found\n')
try:
basestring
except NameError:
basestring = str
NTH = {
1 : "First",
2 : "Second",
3 : "Third",
4 : "Fourth",
5 : "Fifth",
6 : "Sixth",
7 : "Seventh",
8 : "Eighth",
9 : "Ninth",
10: "Tenth",
11: "Eleventh",
12: "Twelfth"
}
def setup_plot(axe, figsize=None):
if axe:
ax = axe
fig = ax.get_figure()
else:
fig = plt.figure(figsize=(11, 5) if not figsize else figsize)
ax = fig.add_subplot(111)
ax.patch.set_facecolor('lightgrey')
ax.patch.set_alpha(0.4)
ax.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
ax.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
ax.set_axisbelow(True)
ax.minorticks_on() # always on, not only for log
# remove tick marks
ax.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False)
ax.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False, which='minor')
return ax
def tadbit_savefig(savefig):
try:
form = savefig[-4:].split('.')[1]
except IndexError: # no dot in file name
warn('WARNING: file extension not found saving in png')
form = 'png'
if not form in ['png', 'pdf', 'ps', 'eps', 'svg']:
raise NotImplementedError('File extension must be one of %s' %(
['png', 'pdf', 'ps', 'eps', 'svg']))
plt.savefig(savefig, format=form)
def nicer(res, sep=' ', comma='', allowed_decimals=0):
"""
writes resolution number for human beings.
:param ' ' sep: character between number and unit (e.g. default: '125 kb')
:param '' comma: character to separate groups of thousands
:param 0 allowed_decimals: if 1 '1900 kb' would be written as '1.9 Mb'
"""
format = lambda x: '{:,g}'.format(x).replace(',', comma)
if not res:
return format(res) + sep + 'b'
if not res % 10**(9 - allowed_decimals):
return format(res / 10.**9) + sep + 'Gb'
if not res % 10**(6 - allowed_decimals):
return format(res / 10.**6) + sep + 'Mb'
if not res % 10**(3 - allowed_decimals):
return format(res / 10.**3) + sep + 'kb'
return format(res) + sep + 'b'
COLOR = {None: '\033[31m', # red
0 : '\033[34m', # blue
1 : '\033[34m', # blue
2 : '\033[34m', # blue
3 : '\033[36m', # cyan
4 : '\033[0m' , # white
5 : '\033[1m' , # bold white
6 : '\033[33m', # yellow
7 : '\033[33m', # yellow
8 : '\033[35m', # purple
9 : '\033[35m', # purple
10 : '\033[31m' # red
}
COLORHTML = {None: '<span style="color:red;">' , # red
0 : '<span>' , # blue
1 : '<span style="color:blue;">' , # blue
2 : '<span style="color:blue;">' , # blue
3 : '<span style="color:purple;">' , # purple
4 : '<span style="color:purple;">' , # purple
5 : '<span style="color:teal;">' , # cyan
6 : '<span style="color:teal;">' , # cyan
7 : '<span style="color:olive;">' , # yellow
8 : '<span style="color:olive;">' , # yellow
9 : '<span style="color:red;">' , # red
10 : '<span style="color:red;">' # red
}
def colorize(string, num, ftype='ansi'):
"""
Colorize with ANSII colors a string for printing in shell. this acording to
a given number between 0 and 10
:param string: the string to colorize
:param num: a number between 0 and 10 (if None, number will be equal to 10)
:returns: the string 'decorated' with ANSII color code
"""
color = COLOR if ftype=='ansi' else COLORHTML
return '%s%s%s' % (color[num], string,
'\033[m' if ftype=='ansi' else '</span>')
[docs]def color_residues(x, **kwargs):
"""
Function to color residues from blue to red.
:param model: a given :class:`pytadbit.imp.impmodel.IMPmodel`
:returns: a list of rgb tuples (red, green, blue), each between 0 and 1.
"""
result = []
for n in range(len(x)):
red = float(n + 1) / len(x)
result.append((red, 0, 1 - red))
return result
[docs]def tad_coloring(x, mstart=None, mend=None, tads=None, **kwargs):
"""
Colors TADs from blue to red (first to last TAD). TAD borders are displayed
in scale of grey, from light to dark grey (again first to last border)
:param model: a given :class:`pytadbit.imp.impmodel.IMPmodel`
:param tads: a dictionary of TADs, Experiments.tads can be directly passed
:returns: a list of rgb tuples (red, green, blue), each between 0 and 1.
"""
ltads = [t for t in tads if tads[t]['end'] > mstart
and tads[t]['start'] < mend]
ntads = len(ltads)
grey = 0.95
try:
grey_step = (0.95 - 0.4) / ntads
except ZeroDivisionError:
raise Exception('ERROR: TAD borders are not predicted yet.')
result = []
for t in tads:
if tads[t]['end'] < mstart or tads[t]['start'] > mend:
continue
red = float(t + 1 - min(ltads)) / ntads
for _ in range(int(max(tads[t]['start'], mstart)),
int(min(tads[t]['end'], mend))):
result.append((red, 0, 1 - red))
if tads[t]['end'] <= mend:
result.append((grey, grey, grey))
grey -= grey_step
return result
[docs]def tad_border_coloring(x, mstart=None, mend=None, tads=None, **kwargs):
"""
Colors TAD borders from blue to red (bad to good score). TAD are displayed
in scale of grey, from light to dark grey (first to last particle in the
TAD)
:param model: a given :class:`pytadbit.imp.impmodel.IMPmodel`
:param tads: a dictionary of TADs, Experiments.tads can be directly passed
:returns: a list of rgb tuples (red, green, blue), each between 0 and 1.
"""
result = []
if not tads:
raise Exception('ERROR: TAD borders are not predicted yet.')
for t in tads:
if tads[t]['end'] < mstart or tads[t]['start'] > mend:
continue
grey = 0.95
grey_step = (0.95 - 0.4) / (tads[t]['end'] - tads[t]['start'] + 1)
red = float(tads[t]['score']) / 10
for _ in range(int(max(tads[t]['start'], mstart)),
int(min(tads[t]['end'], mend))):
result.append((grey, grey, grey))
grey -= grey_step
if tads[t]['end'] <= mend:
result.append((red, 0, 1 - red))
return result
def augmented_dendrogram(clust_count=None, dads=None, objfun=None, color=False,
axe=None, savefig=None, *args, **kwargs):
"""
"""
from scipy.cluster.hierarchy import dendrogram
fig = plt.figure(figsize=kwargs.get('figsize', (8,8)))
if axe:
ax = axe
fig = axe.get_figure()
ddata = dendrogram(*args)
plt.clf()
else:
ddata = dendrogram(*args)
plt.clf()
ax = fig.add_subplot(111)
ax.patch.set_facecolor('lightgrey')
ax.patch.set_alpha(0.4)
ax.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
ax.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
ax.set_axisbelow(True)
# remove tick marks
ax.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False)
ax.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False, which='minor')
# set dict to store data of each cluster (count and energy), depending on
# x position in graph.
debug = False
leaves = {}
dist = ddata['icoord'][0][2] - ddata['icoord'][0][1]
for i, x in enumerate(ddata['leaves']):
leaves[dist*i + dist/2] = x
minnrj = min(objfun.values())
maxnrj = max(objfun.values())
difnrj = maxnrj - minnrj
total = max(clust_count.values())
if not kwargs.get('no_plot', False):
for i, d, c in zip(ddata['icoord'], ddata['dcoord'],
ddata['color_list']):
x = 0.5 * sum(i[1:3])
y = d[1]
# plt.plot(x, y, 'ro')
plt.hlines(y, i[1], i[2], lw=2, color=(c if color else 'grey'))
# for each branch
for i1, d1, d2 in zip(i[1:3], [d[0], d[3]], [d[1], d[2]]):
try:
lw = (fig.get_figwidth() *
kwargs.get('width_factor', 5.0) *
float(clust_count[leaves[i1] + 1]) / total)
except KeyError:
lw = 1.0
nrj = objfun[leaves[i1] + 1] if (leaves[i1] + 1) in objfun else maxnrj
d1 = d1 - (difnrj - (nrj - minnrj))
ax.vlines(i1, d1, d2, lw=lw, color=(c if color else 'grey'))
if leaves[i1] + 1 in objfun or debug:
ax.annotate("%.3g" % (leaves[i1] + 1),
(i1, d1), size=kwargs.get('fontsize', 8),
xytext=(0, -8),
textcoords='offset points',
va='top', ha='center')
leaves[(i[1] + i[2])/2] = dads[leaves[i[1]] + 1] - 1
try:
cutter = 10**int(np.log10(difnrj))
except OverflowError: # case that the two are exactly the same
cutter = 1
cut = 10 if cutter >= 10 else 1
bot = (-int(difnrj)//cutter * cutter) or -1 # do not want this to be null
# just to display nice numbers
form = lambda x: ''.join([(s + ',') if not i%3 and i else s
for i, s in enumerate(str(x)[::-1])][::-1])
plt.yticks([bot+i for i in range(0, -bot-bot//cut, -bot//cut)],
# "{:,}".format (int(minnrj)/cutter * cutter + i)
["%s" % (form(int(minnrj)//cutter * cutter + i))
for i in range(0, -bot-bot//cut, -bot//cut)],
size=kwargs.get('fontsize', 8))
ax.set_ylabel('Minimum IMP objective function',
size=kwargs.get('fontsize', 8) + 1)
ax.set_xticks([])
ax.set_xlim((plt.xlim()[0] - 2, plt.xlim()[1] + 2))
ax.figure.suptitle("Dendogram of clusters of 3D models",
size=kwargs.get('fontsize', 8) + 2)
ax.set_title("Branch length proportional to model's objective function " +
"final value\n" +
"Branch width to the number of models in the cluster " +
"(relative to %s models)" % (total),
size=kwargs.get('fontsize', 8))
if savefig:
tadbit_savefig(savefig)
elif not axe:
plt.show()
return ddata
def plot_hist_box(data, part1, part2, axe=None, savefig=None):
# setup the figure and axes
if axe:
fig = axe.get_figure()
else:
fig = plt.figure(figsize=(6, 6))
bpAx = fig.add_axes([0.2, 0.7, 0.7, 0.2]) # left, bottom, width, height:
# (adjust as necessary)
bpAx.patch.set_facecolor('lightgrey')
bpAx.patch.set_alpha(0.4)
bpAx.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
bpAx.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
bpAx.set_axisbelow(True)
bpAx.minorticks_on() # always on, not only for log
# remove tick marks
bpAx.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False)
bpAx.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False, which='minor')
# plot stuff
bp = bpAx.boxplot(data, vert=False)
plt.setp(bp['boxes'], color='black')
plt.setp(bp['whiskers'], color='black')
plt.setp(bp['medians'], color='darkred')
plt.setp(bp['fliers'], color='darkred', marker='+')
bpAx.plot(sum(data)/len(data), 1,
color='w', marker='*', markeredgecolor='k')
bpAx.annotate('%.1f' % (bp['boxes'][0].get_xdata()[0]),
(bp['boxes'][0].get_xdata()[0], bp['boxes'][0].get_ydata()[1]),
va='bottom', ha='center', xytext=(0, 2),
textcoords='offset points',
size='small')
bpAx.annotate('%.1f' % (bp['boxes'][0].get_xdata()[2]),
(bp['boxes'][0].get_xdata()[2], bp['boxes'][0].get_ydata()[1]),
va='bottom', ha='center', xytext=(0, 2),
textcoords='offset points',
size='small')
bpAx.annotate('%.1f' % (bp['medians'][0].get_xdata()[0]),
(bp['medians'][0].get_xdata()[0], bp['boxes'][0].get_ydata()[0]),
va='top', ha='center', xytext=(0, -2),
textcoords='offset points', color='darkred',
size='small')
histAx = fig.add_axes([0.2, 0.2, 0.7, 0.5]) # left specs should match and
# bottom + height on this line should
# equal bottom on bpAx line
histAx.patch.set_facecolor('lightgrey')
histAx.patch.set_alpha(0.4)
histAx.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
histAx.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
histAx.set_axisbelow(True)
histAx.minorticks_on() # always on, not only for log
# remove tick marks
histAx.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False)
histAx.tick_params(axis='both', direction='out', top=False, right=False,
left=False, bottom=False, which='minor')
h = histAx.hist(data, bins=20, alpha=0.5, color='darkgreen')
# confirm that the axes line up
xlims = np.array([bpAx.get_xlim(), histAx.get_xlim()])
for ax in [bpAx, histAx]:
ax.set_xlim([xlims.min(), xlims.max()])
bpAx.set_xticklabels([]) # clear out overlapping xlabels
bpAx.set_yticks([]) # don't need that 1 tick mark
plt.xlabel('Distance between particles (nm)')
plt.ylabel('Number of observations')
bpAx.set_title('Histogram and boxplot of distances between particles ' +
'%s and %s' % (part1, part2))
if savefig:
tadbit_savefig(savefig)
elif not axe:
plt.show()
[docs]def plot_3d_model(x, y, z, label=False, axe=None, thin=False, savefig=None,
show_axe=False, azimuth=-90, elevation=0., color='index',
smooth=0.001, particle_size=50, alpha_part=0.5, lw_main=3,
**kwargs):
"""
Given a 3 lists of coordinates (x, y, z) plots a three-dimentional model
using matplotlib
:param model: a :class:`pytadbit.imp.impmodel.IMPmodel`
:param False label: show labels
:param None axe: a matplotlib.axes.Axes object to define the plot
appearance
:param False thin: draw a thin black line instead of representing particles
and edges
:param 0.001 smooth: connction between particles is smoothed according to
the input condition
:param None savefig: path to a file where to save the image generated;
if None, the image will be shown using matplotlib GUI (the extension
of the file name will determine the desired format).
:param -90 azimuth: angle to rotate camera along the y axis
:param 0 elevation: angle to rotate camera along the x axis
:param 'index' color: can be:
* a string as:
* '**index**' to color particles according to their position in the
model (:func:`pytadbit.utils.extraviews.color_residues`)
* '**tad**' to color particles according to the TAD they belong to
(:func:`pytadbit.utils.extraviews.tad_coloring`)
* '**border**' to color particles marking borders. Color according to
their score (:func:`pytadbit.utils.extraviews.tad_border_coloring`)
coloring function like.
* a function, that takes as argument a model and any other parameter
passed through the kwargs.
* a list of (r, g, b) tuples (as long as the number of particles).
Each r, g, b between 0 and 1.
"""
show = False
if isinstance(color, basestring):
if color == 'index':
color = color_residues(x, **kwargs)
elif color == 'tad':
if not 'tads' in kwargs:
raise Exception('ERROR: missing TADs\n ' +
'pass an Experiment.tads disctionary\n')
color = tad_coloring(x, **kwargs)
elif color == 'border':
if not 'tads' in kwargs:
raise Exception('ERROR: missing TADs\n ' +
'pass an Experiment.tads disctionary\n')
color = tad_border_coloring(x, **kwargs)
else:
raise NotImplementedError(('%s type of coloring is not yet ' +
'implemeted\n') % color)
elif hasattr(color, '__call__'): # its a function
color = color(x, **kwargs)
elif not isinstance(color, list):
raise TypeError('one of function, list or string is required\n')
if not axe:
fig = plt.figure(figsize=kwargs.get('figsize', (8, 8)))
axe = fig.add_subplot(1, 1, 1, projection='3d')
show = True
if not show_axe:
axe._axis3don = False
axe.view_init(elev=elevation, azim=azimuth)
if thin:
if smooth is not False:
tck, u= interpolate.splprep([x, y, z], s=smooth)
#here we generate the new interpolated dataset,
#increase the resolution by increasing the spacing, 500 in this example
xs, ys, zs = interpolate.splev(np.linspace(0,1,500), tck)
axe.plot(xs, ys, zs, color='black', lw=1, alpha=0.2)
else:
axe.plot(x, y, z, color='black', lw=1, alpha=0.2)
else:
if smooth is not False:
tck, u= interpolate.splprep([x, y, z], s=smooth)
#here we generate the new interpolated dataset,
#increase the resolution by increasing the spacing, 500 in this example
xs, ys, zs = interpolate.splev(np.linspace(0,1,500), tck)
axe.plot(xs, ys, zs, color='lightgrey', lw=lw_main, alpha=0.8)
else:
for i in range(len(x)-1):
axe.plot(x[i:i+2], y[i:i+2], z[i:i+2],
color=color[i], lw=lw_main)
if label:
axe.text(x[i], y[i], z[i], str(i), size=7)
if label:
for i in range(len(x)):
axe.text(x[i], y[i], z[i],str(i), size=7)
axe.scatter(x, y, z, color=color, s=particle_size, alpha=alpha_part)
axe.pbaspect = [1,1,1]
if show:
if savefig:
tadbit_savefig(savefig)
else:
plt.show()
[docs]def chimera_view(cmm_files, chimera_bin='chimera', #shape='tube',
chimera_cmd=None, savefig=None, center_of_mass=False,
gyradius=False, align=True, grid=False, highlight='all',
**kwargs):
"""
Open a list of .cmm files with Chimera (http://www.cgl.ucsf.edu/chimera)
to view models.
:param cmm_files: list of .cmm files
:param 'chimera' chimera_bin: path to chimera binary
:param None chimera_cmd: list of command lines for chimera
:param None savefig: path to a file where to save generated image
:param False center_of_mass: if True, draws the center of mass
:param False gyradius: increment the radius of the center of mass by the
value of the radius of girations given.
:param False align: align models
:param False grid: tile models
:param 'all' higlight: higlights a given model, or group of models.
Can be either 'all', or a model number
"""
pref_f = '/tmp/tmp.cmd'
out = open(pref_f, 'w')
for cmm_file in cmm_files:
out.write('open %s\n' % (cmm_file))
nmodels = len(cmm_files)
if nmodels > 1:
for i in range(nmodels):
if i == 0:
continue
if align:
out.write('match #%s #%s\n' % (i, 0))
if highlight != 'all':
if highlight != i:
out.write('color black #%s\n' % (i))
if not chimera_cmd:
the_shape = '''bonddisplay never #%s
shape tube #%s radius 15 bandLength 300 segmentSubdivisions 1 followBonds on
~show #%s'''
out.write(('''
focus
set bg_color white
windowsize 800 600
represent wire
%s
clip yon -500
~label
set subdivision 1
set depth_cue
set dc_color black
set dc_start 0.5
set dc_end 1
scale 0.8%s\n
''' % ('\n'.join([the_shape % (mdl, mdl, mdl) for mdl in (
[highlight] if highlight!='all' else list(range(nmodels)))]),
'\ntile' if grid else '')) +
('define centroid radius %s color 1,0,0,0.2\n' % (
gyradius if gyradius else 10) if center_of_mass else '')
+ (kwargs.get('extra', '')))
if savefig:
if savefig.endswith('.png'):
out.write('copy file %s png' % (savefig))
elif savefig[-4:] in ('.mov', 'webm'):
out.write('''
movie record supersample 1
turn y 3 120
wait 120
movie stop
movie encode output %s
''' % (savefig))
elif savefig:
raise Exception('Not supported format, must be png, mov or webm\n')
else:
out.write('\n'.join(chimera_cmd) + '\n')
out.close()
Popen('%s %s' % (chimera_bin, pref_f), shell=True)
[docs]def plot_3d_optimization_result(result,
axes=('scale', 'maxdist', 'upfreq', 'lowfreq')):
"""
Displays a three dimensional scatter plot representing the result of the
optimization.
:param result: 3D numpy array contating correlation values
:param 'scale','maxdist','upfreq','lowfreq' axes: tuple of axes to
represent. The order will define which parameter will be placed on the
w, z, y or x axe.
"""
ori_axes, axes_range, result = result
trans = [ori_axes.index(a) for a in axes]
axes_range = [axes_range[i] for i in trans]
# transpose results
result = result.transpose(trans)
round_decs = 6
wax = [my_round(i, round_decs) for i in axes_range[0]]
zax = [my_round(i, round_decs) for i in axes_range[1]]
xax = [my_round(i, round_decs) for i in axes_range[3]]
yax = [my_round(i, round_decs) for i in axes_range[2]]
sort_result = sorted([(result[i, j, k, l], wax[i], zax[j], xax[l], yax[k])
for i in range(len(wax))
for j in range(len(zax))
for k in range(len(yax))
for l in range(len(xax))
if not np.isnan(result[i, j, k, l])
], key=lambda x: x[0],
reverse=True)[0]
x = [i for i in axes_range[1] for j in axes_range[2] for k in axes_range[3]]
y = [j for i in axes_range[1] for j in axes_range[2] for k in axes_range[3]]
z = [k for i in axes_range[1] for j in axes_range[2] for k in axes_range[3]]
from mpl_toolkits.mplot3d import Axes3D
ncols = int(np.sqrt(len(wax)) + 0.999)
nrows = int(np.sqrt(len(wax)) + 0.5)
fig = plt.figure(figsize=((ncols)*6,(nrows)*4.5))
for i in range(len(wax)):
col = [result[i, j, k, l] for j in range(len(axes_range[1]))
for k in range(len(axes_range[2])) for l in range(len(axes_range[3]))]
ax = fig.add_subplot(int(str(nrows) + str(ncols) + str(i)),
projection='3d')
ax.set_xlabel(axes[1])
ax.set_ylabel(axes[2])
ax.set_zlabel(axes[3])
lol = ax.scatter(x, y, z, c=col, s=100, alpha=0.9)
cbar = fig.colorbar(lol)
cbar.ax.set_ylabel('Correlation value')
tit = 'Optimal IMP parameters (subplot %s=%s)\n' % (axes[0], wax[i])
tit += 'Best: %s=%%s, %s=%%s, %s=%%s, %s=%%s' % (axes[0], axes[1],
axes[2], axes[3])
plt.title(tit % tuple([my_round(r, 3) for r in sort_result[1:]]))
plt.show()
def my_round(num, val):
num = round(num, val)
return int(num) if num == int(num) else num
[docs]def plot_2d_optimization_result(result,
axes=('scale', 'kbending', 'maxdist',
'lowfreq', 'upfreq'), dcutoff=None,
show_best=0, skip=None, savefig=None,
clim=None, cmap='inferno'):
"""
A grid of heatmaps representing the result of the optimization. In the optimization
up to 5 parameters can be optimized: 'scale', 'kbending', 'maxdist', 'lowfreq', and 'upfreq'.
The maps will be divided in different pages depending on the 'scale' and 'kbending' values.
In each page there will be different maps depending the 'maxdist' values.
Each map has 'upfreq' values along the x-axes, and 'lowfreq' values along the y-axes.
:param result: 3D numpy array contating the computed correlation values
:param 'scale','kbending','maxdist','lowfreq','upfreq' axes: tuple of axes
to represent. The order is important here. It will define which parameter
will be placed respectively on the v, w, z, y, or x axes.
:param 0 show_best: number of best correlation value to highlight in the heatmaps.
The best correlation is highlithed by default
:param None skip: a dict can be passed here in order to fix a given parameter value,
e.g.: {'scale': 0.001, 'kbending': 30, 'maxdist': 500} will represent all the
correlation values at fixed 'scale', 'kbending', and 'maxdist' values,
respectively equal to 0.001, 30, and 500.
:param None dcutoff: The distance cutoff (dcutoff) used to compute the contact matrix
in the models.
:param None savefig: path to a file where to save the generated image.
If None, the image will be displayed using matplotlib GUI. NOTE: the extension
of the file name will automatically determine the desired format.
:param None clim: color scale. If None, the max and min values of the input are used.
:param inferno cmap: matplotlib colormap
"""
from mpl_toolkits.axes_grid1 import AxesGrid
import matplotlib.patches as patches
ori_axes, axes_range, result = result
# Commands for compatibility with the OLD version:
#print axes_range
if len(axes_range) == 4:
print("I'm here!!!")
tmp_axes_range = axes_range
tmp_axes_range[1] = [0.0] # kbending !!!New option!!!
len_kbending_range = 1
for i in range(len(ori_axes)):
if ori_axes[i] == 'scale':
tmp_axes_range[0] = axes_range[i] # scale
len_scale_range = len(axes_range[i])
scale_index = i
if ori_axes[i] == 'maxdist':
tmp_axes_range[2] = axes_range[i] # maxdist
len_maxdist_range = len(axes_range[i])
maxdist_index = i
if ori_axes[i] == 'lowfreq':
tmp_axes_range[3] = axes_range[i] # lowfreq
len_lowfreq_range = len(axes_range[i])
lowfreq_index = i
if ori_axes[i] == 'upfreq':
tmp_axes_range[4] = axes_range[i] # upfreq
len_upfreq_range = len(axes_range[i])
upfreq_index = i
#print axes_range
tmp_result = np.empty((len_scale_range , len_kbending_range, len_maxdist_range,
len_lowfreq_range, len_upfreq_range))
indeces_sets = product(list(range(len(axes_range[0]))),
list(range(len(axes_range[1]))),
list(range(len(axes_range[2]))),
list(range(len(axes_range[3]))))
for indeces_set in indeces_sets:
tmp_indeces_set = [0, 0, 0, 0, 0]
tmp_indeces_set[0] = indeces_set[scale_index] # scale
tmp_indeces_set[1] = 0 # kbending
tmp_indeces_set[2] = indeces_set[maxdist_index] # maxdist
tmp_indeces_set[3] = indeces_set[lowfreq_index] # lowfreq
tmp_indeces_set[4]= indeces_set[upfreq_index] # upfreq
tmp_result[tmp_indeces_set] = result[indeces_set]
ori_axes = ('scale', 'kbending', 'maxdist', 'lowfreq', 'upfreq')
axes_range = tmp_axes_range
result = tmp_result
trans = [ori_axes.index(a) for a in axes]
axes_range = [axes_range[i] for i in trans]
# transpose results
result = result.transpose(trans)
# set NaNs
result = np.ma.array(result, mask=np.isnan(result))
cmap = copy.copy(plt.get_cmap(cmap))
cmap.set_bad('w', 1.)
# defines axes
if clim:
vmin=clim[0]
vmax=clim[1]
else:
vmin = result.min()
vmax = result.max()
round_decs = 6
# Here we round the values in axes_range and pass from the
# 5 parameters to the cartesian axes names.
vax = [my_round(i, round_decs) for i in axes_range[0]] # scale
wax = [my_round(i, round_decs) for i in axes_range[1]] # kbending
zax = [my_round(i, round_decs) for i in axes_range[2]] # maxdist
yax = [my_round(i, round_decs) for i in axes_range[3]] # lowfreq
xax = [my_round(i, round_decs) for i in axes_range[4]] # upfreq
# This part marks the set of best correlations that the
# user wants to be highlighted in the plot
vax_range = list(range(len(vax)))[::-1] # scale
wax_range = list(range(len(wax)))[::-1] # kbending
zax_range = list(range(len(zax))) # maxdist
yax_range = list(range(len(yax))) # lowfreq
xax_range = list(range(len(xax))) # upfreq
indeces_sets = product(vax_range, wax_range,
zax_range, yax_range,
xax_range)
sort_result = sorted([(result[indeces_set],vax[indeces_set[0]],wax[indeces_set[1]],
zax[indeces_set[2]],yax[indeces_set[3]],xax[indeces_set[4]])
for indeces_set in indeces_sets if str(result[indeces_set]) != '--'],
key=lambda x: x[0], reverse=True)[:show_best+1]
# This part allows the user to "skip" some parameters to show.
# This means to fix the value of certain parameters.
skip = {} if not skip else skip
for i, parameter in enumerate(axes):
if not parameter in skip:
continue
if i == 0:
vax_range = [vax.index(skip[parameter])]
elif i == 1:
wax_range = [wax.index(skip[parameter])]
elif i == 2:
zax_range = [zax.index(skip[parameter])]
else:
raise Exception(('ERROR: skip keys must be one of the three first' +
' keywords passed as axes parameter'))
# best number of rows/columns
ncols = len(zax_range)
nrows = len(vax_range) * len(wax_range)
# width and height of each heatmap. These dimensions of each heatmap
# depend on the number of values on the x-axes, len(xax), related to
# 'upfreq', and on the y-axes, len(yax), related to 'lowfreq'. width and
# height are also multiplied by the ncols, that is the number of
# heatmaps per row (one for each value of 'maxdist'), and nrows, that is
# the number of heatmaps per column (one for each combination of 'scale' and
# 'kbending' values).
width = max(4, (float(ncols) * len(xax)) / 3)
height = max(3, (float(nrows) * len(yax)) / 3)
#print 4,float(ncols)*len(xax) / 3,width
#print 3,float(nrows)*len(yax) / 3,height
# Definition of the heatmap object
heatmap = plt.figure(figsize=(width, height))
# Here we define the grid of heatmaps.
grid = AxesGrid(heatmap, [.2, .2, .6, .5],
nrows_ncols = (nrows + 1, ncols + 1),
axes_pad = 0.0,
label_mode = "1",
share_all = False,
cbar_location="right",
cbar_mode="single",
# cbar_size="%s%%" % (20./ width),
cbar_pad="30%",
)
cell = ncols
used = []
for row in product(vax_range,wax_range):
cell+=1
for column in zax_range:
used.append(cell)
# Setting the values in the heatmap
im = grid[cell].imshow(result[row[0], row[1], column, :, :],
interpolation="nearest", origin='lower',
vmin=vmin, vmax=vmax, cmap=cmap)
# Setting the ticks of the heatmap
grid[cell].tick_params(axis='both', direction='out', top=False,
right=False, left=False, bottom=False)
for j, best in enumerate(sort_result[:-1], 1):
if best[1] == vax[row[0]] and best[2] == wax[row[1]] and best[3] == zax[column]:
#print j, best, vax[row[0]], wax[row[1]], zax[column]
grid[cell].text(xax.index(best[5]), yax.index(best[4]), str(j),
{'ha':'center', 'va':'center'}, size=8)
if row[0] == vax_range[0] and row[1] == wax_range[0]:
rect = patches.Rectangle((-0.5, len(yax)-0.5), len(xax), 1.5,
facecolor='grey', alpha=0.5)
rect.set_clip_on(False)
grid[cell].add_patch(rect)
# Set up label in the heatmap (for maxdist)
if column == 0:
#print "Cell number",cell
grid[cell].text(- (len(xax) / 2 + 0.5), len(yax)+0.25,
axes[2],
{'ha':'center', 'va':'center'}, size=8)
grid[cell].text(len(xax) / 2. - 0.5, len(yax)+0.25,
str(my_round(zax[column], round_decs)),
{'ha':'center', 'va':'center'}, size=8)
cell += 1
rect = patches.Rectangle((len(xax)-.5, -0.5), 2.5, len(yax),
facecolor='grey', alpha=0.5)
# Define the rectangles for
rect.set_clip_on(False)
grid[cell-1].add_patch(rect)
grid[cell-1].text(len(xax) + 1.0, len(yax) / 2.,
str(my_round(vax[row[0]], round_decs)) + '\n' +
str(my_round(wax[row[1]], round_decs)) + '\n' +
str(my_round(dcutoff, round_decs)),
{'ha':'center', 'va':'center'},
rotation=90, size=8)
grid[cell-1].text(len(xax) - 0.2, len(yax) + 1.2,
axes[0] + '\n' + axes[1] + '\ndcutoff',
{'ha':'left', 'va':'center'},
rotation=90, size=8)
#
for i in range(cell+1):
if not i in used:
grid[i].set_visible(False)
# This affects the axes of all the heatmaps, because the flag set share_all
# is set equal to True.
# grid.axes_llc.set_ylim(-0.5, len(yax)+1)
grid.axes_llc.set_xticks(list(range(0, len(xax), 2)))
grid.axes_llc.set_yticks(list(range(0, len(yax), 2)))
grid.axes_llc.set_xticklabels([my_round(i, round_decs) for i in xax][::2], size=9)
grid.axes_llc.set_yticklabels([my_round(i, round_decs) for i in yax][::2], size=9)
grid.axes_llc.set_xlabel(axes[4], size=9)
grid.axes_llc.set_ylabel(axes[3], size=9)
# Color bar settings
cb = plt.colorbar(im, cax=grid.cbar_axes[0])
cb.ax.set_ylabel('Correlation value', size=9)
cb.ax.tick_params(labelsize=9)
title = 'Optimal IMP parameters\n'
heatmap.suptitle(title, size=12)
#plt.tight_layout()
if savefig:
tadbit_savefig(savefig)
else:
plt.show()
[docs]def compare_models(sm1, sm2, cutoff=150,
models1=None, cluster1=None,
models2=None, cluster2=None):
"""
Plots the difference of contact maps of two group of structural models.
:param sm1: a StructuralModel
:param sm2: a StructuralModel
:param 150 dcutoff: distance threshold (nm) to determine if two
particles are in contact
:param None models: if None (default) the contact map will be computed
using all the models. A list of numbers corresponding to a given set
of models can be passed
:param None cluster: compute the contact map only for the models in the
cluster number 'cluster'
"""
mtx1 = sm1.get_contact_matrix(models=models1, cluster=cluster1, cutoff=cutoff)
mtx2 = sm2.get_contact_matrix(models=models2, cluster=cluster2, cutoff=cutoff)
mtx3 = [[mtx2[i][j] - mtx1[i][j]
for j in range(len(mtx1))]
for i in range(len(mtx1))]
fig = plt.figure(figsize=(8, 6))
axe = fig.add_subplot(111)
im = axe.imshow(mtx3, origin='lower', interpolation="nearest")
axe.set_ylabel('Particle')
axe.set_xlabel('Particle')
cbar = axe.figure.colorbar(im)
cbar.ax.set_ylabel('Signed log difference between models')
plt.show()
def _tad_density_plot(xpr, maxys=None, fact_res=1., axe=None,
focus=None, extras=None, normalized=True,
savefig=None, shape='ellipse'):
"""
"""
from matplotlib.cm import jet
show=False
if focus:
siz = focus[1] - focus[0]
figsiz = 4 + (focus[1] - focus[0]) / 30
beg, end = focus
tads = dict([(t, xpr.tads[t]) for t in xpr.tads
if (xpr.tads[t]['start'] + 1 >= beg
and xpr.tads[t]['end'] <= end)])
if not tads:
warn('WARNING: Experiment %s has no TADs in the region %d-%d' % (
xpr.name, focus[0], focus[1]))
else:
siz = xpr.size
figsiz = 4 + (siz) / 30
tads = xpr.tads
if not axe:
fig = plt.figure(figsize=(figsiz, 1 + 1 * 1.8))
axe = fig.add_subplot(111)
fig.subplots_adjust(hspace=0)
show=True
zsin = np.sin(np.linspace(0, np.pi))
shapes = {'ellipse' : lambda h: [0] + list(h * zsin) + [0],
'rectangle' : lambda h: [0] + [h] * 50 + [0],
'triangle' : lambda h: ([h/25 * i for i in range(26)] +
[h/25 * i for i in range(25, -1, -1)])}
try:
shape = shapes[shape]
except KeyError:
import this
table = ''.join([this.d.get(chr(i), chr(i)) for i in range(256)])
if locals()['funcr'.translate(table)].translate(table) == ''.join(
[this.s[i].upper() if this.s[i-1] is 'v' else this.s[i]
for i in [24, 36, 163, 8, 6, 16, 36]]):
shape = lambda h: (
[h / 25 * i for i in range(25)] + [h + 0.2] * 2 +
[h / 25 * i for i in range(24, -1, -1)])
else:
raise NotImplementedError(
'%s not valid, use one of ellipse, rectangle or triangle')
maxys = maxys if isinstance(maxys, list) else []
zeros = xpr._zeros or {}
if normalized and xpr.norm:
norms = xpr.norm[0]
elif xpr.hic_data:
if normalized:
warn("WARNING: weights not available, using raw data")
norms = xpr.hic_data[0]
else:
warn("WARNING: raw Hi-C data not available, " +
"TAD's height fixed to 1")
norms = None
if tads and not 'height' in tads[list(tads.keys())[0]]:
diags = []
siz = xpr.size
sp1 = siz + 1
if norms:
for k in range(1, siz):
s_k = siz * k
diags.append(sum([norms[i * sp1 + s_k]
if not (i in zeros
or (i + k) in zeros) else 0.
for i in range(siz - k)]) / (siz - k))
for tad in tads:
start, end = (int(tads[tad]['start']) + 1,
int(tads[tad]['end']) + 1)
if norms:
matrix = sum([norms[i + siz * j]
if not (i in zeros
or j in zeros) else 0.
for i in range(start - 1, end - 1)
for j in range(i + 1, end - 1)])
try:
if norms:
height = float(matrix) / sum(
[diags[i-1] * (end - start - i)
for i in range(1, end - start)])
else:
height = 1.
except ZeroDivisionError:
height = 0.
maxys.append(height)
start = float(start) / fact_res # facts[iex]
end = float(end) / fact_res # facts[iex]
axe.fill([start] + list(np.linspace(start, end)) + [end], shape(height),
alpha=.8 if height > 1 else 0.4,
facecolor='grey', edgecolor='grey')
else:
for tad in tads:
start, end = (int(tads[tad]['start']) + 1,
int(tads[tad]['end']) + 1)
height = float(tads[tad]['height'])
maxys.append(height)
axe.fill([start] + list(np.linspace(start, end)) + [end],
shape(height),
alpha=.8 if height > 1 else 0.4,
facecolor='grey', edgecolor='grey')
if extras:
axe.plot(extras, [.5 for _ in range(len(extras))], 'rx')
axe.grid()
axe.patch.set_visible(False)
axe.set_ylabel('Relative\nHi-C count')
#
for tad in tads:
if not tads[tad]['end']:
continue
tad = tads[tad]
axe.plot(((tad['end'] + 1.) / fact_res, ), (0., ),
color=jet(tad['score'] / 10) if tad['score'] else 'w',
mec=jet(tad['score'] / 10) if tad['score'] else 'k',
marker=6, ms=9, alpha=1, clip_on=False)
try:
axe.set_xticks([1] + list(range(100, int(tad['end'] + 1), 50)))
except UnboundLocalError:
pass
axe.minorticks_on()
axe.xaxis.set_minor_locator(MultipleLocator(10))
try:
axe.hlines(1, tads[list(tads.keys())[0]]['start'], end, 'k', lw=1.5)
except IndexError:
pass
if show:
tit1 = fig.suptitle("TAD borders", size='x-large')
plt.subplots_adjust(top=0.76)
fig.set_facecolor('white')
plots = []
for scr in range(1, 11):
plots += plt.plot((100,),(100,), marker=6, ms=9,
color=jet(float(scr) / 10), mec='none')
try:
axe.legend(plots,
[str(scr) for scr in range(1, 11)],
numpoints=1, title='Border scores',
fontsize='small', loc='lower left',
bbox_to_anchor=(1, 0.1))
except TypeError:
axe.legend(plots,
[str(scr) for scr in range(1, 11)],
numpoints=1, title='Border scores',
loc='lower left',
bbox_to_anchor=(1, 0.1))
axe.set_ylim((0, max(maxys) + 0.4))
if savefig:
tadbit_savefig(savefig)
else:
plt.show()
def plot_compartments(crm, first, cmprts, matrix, show, savefig,
vmin=-1, vmax=1, whichpc=1,showAB=False):
heights = []
val = 0
for i in range(len(matrix)):
try:
val = [c['dens'] for c in cmprts[crm] if c['start']==i][0]
except IndexError:
pass
except KeyError:
heights = []
break
heights.append(val-1)
if heights:
maxheights = max([abs(f) for f in heights])
try:
heights = [f / maxheights for f in heights]
except ZeroDivisionError:
warn('WARNING: no able to plot chromosome %s' % crm)
return
# definitions for the axes
left, width = 0.1, 0.75
bottom, height = 0.1, 0.75
bottom_h = left + height + 0.02
rect_scatter = [left , bottom , width, height]
rect_histx = [left , bottom_h, width, 0.08 ]
rect_histy = [left + width + 0.02, bottom , 0.02 , 0.75 ]
# start with a rectangular Figure
plt.figure(1, figsize=(14, 14))
axim = plt.axes(rect_scatter)
axex = plt.axes(rect_histx, sharex=axim)
axey = plt.axes(rect_histy)
axey.set_ylabel('density (orange)')
im = axim.imshow(matrix, interpolation='nearest', cmap='coolwarm',
vmin=vmin, vmax=vmax)
axim.minorticks_on()
cb = plt.colorbar(im, cax=axey)
cb.set_label('Pearson product-moment correlation coefficients')
axim.grid()
# scale first PC
mfirst = np.nanmax((np.nanmax(first), abs(np.nanmin(first))))
first = [f / mfirst for f in first]
axex.plot(first, color='green', alpha=0.5)
if heights:
axex.plot(heights, color='black', alpha=1, linewidth=2)
axex.plot(heights, color='orange', alpha=1, linewidth=1)
div = 1000 // len(matrix) + 1
_div = float(div)
half_first = np.array([sum(first[(i + j) // div] for j in range(div)) / _div
for i in range(len(first) * div - div + 1)])
with np.errstate(invalid='ignore'):
axex.fill_between([i / _div for i in range(len(half_first))],
[0] * len(half_first), half_first,
where=half_first > 0, color='olive', alpha=0.5)
axex.fill_between([i / _div for i in range(len(half_first))],
[0] * len(half_first), half_first,
where=half_first < 0, color='darkgreen', alpha=0.5)
axex.set_yticks([0])
if heights:
axex.set_ylabel('%s PC (green)\nrich in A (orange)' % (NTH[whichpc]))
else:
axex.set_ylabel('%s PC (green)' % (NTH[whichpc]))
breaks = [0] + [i + 0.5 for i, (a, b) in
enumerate(zip(first[1:], first[:-1]))
if a * b < 0] + [len(first)]
# COMPARTMENTS A/B
if showAB and heights:
a_comp = []
b_comp = []
breaks = []
for cmprt in cmprts[crm]:
breaks.append(cmprt['start'])
try:
if cmprt['type'] == 'A':
a_comp.append((cmprt['start'] - 0.5, cmprt['end'] + 0.5))
elif cmprt['type'] == 'B':
b_comp.append((cmprt['start'] - 0.5, cmprt['end'] + 0.5))
except KeyError:
if cmprt['dens'] > 1:
a_comp.append((cmprt['start'] - 0.5, cmprt['end'] + 0.5))
else:
b_comp.append((cmprt['start'] - 0.5, cmprt['end'] + 0.5))
a_comp.sort()
b_comp.sort()
axex.hlines([0.05]*len(a_comp), [a[0] for a in a_comp],
[a[1] for a in a_comp], color='red' , linewidth=6)
axex.hlines([-0.05]*len(b_comp), [b[0] for b in b_comp],
[b[1] for b in b_comp], color='blue' , linewidth=6)
elif showAB:
warn('WARNING: not displaying AB compartments, need rich in A regions')
axex.grid()
# TODO: these two lines conflict with matplotlib 2.0.2, and plot are still beautiful so...
# axex.minorticks_on()
# axex.grid(b=True, which='minor')
plt.setp(axex.get_xticklabels(), visible=False)
axex.set_xlim((-0.5, len(matrix) - 0.5))
axim.set_ylim((-0.5, len(matrix) - 0.5))
if show:
plt.show()
if savefig:
tadbit_savefig(savefig)
plt.close('all')
def plot_compartments_summary(crm, cmprts, show, savefig, title=None):
plt.close('all')
# start with a rectangular Figure
a_comp = []
b_comp = []
breaks = []
for cmprt in cmprts[crm]:
breaks.append(cmprt['start'])
try:
if cmprt['type'] == 'A':
a_comp.append((cmprt['start'], cmprt['end']))
elif cmprt['type'] == 'B':
b_comp.append((cmprt['start'], cmprt['end']))
except KeyError:
if cmprt['dens'] > 1:
a_comp.append((cmprt['start'], cmprt['end']))
else:
b_comp.append((cmprt['start'], cmprt['end']))
a_comp.sort()
b_comp.sort()
fig, ax = plt.subplots(figsize=(3 + cmprt['end'] / 100., 2))
plt.subplots_adjust(top=0.7, bottom=0.25)
plt.hlines([0.05]*len(a_comp), [a[0] for a in a_comp],
[a[1] for a in a_comp], color='red' , linewidth=6)
plt.hlines([-0.05]*len(b_comp), [b[0] for b in b_comp],
[b[1] for b in b_comp], color='blue' , linewidth=6)
plt.vlines(breaks, [-0.3]*len(breaks), [0.3]*len(breaks), color='black',
linestyle=':')
plt.title(title if title else ('Chromosome %s' % crm))
plt.text(1, 0.55, 'A compartments')
plt.text(1, -0.75, 'B compartments')
plt.xlabel('Genomic bin')
plt.ylim((-1, 1))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.set_yticks([])
plt.xlim((0, cmprts[crm][-1]['end']))
for item in [fig, ax]:
item.patch.set_visible(False)
if show:
plt.show()
if savefig:
tadbit_savefig(savefig)
plt.close('all')
def pcolormesh_45deg(matrix, axe=None, **kwargs):
"""
Draw triangular matrix
"""
if axe is None:
axe = kwargs.get('axe', plt.subplot(111))
size = matrix.shape[0]
if 'vmin' not in kwargs and 'vmax' not in kwargs:
try:
kwargs['vmin'] = np.min(matrix[np.isfinite(matrix)])
kwargs['vmax'] = np.max(matrix[np.isfinite(matrix)])
except ValueError: # probably empty
pass
# create rotation/scaling matrix
rot = np.array([[1,0.5],[-1,0.5]])
# create coordinate matrix and transform it
A = np.dot(np.array([(j, i) for i, j in product(
list(range(size, -1, -1)), list(range(0, size + 1, 1)))]), rot)
# plot
im = axe.pcolormesh(A[:,1].reshape(size + 1,size + 1) - 0.5,
A[:,0].reshape(size + 1,size + 1),
np.flipud(matrix), **kwargs)
axe.spines['right'].set_visible(False)
axe.spines['left'].set_visible(False)
axe.spines['top'].set_visible(False)
axe.set_yticks([])
return im
def add_subplot_axes(ax,rect,axisbg='w'):
"""
from https://stackoverflow.com/questions/17458580/embedding-small-plots-inside-subplots-in-matplotlib/35966183
"""
fig = plt.gcf()
box = ax.get_position()
width = box.width
height = box.height
inax_position = ax.transAxes.transform(rect[0:2])
transFigure = fig.transFigure.inverted()
infig_position = transFigure.transform(inax_position)
x = infig_position[0]
y = infig_position[1]
width *= rect[2]
height *= rect[3] # <= Typo was here
subax = fig.add_axes([x,y,width,height])
x_labelsize = subax.get_xticklabels()[0].get_size()
y_labelsize = subax.get_yticklabels()[0].get_size()
x_labelsize *= rect[2]**0.5
y_labelsize *= rect[3]**0.5
subax.xaxis.set_tick_params(labelsize=x_labelsize)
subax.yaxis.set_tick_params(labelsize=y_labelsize)
return subax
def plot_HiC_matrix(matrix, bad_color=None, triangular=False, axe=None,
transform=np.log2, rescale_zeros=True, figsize=None,
tad_def=None, **kwargs):
"""
Plot HiC matrix with histogram of values inside color bar.
:param matrix: list of lists with values to be plotted
:param None bad_color: plots NaNs in a given color
:param False triangular: representes only half matrix horizontally
:param None figsize: tuple with the width and heigth of the wanted
image (default is (16, 10) for triangular and (16, 14) for square
matrices)
:param None tad_def: dictionary with tad definition to include in
the plot
:param kwargs: extra parameters for the imshow function of matplotlib
:returns: two axes object, the first corresponding to the matrix,
the second to the color bar
"""
if bad_color is not None:
kwargs['cmap'] = copy.copy(plt.get_cmap(kwargs.get('cmap', None)))
kwargs['cmap'].set_bad(bad_color, 1.)
if not isinstance(matrix, (np.ndarray, np.generic)):
matrix = np.asarray(matrix)
# remove zeroes from the matrix in order to avoid -inf with log transform
if rescale_zeros:
try:
mini = min(matrix[np.nonzero(matrix)]) / 2.
except ValueError:
mini = 0.
matrix[matrix==0] = mini
with np.errstate(divide='ignore', invalid='ignore'):
matrix = np.ma.masked_where(np.isnan(matrix), transform(matrix))
if triangular:
if not axe:
if not figsize:
_ = plt.figure(figsize=(16, 10))
else:
_ = plt.figure(figsize=figsize)
axe1 = plt.axes([0.05, 0.15, 0.9, 0.72])
axe2 = plt.axes([0.63, 0.775, 0.32, 0.07])
else:
axe1 = add_subplot_axes(axe, [0.05, 0.15, 0.9, 0.72])
axe2 = add_subplot_axes(axe, [0.63, 0.775, 0.32, 0.07])
else:
if not axe:
if not figsize:
_ = plt.figure(figsize=(16, 14))
else:
_ = plt.figure(figsize=figsize)
axe1 = plt.axes([0.1, 0.1, 0.7, 0.8])
axe2 = plt.axes([0.82, 0.1, 0.07, 0.8])
else:
axe1 = add_subplot_axes(axe, [0.1, 0.1, 0.7, 0.8])
axe2 = add_subplot_axes(axe, [0.82, 0.1, 0.07, 0.8])
if triangular:
pcolormesh_45deg(matrix, axe=axe1, **kwargs)
else:
axe1.imshow(matrix, interpolation='None', origin='lower', **kwargs)
axe1.set_xlim(0 - 0.5, len(matrix[0]) - 0.5)
if triangular:
axe1.set_ylim(0, len(matrix))
else:
axe1.set_ylim(-0.5, len(matrix) - 0.5)
data = [i for d in matrix for i in d if np.isfinite(i)]
try:
mindata = np.nanmin(data)
maxdata = np.nanmax(data)
except ValueError:
mindata = maxdata = 0.
gradient = np.linspace(maxdata, mindata, max((len(matrix),
len(matrix[0]))))
if not triangular:
gradient = np.dstack((gradient, gradient))[0]
else:
gradient = [gradient[::-1]]
try:
h = axe2.hist(data, color='darkgrey', linewidth=2,
orientation='vertical' if triangular else 'horizontal',
bins=50, histtype='step', density=True)
except AttributeError: # older versions of matplotlib
h = axe2.hist(data, color='darkgrey', linewidth=2,
orientation='vertical' if triangular else 'horizontal',
bins=50, histtype='step')
_ = axe2.imshow(gradient, aspect='auto', extent=(
(mindata, maxdata, 0, max(h[0])) if triangular else
(0, max(h[0]), mindata, maxdata)), **kwargs)
if triangular:
axe2.set_yticks([])
axe2.set_xlabel('Hi-C %sinteractions%s' % (
'Log2 ' if 'log2' in str(transform) else
'Log ' if 'log' in str(transform) else '',
'\n(Forced color range %s-%s)' % (kwargs['vmin'], kwargs['vmax'])
if 'vmin' in kwargs and 'vmax' in kwargs else ''))
axe2.set_ylabel('Count')
else:
axe2.yaxis.tick_right()
axe2.yaxis.set_label_position("right")
axe2.set_xticks([])
axe2.set_ylabel('Hi-C %sinteractions%s' % (
'Log2 ' if 'log2' in str(transform) else
'Log ' if 'log' in str(transform) else '',
'\n(Forced color range %s-%s)' % (kwargs['vmin'], kwargs['vmax'])
if 'vmin' in kwargs and 'vmax' in kwargs else ''), rotation=-90,
labelpad=20 if 'vmin' in kwargs and 'vmax' in kwargs
else 10)
axe2.set_xlabel('Count')
if tad_def:
pwidth = 1
for i, tad in tad_def.items():
nwidth = float(abs(tad['score'])) / 4
t_start = int(tad['start']) - 0.5
t_end = int(tad['end']) + 0.5
if not triangular:
axe1.hlines(t_start, t_start, t_end, colors='k', lw=pwidth)
axe1.hlines(t_end , t_start, t_end, colors='k', lw=nwidth)
axe1.vlines(t_start, t_start, t_end, colors='k', lw=pwidth)
axe1.vlines(t_end , t_start, t_end, colors='k', lw=nwidth)
else:
pol1 = plt.Polygon([(t_start,0),
(t_start+(t_end-t_start)/2,(t_end-t_start)),
(t_end,0)], ls="--", lw=nwidth, fill=False)
axe1.add_patch(pol1)
if tad['score'] < 0:
for j in range(0, int(t_end) - int(t_start), 2):
axe1.plot((t_start , t_start + j),
(t_end - j, t_end ), color='k')
axe1.plot((t_end , t_end - j),
(t_start + j, t_start ), color='k')
return axe1, axe2
def format_HiC_axes(axe1, start1, end1, start2, end2, reso, regions,
section_pos, sections, xtick_rotation, triangular=False):
if len(regions) <= 2:
pltbeg1 = 0 if start1 is None else start1
pltend1 = sections[regions[0]] if end1 is None else end1
pltbeg2 = (pltbeg1 if len(regions) == 1 else
0 if start2 is None else start2)
pltend2 = (pltend1 if len(regions) == 1 else
sections[regions[-1]] if end2 is None else end2)
axe1.set_xlabel('{}:{:,}-{:,}'.format(
regions[0] , pltbeg1 if pltbeg1 else 1, pltend1))
if not triangular:
axe1.set_ylabel('{}:{:,}-{:,}'.format(
regions[-1], pltbeg2 if pltbeg2 else 1, pltend2))
def format_xticks(tickstring, _=None):
tickstring = int(tickstring * reso + pltbeg1)
return nicer(tickstring if tickstring else 1,
comma=',', allowed_decimals=1)
def format_yticks(tickstring, _=None):
tickstring = int(tickstring * reso + pltbeg2)
return nicer(tickstring if tickstring else 1,
comma=',', allowed_decimals=1)
axe1.xaxis.set_major_formatter(FuncFormatter(format_xticks))
axe1.yaxis.set_major_formatter(FuncFormatter(format_yticks))
if triangular:
axe1.set_yticks([])
labels = axe1.get_xticklabels()
plt.setp(labels, rotation=xtick_rotation,
ha='left' if xtick_rotation else 'center')
else:
vals = [0]
keys = []
total = 0
for crm in section_pos:
total += (section_pos[crm][1]-section_pos[crm][0]) // reso + 1
vals.append(total)
keys.append(crm)
axe1.set_yticks(vals)
axe1.set_yticklabels('')
axe1.set_yticks([float(vals[i]+vals[i + 1]) / 2
for i in range(len(vals) - 1)],
minor=True)
axe1.set_yticklabels(keys, minor=True)
for t in axe1.yaxis.get_minor_ticks():
t.tick1line.set_visible(False)
t.tick2line.set_visible(False)
axe1.set_xticks(vals)
axe1.set_xticklabels('')
axe1.set_xticks([float(vals[i]+vals[i+1])/2
for i in range(len(vals) - 1)],
minor=True)
axe1.set_xticklabels(keys, minor=True)
for t in axe1.xaxis.get_minor_ticks():
t.tick1line.set_visible(False)
t.tick2line.set_visible(False)
axe1.set_xlabel('Chromosomes')
if not triangular:
axe1.set_ylabel('Chromosomes')
def _format_axes(axe1, start1, end1, start2, end2, reso, regions,
section_pos, sections, xtick_rotation, triangular=False):
if len(regions) <= 2:
pltbeg1 = 0 if start1 is None else start1
pltend1 = sections[regions[0]] if end1 is None else end1
pltbeg2 = (pltbeg1 if len(regions) == 1 else
0 if start2 is None else start2)
pltend2 = (pltend1 if len(regions) == 1 else
sections[regions[-1]] if end2 is None else end2)
axe1.set_xlabel('{}:{:,}-{:,}'.format(
regions[0] , pltbeg1 if pltbeg1 else 1, pltend1))
if not triangular:
axe1.set_ylabel('{}:{:,}-{:,}'.format(
regions[-1], pltbeg2 if pltbeg2 else 1, pltend2))
def format_xticks(tickstring, _=None):
tickstring = int(tickstring * reso + pltbeg1)
return nicer(tickstring if tickstring else 1,
comma=',', allowed_decimals=1)
def format_yticks(tickstring, _=None):
tickstring = int(tickstring * reso + pltbeg2)
return nicer(tickstring if tickstring else 1,
comma=',', allowed_decimals=1)
axe1.xaxis.set_major_formatter(FuncFormatter(format_xticks))
axe1.yaxis.set_major_formatter(FuncFormatter(format_yticks))
if triangular:
axe1.set_yticks([])
labels = axe1.get_xticklabels()
plt.setp(labels, rotation=xtick_rotation,
ha='left' if xtick_rotation else 'center')
else:
vals = [0]
keys = []
total = 0
for crm in section_pos:
total += (section_pos[crm][1]-section_pos[crm][0]) // reso + 1
vals.append(total)
keys.append(crm)
axe1.set_yticks(vals)
axe1.set_yticklabels('')
axe1.set_yticks([float(vals[i]+vals[i + 1]) / 2
for i in range(len(vals) - 1)],
minor=True)
axe1.set_yticklabels(keys, minor=True)
for t in axe1.yaxis.get_minor_ticks():
t.tick1line.set_visible(False)
t.tick2line.set_visible(False)
axe1.set_xticks(vals)
axe1.set_xticklabels('')
axe1.set_xticks([float(vals[i]+vals[i+1])/2
for i in range(len(vals) - 1)],
minor=True)
axe1.set_xticklabels(keys, minor=True)
for t in axe1.xaxis.get_minor_ticks():
t.tick1line.set_visible(False)
t.tick2line.set_visible(False)
axe1.set_xlabel('Chromosomes')
if not triangular:
axe1.set_ylabel('Chromosomes')