|
""" |
|
PerplexitySubsampler class: define and execute subsampling on a dataset, |
|
weighted by perplexity values |
|
""" |
|
|
|
from collections import namedtuple |
|
import numpy as np |
|
import scipy as sp |
|
from numpy.random import default_rng |
|
from scipy.stats import norm, uniform |
|
|
|
from typing import List, Tuple, Iterable |
|
|
|
Histo = namedtuple("HISTO", "counts edges centers") |
|
|
|
rng = default_rng() |
|
|
|
|
|
def histo_quantile(hcounts: np.ndarray, hedges: np.ndarray, |
|
perc_values: Iterable[float]) -> List[float]: |
|
""" |
|
Compute quantile values by using a histogram |
|
""" |
|
cs = np.cumsum(hcounts)/np.sum(hcounts) |
|
out = [] |
|
for p in perc_values: |
|
idx = np.searchsorted(cs, p) |
|
frac = (p - cs[idx-1]) / (cs[idx] - cs[idx-1]) |
|
r = hedges[idx] + (hedges[idx+1] - hedges[idx])*frac |
|
out.append(r) |
|
return out |
|
|
|
|
|
def _histo_inv_quantile(hedges: np.ndarray, hcounts: np.ndarray, |
|
perp_value: float) -> float: |
|
""" |
|
Using an histogram of values, estimate the quantile occupied |
|
by a given value |
|
It is therefore the inverse function of quantile() |
|
""" |
|
v = np.searchsorted(hedges, perp_value, side="right") |
|
frac = (perp_value - hedges[v-1]) / (hedges[v] - hedges[v-1]) |
|
return hcounts[:v-1].sum() + hcounts[v-1]*frac |
|
|
|
|
|
def subsample_frac(data: np.ndarray, frac: float) -> np.ndarray: |
|
""" |
|
Subsample an array to a given fraction |
|
""" |
|
return data[uniform.rvs(size=len(data)) < frac] |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PerplexitySubsampler: |
|
|
|
def __init__(self, perp_data: np.ndarray = None, |
|
perp_histogram: Tuple[np.ndarray, np.ndarray] = None, |
|
hbins: int = 1000): |
|
""" |
|
:param perp_data: a dataset of perplexity values |
|
:param perp_histo: a histogram computed over a dataset of perplexity |
|
values, passed as a tuple (counts, edges) |
|
:param hbins: number of bins to use for the histogram approximation |
|
(only used if `perp_data` is passed) |
|
|
|
Either `perp_data` or `perp_histogram` must be passed |
|
""" |
|
if perp_data is not None: |
|
|
|
|
|
self.qr = np.quantile(perp_data, [0.25, 0.75]) |
|
|
|
range_max = self.qr[1]*10 |
|
counts, edges = np.histogram(perp_data, bins=hbins, |
|
range=[0, range_max]) |
|
counts[-1] += len(perp_data[perp_data > range_max]) |
|
self.histo = Histo(counts, edges, (edges[:-1] + edges[1:])/2) |
|
|
|
elif perp_histogram is not None: |
|
|
|
edges = perp_histogram[1] |
|
self.histo = Histo(perp_histogram[0], edges, |
|
(edges[:-1] + edges[1:])/2) |
|
self.qr = histo_quantile(self.histo.counts, self.histo.edges, |
|
[0.25, 0.75]) |
|
|
|
else: |
|
raise Exception("Neither sample nor histogram provided") |
|
|
|
|
|
def _estimate(self, m: float, s: float, |
|
ratio: float) -> Tuple[float, float]: |
|
""" |
|
Estimate the quantiles to be retained in the 1st & 4th original |
|
quartiles |
|
""" |
|
|
|
gauss_weights = norm.pdf(self.histo.centers, loc=m, scale=s) |
|
hcounts = self.histo.counts |
|
adjusted_norm = (hcounts*gauss_weights).sum()/hcounts.sum()/ratio |
|
|
|
hcounts_sub = self.histo.counts*gauss_weights/adjusted_norm |
|
sub_size = hcounts_sub.sum() |
|
|
|
ra = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[0])/sub_size |
|
rb = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[1])/sub_size |
|
|
|
return ra, 1-rb |
|
|
|
|
|
def _error(self, point: np.ndarray, ratio: float, |
|
pa: float, pb: float) -> float: |
|
""" |
|
Estimate the error in probability mass results |
|
""" |
|
actual_pa, actual_pb = self._estimate(point[0], point[1], ratio) |
|
return abs(pa-actual_pa) + abs(pb-actual_pb) |
|
|
|
|
|
def set(self, ratio: float, pa: float, pb: float): |
|
""" |
|
Compute the parameters needed to achieve a desired sampling ratio & |
|
probability distribution |
|
:param ratio: the desired sampling ratio |
|
:param pa: the probability mass to be left in the first original |
|
perplexity quartile |
|
:param pb: the probability mass to be left in the fourth original |
|
perplexity quartile |
|
""" |
|
|
|
|
|
sdev = (self.qr[0] - self.qr[1]) / (norm.ppf(pa) - norm.ppf(1-pb)) |
|
mean = self.qr[0] - norm.ppf(pa)*sdev |
|
|
|
initial = np.array([mean, sdev]) |
|
result = sp.optimize.minimize(self._error, initial, |
|
args=(ratio, pa, pb), |
|
method='nelder-mead', |
|
options={'xatol': 1e-8, 'disp': False}) |
|
self.mean, self.sdev = result.x |
|
|
|
|
|
gauss_weights = norm.pdf(self.histo.centers, loc=self.mean, |
|
scale=self.sdev) |
|
|
|
counts = self.histo.counts |
|
self.norm = (counts*gauss_weights).sum()/counts.sum()/ratio |
|
|
|
|
|
def subsample(self, data: np.ndarray) -> np.ndarray: |
|
""" |
|
Subsample a dataset according to the defined conditions |
|
Note: set() must have been called previously |
|
""" |
|
|
|
p = norm.pdf(data, loc=self.mean, scale=self.sdev)/self.norm |
|
|
|
|
|
return data[uniform.rvs(size=len(p)) < p] |
|
|
|
|
|
def retain(self, perp: float) -> bool: |
|
""" |
|
Decide if a sample is to be retained based on its perplexity value |
|
Note: set() must have been called previously |
|
""" |
|
p = norm.pdf(perp, loc=self.mean, scale=self.sdev)/self.norm |
|
return rng.uniform() < p |
|
|
|
|
|
def subsample_piecewise(self, data: np.ndarray, |
|
pa: float, pb: float) -> np.ndarray: |
|
""" |
|
Creat a subsample by directly subsampling each region |
|
""" |
|
qr = self.qr |
|
data1 = subsample_frac(data[data < qr[0]], pa/0.25*self.ratio) |
|
data2 = subsample_frac(data[(data >= qr[0]) & (data <= qr[1])], |
|
(1-pa-pb)/0.5*self.ratio) |
|
data3 = subsample_frac(data[self.data > qr[1]], pb/0.25*self.ratio) |
|
return np.hstack([data1, data2, data3]) |
|
|
|
|
|
def verify(self, data: np.ndarray, data_sub: np.ndarray) -> Tuple: |
|
""" |
|
Check the statistics of a sample |
|
""" |
|
ratio = len(data_sub)/len(data) |
|
ra = len(data_sub[data_sub < self.qr[0]]) / len(data_sub) |
|
rb = len(data_sub[data_sub > self.qr[1]]) / len(data_sub) |
|
return ratio, ra, rb |
|
|
|
|
|
|
|
def check_results(s: PerplexitySubsampler, |
|
data_full: np.ndarray, data_sub: np.ndarray): |
|
""" |
|
Compute and print out the results for a subsample |
|
""" |
|
r, ra, rb = s.verify(data_full, data_sub) |
|
print("Sampling ratio:", r) |
|
print("Probability mass below Pa:", ra) |
|
print("Probability mass above Pb:", rb) |
|
|