Source code for pytadbit.utils.hic_filtering

"""
06 Aug 2013
"""
from __future__ import print_function

from warnings                  import warn
from sys                       import stderr
from re                        import sub
from pytadbit.utils.extraviews import tadbit_savefig

import numpy as np

try:
    from matplotlib import pyplot as plt
except ImportError:
    warn('matplotlib not found\n')


def get_r2 (fun, X, Y, *args):
    sstot = sum([(Y[i]-np.mean(Y))**2 for i in range(len(Y))])
    sserr = sum([(Y[i] - fun(X[i], *args))**2 for i in range(len(Y))])
    return 1 - sserr/sstot


def filter_by_mean(matrx, draw_hist=False, silent=False, bads=None, savefig=None):
    """
    fits the distribution of Hi-C interaction count by column in the matrix to
    a polynomial. Then searches for the first possible 
    """
    nbins = 100
    if not bads:
        bads = {}
    # get sum of columns
    cols = []
    size = len(matrx)
    for c in sorted([[matrx.get(i+j*size, 0) for j in range(size) if not j in bads]
                     for i in range(size) if not i in bads], key=sum):
        cols.append(sum(c))
    cols = np.array(cols)
    if draw_hist:
        plt.figure(figsize=(9, 9))
    try:
        percentile = np.percentile(cols, 5)
    except IndexError:
        warn('WARNING: no columns to filter out')
        return bads
    # mad = np.median([abs(median - c ) for c in cols])
    best =(float('-inf'), float('-inf'), float('-inf'), float('-inf'))
    # bin the sum of columns
    xmin = min(cols)
    xmax = max(cols)
    y = np.linspace(xmin, xmax, nbins)
    hist = np.digitize(cols, y)
    x = [sum(hist == i) for i in range(1, nbins + 1)]
    if draw_hist:
        hist = plt.hist(cols, bins=100, alpha=.3, color='grey')
    xp = list(range(0, int(cols[-1])))
    # check if the binning is correct
    # we want at list half of the bins with some data
    try:
        cnt = 0
        while list(x).count(0) > len(x)/2:
            cnt += 1
            cols = cols[:-1]
            xmin = min(cols)
            xmax = max(cols)
            y = np.linspace(xmin, xmax, nbins)
            hist = np.digitize(cols, y)
            x = [sum(hist == i) for i in range(1, nbins + 1)]
            if draw_hist:
                plt.clf()
                hist = plt.hist(cols, bins=100, alpha=.3, color='grey')
            xp = list(range(0, int(cols[-1])))
            if cnt > 10000:
                raise ValueError
        # find best polynomial fit in a given range
        for order in range(6, 18):
            z = np.polyfit(y, x, order)
            zp = np.polyder(z, m=1)
            roots = np.roots(np.polyder(z))
            # check that we are concave down, otherwise take next root
            pente = np.polyval(zp, abs(roots[-2] - roots[-1]) / 2 + roots[-1])
            if pente > 0:
                root = roots[-1]
            else:
                root = roots[-2]
            # root must be higher than zero
            if root <= 0:
                continue
            # and lower than the median
            if root >= percentile:
                continue
            p  = np.poly1d(z)
            R2 = get_r2(p, x, y)
            # try to avoid very large orders by weigthing negatively their fit
            if order > 13:
                R2 -= float(order)/30
            if best[0] < R2:
                best = (R2, order, p, z, root)
        try:
            p, z, root = best[2:]
            if draw_hist:
                xlims = plt.xlim()
                ylims = plt.ylim()
                a = plt.plot(xp, p(xp), "--", color='k')
                b = plt.vlines(root, 0, plt.ylim()[1], colors='r', linestyles='dashed')
                # c = plt.vlines(median - mad * 1.5, 0, 110, colors='g',
                #                linestyles='dashed')
                try:
                    plt.legend(a+[b], ['polyfit \n%s' % (
                        ''.join([sub('e([-+][0-9]+)', 'e^{\\1}',
                                     '$%s%.1fx^%s$' % ('+' if j>0 else '', j,
                                                       '{' + str(i) + '}'))
                                 for i, j in enumerate(list(p)[::-1])])),
                                           'first solution of polynomial derivation'],
                               fontsize='x-small')
                except TypeError:
                    plt.legend(a+[b], ['polyfit \n%s' % (
                        ''.join([sub('e([-+][0-9]+)', 'e^{\\1}',
                                     '$%s%.1fx^%s$' % ('+' if j>0 else '', j,
                                                       '{' + str(i) + '}'))
                                 for i, j in enumerate(list(p)[::-1])])),
                                           'first solution of polynomial derivation'])
                # plt.legend(a+[b]+[c], ['polyfit \n{}'.format (
                #     ''.join([sub('e([-+][0-9]+)', 'e^{\\1}',
                #                  '${}{:.1}x^{}$'.format ('+' if j>0 else '', j,
                #                                         '{' + str(i) + '}'))
                #              for i, j in enumerate(list(p)[::-1])])),
                #                        'first solution of polynomial derivation',
                #                        'median - (1.5 * median absolute deviation)'],
                #            fontsize='x-small')
                plt.ylim([0, ylims[1]])
                plt.xlim(xlims)
                plt.xlabel('Sum of interactions')
                plt.xlabel('Number of columns with a given value')
                if savefig:
                    tadbit_savefig(savefig)
                else:
                    plt.show()
            # label as bad the columns with sums lower than the root
            for i, col in enumerate([[matrx.get(i+j*size, 0)
                                      for j in range(size)]
                                     for i in range(size)]):
                if sum(col) < root:
                    bads[i] = sum(col)
            # now stored in Experiment._zeros, used for getting more accurate z-scores
            if bads and not silent:
                stderr.write(('\nWARNING: removing columns having less than %s ' +
                              'counts:\n %s\n') % (
                                 round(root, 3), ' '.join(
                                     ['%5s'%str(i + 1) + (''if (j + 1) % 20 else '\n')
                                      for j, i in enumerate(sorted(bads.keys()))])))
        except:
            if not silent:
                stderr.write('WARNING: Too many zeroes to filter columns.' +
                             ' SKIPPING...\n')
            if draw_hist:
                plt.xlabel('Sum of interactions')
                plt.xlabel('Number of columns with a given value')
                if savefig:
                    tadbit_savefig(savefig)
                else:
                    plt.show()
    except ValueError:
        if not silent:
            stderr.write('WARNING: Too few data to filter columns based on ' +
                         'mean value.\n')
    if draw_hist:
        plt.close('all')
    return bads


def filter_by_zero_count(matrx, perc_zero, min_count=None, silent=True):
    """
    :param matrx: Hi-C matrix of a given experiment
    :param perc: percentage of cells with no count allowed to consider a column
       as valid.
    :param None min_count: minimum number of reads mapped to a bin (recommended
       value could be 2500). If set this option overrides the perc_zero
       filtering... This option is slightly slower.

    :returns: a dicitionary, which has as keys the index of the filtered out
       columns.
    """
    bads = {}
    size = len(matrx)
    if min_count is None:
        cols = [size for i in range(size)]
        for k in matrx:
            cols[k // size] -= 1
        min_val = int(size * float(perc_zero) / 100)
    else:
        if matrx.symmetricized:
            stderr.write('\nWARNING: Using twice min_count as the matrix was '
                         'symmetricized and contains twice as many '
                         'interactions as the original\n')
            min_count *= 2
        cols = [0 for i in range(size)]
        for k, v in matrx.items(): # linear representation of the matrix
            cols[k // size] += v
        min_val = size - min_count
    if min_count is None:
        check = lambda x: x > min_val
    else:
        check = lambda x: x < min_count
    for i, col in enumerate(cols):
        if check(col):
            bads[i] = True
    if bads and not silent:
        if min_count is None:
            stderr.write(('\nWARNING: removing columns having more than %s ' +
                          'zeroes:\n %s\n') % (
                             min_val, ' '.join(
                                 ['%5s' % str(i + 1) + (''if (j + 1) % 20 else '\n')
                                  for j, i in enumerate(sorted(bads.keys()))])))
        else:
            stderr.write(('\nWARNING: removing columns having less than %s ' +
                          'counts:\n %s\n') % (
                             int(size - min_val), ' '.join(
                                 ['%5s' % str(i + 1) + (''if (j + 1) % 20 else '\n')
                                  for j, i in enumerate(sorted(bads.keys()))])))
    return bads


[docs]def hic_filtering_for_modelling(matrx, silent=False, perc_zero=90, auto=True, min_count=None, draw_hist=False, savefig=None, diagonal=True): """ Call filtering function, to remove artifactual columns in a given Hi-C matrix. This function will detect columns with very low interaction counts; and columns with NaN values (in this case NaN will be replaced by zero in the original Hi-C data matrix). Filtered out columns will be stored in the dictionary Experiment._zeros. :param matrx: Hi-C matrix of a given experiment :param False silent: does not warn for removed columns :param None savefig: path to a file where to save the image generated; if None, the image will be shown using matplotlib GUI (the extension of the file name will determine the desired format). :param True diagonal: remove row/columns with zero in the diagonal :param 90 perc_zero: maximum percentage of cells with no interactions allowed. :param None min_count: minimum number of reads mapped to a bin (recommended value could be 2500). If set this option overrides the perc_zero filtering... This option is slightly slower. :param True auto: if False, only filters based on the given percentage zeros :returns: the indexes of the columns not to be considered for the calculation of the z-score """ bads = filter_by_zero_count(matrx, perc_zero, min_count=min_count, silent=silent) if auto: bads.update(filter_by_mean(matrx, draw_hist=draw_hist, silent=silent, savefig=savefig, bads=bads)) # also removes rows or columns containing a NaN has_nans = False for i in range(len(matrx)): if matrx.get(i + i * len(matrx), 0) == 0 and diagonal: if not i in bads: bads[i] = None elif repr(sum([matrx.get(i + j * len(matrx), 0) for j in range(len(matrx))])) == 'nan': has_nans = True if not i in bads: bads[i] = None return bads, has_nans
def _best_window_size(sorted_prc, size, beg, end, verbose=False): """ Search for best window size. Between given begin and end percentiles of the distribution of cis interactions searches for a window size (number of bins) where all median values are between median * stddev and median * stddev of the global measure. :param sorted_prc: list of percentages of cis interactions by bins, sorted by the total interactions in the corresponding bins. :param size: total number of bins :param beg: starting position of the region with expected 'normal' behavior of the cis-percentage :param end: last position of the region with expected 'normal' behavior of the cis-percentage :param False verbose: print running information :returns: window size """ if verbose: print (' -> defining window in number of bins to average values of\n' ' percentage of cis interactions') nwins = min((1000, size // 10)) if nwins < 100: warn('WARNING: matrix probably too small to automatically filter out bins\n') win_size = 1 prevn = 0 count = 0 # iterate over possible window sizes (use logspace to gain some time) for n in np.logspace(1, 4, num=100): n = int(n) if n == prevn: continue prevn = n tmp_std = [] tmp_med = [] for k in range(int(size * beg), int(size * end), (int(size * end) - int(size * beg)) // nwins): vals = sorted_prc[k:k + n] tmp_std.append(np.std(vals)) tmp_med.append(np.median(vals)) med_mid = np.median([tmp_med[i] for i in range(nwins)]) results = [m - s < med_mid < m + s for m, s in zip(tmp_med, tmp_std)] # if verbose: # print ' -', n, med_mid, sum(results) if all(results): if not count: win_size = n count += 1 if count == 10: break else: count = 0 if verbose: print(' * first window size with stable median of cis-percentage: %d' % (win_size)) return win_size def filter_by_cis_percentage(cisprc, beg=0.3, end=0.8, sigma=2, verbose=False, size=None, min_perc=None, max_perc=None, savefig=None): """ Define artifactual columns with either too low or too high counts of interactions by compraing their percentage of cis interactions (inter-chromosomal). :param cisprc: dictionary with counts of cis-percentage by bin number. Values of the dictionary are tuple with,m as first element the number of cis interactions and as second element the total number of interactions. :param 0.3 beg: proportion of bins to be considered as possibly having low counts :param 0.8 end: proportion of bins to be considered as possibly having high counts :param 2 sigma: number of standard deviations used to define lower and upper ranges in the varaition of the percentage of cis interactions :param None size: size of the genome, inumber of bins (otherwise inferred from cisprc dictionary) :param None sevefig: path to save image of the distribution of cis percentages and total counts by bin. :returns: dictionary of bins to be filtered out (with either too low or too high counts of interactions). """ sorted_sum, indices = list(zip(*sorted((cisprc[i][1], i) for i in cisprc))) sorted_prc = [float(cisprc[i][0]) / cisprc[i][1] for i in indices] size = (max(indices) + 1) if not size else size win_size = _best_window_size(sorted_prc, size, beg, end, verbose=verbose) # define confidance bands, compute median plus/minus one standard deviation errors_pos = [] errors_neg = [] for k in range(0, size, 1): vals = sorted_prc[k:k+win_size] std = np.std(vals) med = np.median(vals) errors_pos.append(med + std * sigma) errors_neg.append(med - std * sigma) # calculate median and variation of median plus/minus one standard deviation # for values between percentile 10 and 90 of the distribution of the # percentage of cis interactions # - for median plus one standard deviation std_err_pos = np.std (errors_pos[int(size * beg):int(size * end)]) med_err_pos = np.median(errors_pos[int(size * beg):int(size * end)]) # - for median minus one standard deviation std_err_neg = np.std (errors_neg[int(size * beg):int(size * end)]) med_err_neg = np.median(errors_neg[int(size * beg):int(size * end)]) # define cutoffs, values of cis percentage plus 1 stddev should be between # the general median +/- 2 stddev of the distribution of the cis percentage # plus 1 stddev. Same on the side of median cis percentage minus 1 stddev beg_pos = med_err_pos - std_err_pos * sigma end_pos = med_err_pos + std_err_pos * sigma beg_neg = med_err_neg - std_err_neg * sigma end_neg = med_err_neg + std_err_neg * sigma cutoffL = None passed = 0 consecutive = 10 for cutoffL, (p, n) in enumerate(zip(errors_pos, errors_neg)): # print '%6.4f %6.4f %6.4f %6.4f %6.4f %6.4f' % (beg_pos, p, end_pos, beg_neg, n, end_neg) if (beg_pos < p < end_pos) and (beg_neg < n < end_neg): if passed >= consecutive: break passed += 1 else: passed = 0 else: if min_perc is None: raise Exception('ERROR: left cutoff not found!!!\n' ' define it by hand with min_perc') else: cutoffL = min_perc / 100. * size + consecutive cutoffL -= consecutive # rescale, we asked for XX consecutive # right cutoffR = None passed = 0 for cutoffR, (p, n) in enumerate(list(zip(errors_pos, errors_neg))[::-1]): cutoffR = size - cutoffR # print '%6.4f %6.4f %6.4f %6.4f %6.4f %6.4f' % (beg_pos, p, end_pos, beg_neg, n, end_neg) if (beg_pos < p < end_pos) and (beg_neg < n < end_neg): if passed >= consecutive: break passed += 1 else: passed = 0 else: if max_perc is None: raise Exception('ERROR: right cutoff not found!!!\n' ' define it by hand with max_perc') else: cutoffR = max_perc / 100. * size - consecutive cutoffR += consecutive # rescale, we asked for XX consecutive if min_perc: cutoffL = min_perc / 100. * size if max_perc: cutoffR = max_perc / 100. * size min_count = sorted_sum[int(cutoffL)] try: max_count = sorted_sum[int(cutoffR)] except IndexError: # all good max_count = sorted_sum[-1] + 1 if verbose: print(' * Lower cutoff applied until bin number: %d' % (cutoffL)) print(' * too few interactions defined as less than %9d interactions' % ( min_count)) print(' * Upper cutoff applied until bin number: %d' % (cutoffR)) print(' * too much interactions defined as more than %9d interactions' % ( max_count)) # plot if savefig: if verbose: print(' -> Making plot...') fig = plt.figure(figsize=(20,11)) ax1 = fig.add_subplot(111) plt.subplots_adjust(left=0.25, bottom=0.2) line1 = ax1.plot( [float(cisprc.get(i, [0, 0])[0]) / cisprc.get(i, [1, 1])[1] for i in indices], '.', color='grey', alpha=0.2, label='cis interactions ratio by bin', zorder=1) line2 = ax1.plot( list(range(0, len(indices), 20)), [sum(float(cisprc.get(j, [0, 0])[0]) / cisprc.get(j, [1, 1])[1] for j in indices[k:k+win_size]) / win_size for k in range(0, len(indices), 20)], '.', color='k', alpha=0.3, label='cis interactions ratio by %d bin' % win_size, zorder=1) for k, (p, n) in enumerate(zip(errors_pos[::size // 100], errors_neg[::size // 100])): ax1.vlines(k * (size // 100), (p + n) // 2, p, color='red', alpha=0.6) ax1.vlines(k * (size // 100), n, (p + n) // 2, color='blue', alpha=0.6) ax1.plot(list(range(0, size, size // 100)), errors_neg[::size//100], 'b^', mec='blue', alpha=0.5) ax1.plot(list(range(0, size, size // 100)), errors_pos[::size//100], 'rv', mec='red', alpha=0.5) ax1.fill_between([0, size], beg_pos, end_pos, color='red', alpha=0.3, zorder=2) ax1.text(-size/15., (end_pos + beg_pos) / 2, 'Confidance band for\nupper stddev of median', color='red', ha='right', va='center') ax1.fill_between([0, size], beg_neg, end_neg, color='blue', alpha=0.3, zorder=2) ax1.text(-size/15., (end_neg + beg_neg) / 2, 'Confidance band for\nlower stddev of median', color='blue', ha='right', va='center') ax1.set_ylim((0,1.1)) ax1.set_ylabel('Ratio of cis interactions ratio') ax1.fill_betweenx([0, 1.1], cutoffL, cutoffR, color='green', alpha=0.2) ax1.text((cutoffR + cutoffL) / 2, -0.1, ('Kept bins, top and bottom deviations from median cis-ratio\n' + 'should be inside their respective confidance bands'), ha='center', color='green') ax2 = fig.add_subplot(111, sharex=ax1, frameon=False) line3 = ax2.plot(sorted_sum, 'rx', alpha=0.4, label='Log sum of interactions by bin') ax2.set_yscale('log') ax2.yaxis.tick_right() ax2.yaxis.set_label_position("right") ax2.set_ylabel('Log interaction counts') lns = line1 + line2 + line3 labs = [l.get_label() for l in lns] ax2.legend(lns, labs, loc=0, bbox_to_anchor=(0, 0), frameon=False) ax3 = fig.add_subplot(111, frameon=False) ax3.xaxis.tick_top() ax3.set_xticks(list(range(100)), minor=True) ax3.set_xticks(list(range(0, 100, 5)), minor=False) ax3.set_yticks([]) ax3.set_xticklabels([]) for p in range(5, 100, 5): ax3.text(p, 99, '%d%%' % p, va='top', ha='left', size=9) ax3.tick_params(direction='in', axis='x', which='both') ax3.set_xlim(0, 100) ax3.set_ylim(0, 100) ax3.grid(which='major') ax3.grid(which='minor', alpha=0.5) if min_perc: plt.title('Setting from %.2f%% to %.2f%%' % (100 * float(cutoffL) / len(indices), 100 * float(cutoffR) / len(indices))) else: plt.title('Keeping from %.2f%% to %.2f%%' % (100 * float(cutoffL) / len(indices), 100 * float(cutoffR) / len(indices))) ax1.set_xlim((0, len(indices))) tadbit_savefig(savefig) plt.close('all') badcol = {} countL = 0 countZ = 0 countU = 0 for c in range(size): if cisprc.get(c, [0, 0])[1] < min_count: badcol[c] = cisprc.get(c, [0, 0])[1] countL += 1 if not c in cisprc: countZ += 1 elif cisprc[c][1] > max_count: # don't need get here, already cought in previous condition badcol[c] = cisprc.get(c, [0, 0])[1] countU += 1 print(' => %d BAD bins (%d/%d/%d null/low/high counts) of %d (%.1f%%)' % ( len(badcol), countZ, countL, countU, size, float(len(badcol)) / size * 100)) return badcol