mimir-perplexity / subsampler.py
versae's picture
Mdels and code
dcc5cd1
"""
PerplexitySubsampler class: define and execute subsampling on a dataset,
weighted by perplexity values
"""
from collections import namedtuple
import numpy as np
import scipy as sp
from numpy.random import default_rng
from scipy.stats import norm, uniform
from typing import List, Tuple, Iterable
Histo = namedtuple("HISTO", "counts edges centers")
rng = default_rng()
def histo_quantile(hcounts: np.ndarray, hedges: np.ndarray,
perc_values: Iterable[float]) -> List[float]:
"""
Compute quantile values by using a histogram
"""
cs = np.cumsum(hcounts)/np.sum(hcounts)
out = []
for p in perc_values:
idx = np.searchsorted(cs, p)
frac = (p - cs[idx-1]) / (cs[idx] - cs[idx-1])
r = hedges[idx] + (hedges[idx+1] - hedges[idx])*frac
out.append(r)
return out
def _histo_inv_quantile(hedges: np.ndarray, hcounts: np.ndarray,
perp_value: float) -> float:
"""
Using an histogram of values, estimate the quantile occupied
by a given value
It is therefore the inverse function of quantile()
"""
v = np.searchsorted(hedges, perp_value, side="right")
frac = (perp_value - hedges[v-1]) / (hedges[v] - hedges[v-1])
return hcounts[:v-1].sum() + hcounts[v-1]*frac
def subsample_frac(data: np.ndarray, frac: float) -> np.ndarray:
"""
Subsample an array to a given fraction
"""
return data[uniform.rvs(size=len(data)) < frac]
# -------------------------------------------------------------------------
class PerplexitySubsampler:
def __init__(self, perp_data: np.ndarray = None,
perp_histogram: Tuple[np.ndarray, np.ndarray] = None,
hbins: int = 1000):
"""
:param perp_data: a dataset of perplexity values
:param perp_histo: a histogram computed over a dataset of perplexity
values, passed as a tuple (counts, edges)
:param hbins: number of bins to use for the histogram approximation
(only used if `perp_data` is passed)
Either `perp_data` or `perp_histogram` must be passed
"""
if perp_data is not None:
# Get the P25 and P75 quartiles
self.qr = np.quantile(perp_data, [0.25, 0.75])
# Build an histogram of perplexities
range_max = self.qr[1]*10
counts, edges = np.histogram(perp_data, bins=hbins,
range=[0, range_max])
counts[-1] += len(perp_data[perp_data > range_max])
self.histo = Histo(counts, edges, (edges[:-1] + edges[1:])/2)
elif perp_histogram is not None:
edges = perp_histogram[1]
self.histo = Histo(perp_histogram[0], edges,
(edges[:-1] + edges[1:])/2)
self.qr = histo_quantile(self.histo.counts, self.histo.edges,
[0.25, 0.75])
else:
raise Exception("Neither sample nor histogram provided")
def _estimate(self, m: float, s: float,
ratio: float) -> Tuple[float, float]:
"""
Estimate the quantiles to be retained in the 1st & 4th original
quartiles
"""
# Compute the normalization factor
gauss_weights = norm.pdf(self.histo.centers, loc=m, scale=s)
hcounts = self.histo.counts
adjusted_norm = (hcounts*gauss_weights).sum()/hcounts.sum()/ratio
# Subsample the histogram
hcounts_sub = self.histo.counts*gauss_weights/adjusted_norm
sub_size = hcounts_sub.sum()
# Estimate the quantiles at Xa & Xb
ra = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[0])/sub_size
rb = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[1])/sub_size
#print(f"{m:10.2f} {s:10.2f} => {ra:.4} {1-rb:.4}")
return ra, 1-rb
def _error(self, point: np.ndarray, ratio: float,
pa: float, pb: float) -> float:
"""
Estimate the error in probability mass results
"""
actual_pa, actual_pb = self._estimate(point[0], point[1], ratio)
return abs(pa-actual_pa) + abs(pb-actual_pb)
def set(self, ratio: float, pa: float, pb: float):
"""
Compute the parameters needed to achieve a desired sampling ratio &
probability distribution
:param ratio: the desired sampling ratio
:param pa: the probability mass to be left in the first original
perplexity quartile
:param pb: the probability mass to be left in the fourth original
perplexity quartile
"""
# Obtain the initial parameters for the gaussian weighting function
# (assuming uniform data)
sdev = (self.qr[0] - self.qr[1]) / (norm.ppf(pa) - norm.ppf(1-pb))
mean = self.qr[0] - norm.ppf(pa)*sdev
# Optimize for the real data distribution
initial = np.array([mean, sdev])
result = sp.optimize.minimize(self._error, initial,
args=(ratio, pa, pb),
method='nelder-mead',
options={'xatol': 1e-8, 'disp': False})
self.mean, self.sdev = result.x
# Now that we have the final parameters, compute the weighting
# function over the histogram values
gauss_weights = norm.pdf(self.histo.centers, loc=self.mean,
scale=self.sdev)
# Find the normalization needed to achieve the desired sampling ratio
counts = self.histo.counts
self.norm = (counts*gauss_weights).sum()/counts.sum()/ratio
def subsample(self, data: np.ndarray) -> np.ndarray:
"""
Subsample a dataset according to the defined conditions
Note: set() must have been called previously
"""
# Create the gaussian weight for each data point
p = norm.pdf(data, loc=self.mean, scale=self.sdev)/self.norm
#print(p)
# Subsample data with probability according to the weight
return data[uniform.rvs(size=len(p)) < p]
def retain(self, perp: float) -> bool:
"""
Decide if a sample is to be retained based on its perplexity value
Note: set() must have been called previously
"""
p = norm.pdf(perp, loc=self.mean, scale=self.sdev)/self.norm
return rng.uniform() < p
def subsample_piecewise(self, data: np.ndarray,
pa: float, pb: float) -> np.ndarray:
"""
Creat a subsample by directly subsampling each region
"""
qr = self.qr
data1 = subsample_frac(data[data < qr[0]], pa/0.25*self.ratio)
data2 = subsample_frac(data[(data >= qr[0]) & (data <= qr[1])],
(1-pa-pb)/0.5*self.ratio)
data3 = subsample_frac(data[self.data > qr[1]], pb/0.25*self.ratio)
return np.hstack([data1, data2, data3])
def verify(self, data: np.ndarray, data_sub: np.ndarray) -> Tuple:
"""
Check the statistics of a sample
"""
ratio = len(data_sub)/len(data)
ra = len(data_sub[data_sub < self.qr[0]]) / len(data_sub)
rb = len(data_sub[data_sub > self.qr[1]]) / len(data_sub)
return ratio, ra, rb
def check_results(s: PerplexitySubsampler,
data_full: np.ndarray, data_sub: np.ndarray):
"""
Compute and print out the results for a subsample
"""
r, ra, rb = s.verify(data_full, data_sub)
print("Sampling ratio:", r)
print("Probability mass below Pa:", ra)
print("Probability mass above Pb:", rb)