""" PerplexitySubsampler class: define and execute subsampling on a dataset, weighted by perplexity values """ from collections import namedtuple import numpy as np import scipy as sp from numpy.random import default_rng from scipy.stats import norm, uniform from typing import List, Tuple, Iterable Histo = namedtuple("HISTO", "counts edges centers") rng = default_rng() def histo_quantile(hcounts: np.ndarray, hedges: np.ndarray, perc_values: Iterable[float]) -> List[float]: """ Compute quantile values by using a histogram """ cs = np.cumsum(hcounts)/np.sum(hcounts) out = [] for p in perc_values: idx = np.searchsorted(cs, p) frac = (p - cs[idx-1]) / (cs[idx] - cs[idx-1]) r = hedges[idx] + (hedges[idx+1] - hedges[idx])*frac out.append(r) return out def _histo_inv_quantile(hedges: np.ndarray, hcounts: np.ndarray, perp_value: float) -> float: """ Using an histogram of values, estimate the quantile occupied by a given value It is therefore the inverse function of quantile() """ v = np.searchsorted(hedges, perp_value, side="right") frac = (perp_value - hedges[v-1]) / (hedges[v] - hedges[v-1]) return hcounts[:v-1].sum() + hcounts[v-1]*frac def subsample_frac(data: np.ndarray, frac: float) -> np.ndarray: """ Subsample an array to a given fraction """ return data[uniform.rvs(size=len(data)) < frac] # ------------------------------------------------------------------------- class PerplexitySubsampler: def __init__(self, perp_data: np.ndarray = None, perp_histogram: Tuple[np.ndarray, np.ndarray] = None, hbins: int = 1000): """ :param perp_data: a dataset of perplexity values :param perp_histo: a histogram computed over a dataset of perplexity values, passed as a tuple (counts, edges) :param hbins: number of bins to use for the histogram approximation (only used if `perp_data` is passed) Either `perp_data` or `perp_histogram` must be passed """ if perp_data is not None: # Get the P25 and P75 quartiles self.qr = np.quantile(perp_data, [0.25, 0.75]) # Build an histogram of perplexities range_max = self.qr[1]*10 counts, edges = np.histogram(perp_data, bins=hbins, range=[0, range_max]) counts[-1] += len(perp_data[perp_data > range_max]) self.histo = Histo(counts, edges, (edges[:-1] + edges[1:])/2) elif perp_histogram is not None: edges = perp_histogram[1] self.histo = Histo(perp_histogram[0], edges, (edges[:-1] + edges[1:])/2) self.qr = histo_quantile(self.histo.counts, self.histo.edges, [0.25, 0.75]) else: raise Exception("Neither sample nor histogram provided") def _estimate(self, m: float, s: float, ratio: float) -> Tuple[float, float]: """ Estimate the quantiles to be retained in the 1st & 4th original quartiles """ # Compute the normalization factor gauss_weights = norm.pdf(self.histo.centers, loc=m, scale=s) hcounts = self.histo.counts adjusted_norm = (hcounts*gauss_weights).sum()/hcounts.sum()/ratio # Subsample the histogram hcounts_sub = self.histo.counts*gauss_weights/adjusted_norm sub_size = hcounts_sub.sum() # Estimate the quantiles at Xa & Xb ra = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[0])/sub_size rb = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[1])/sub_size #print(f"{m:10.2f} {s:10.2f} => {ra:.4} {1-rb:.4}") return ra, 1-rb def _error(self, point: np.ndarray, ratio: float, pa: float, pb: float) -> float: """ Estimate the error in probability mass results """ actual_pa, actual_pb = self._estimate(point[0], point[1], ratio) return abs(pa-actual_pa) + abs(pb-actual_pb) def set(self, ratio: float, pa: float, pb: float): """ Compute the parameters needed to achieve a desired sampling ratio & probability distribution :param ratio: the desired sampling ratio :param pa: the probability mass to be left in the first original perplexity quartile :param pb: the probability mass to be left in the fourth original perplexity quartile """ # Obtain the initial parameters for the gaussian weighting function # (assuming uniform data) sdev = (self.qr[0] - self.qr[1]) / (norm.ppf(pa) - norm.ppf(1-pb)) mean = self.qr[0] - norm.ppf(pa)*sdev # Optimize for the real data distribution initial = np.array([mean, sdev]) result = sp.optimize.minimize(self._error, initial, args=(ratio, pa, pb), method='nelder-mead', options={'xatol': 1e-8, 'disp': False}) self.mean, self.sdev = result.x # Now that we have the final parameters, compute the weighting # function over the histogram values gauss_weights = norm.pdf(self.histo.centers, loc=self.mean, scale=self.sdev) # Find the normalization needed to achieve the desired sampling ratio counts = self.histo.counts self.norm = (counts*gauss_weights).sum()/counts.sum()/ratio def subsample(self, data: np.ndarray) -> np.ndarray: """ Subsample a dataset according to the defined conditions Note: set() must have been called previously """ # Create the gaussian weight for each data point p = norm.pdf(data, loc=self.mean, scale=self.sdev)/self.norm #print(p) # Subsample data with probability according to the weight return data[uniform.rvs(size=len(p)) < p] def retain(self, perp: float) -> bool: """ Decide if a sample is to be retained based on its perplexity value Note: set() must have been called previously """ p = norm.pdf(perp, loc=self.mean, scale=self.sdev)/self.norm return rng.uniform() < p def subsample_piecewise(self, data: np.ndarray, pa: float, pb: float) -> np.ndarray: """ Creat a subsample by directly subsampling each region """ qr = self.qr data1 = subsample_frac(data[data < qr[0]], pa/0.25*self.ratio) data2 = subsample_frac(data[(data >= qr[0]) & (data <= qr[1])], (1-pa-pb)/0.5*self.ratio) data3 = subsample_frac(data[self.data > qr[1]], pb/0.25*self.ratio) return np.hstack([data1, data2, data3]) def verify(self, data: np.ndarray, data_sub: np.ndarray) -> Tuple: """ Check the statistics of a sample """ ratio = len(data_sub)/len(data) ra = len(data_sub[data_sub < self.qr[0]]) / len(data_sub) rb = len(data_sub[data_sub > self.qr[1]]) / len(data_sub) return ratio, ra, rb def check_results(s: PerplexitySubsampler, data_full: np.ndarray, data_sub: np.ndarray): """ Compute and print out the results for a subsample """ r, ra, rb = s.verify(data_full, data_sub) print("Sampling ratio:", r) print("Probability mass below Pa:", ra) print("Probability mass above Pb:", rb)