Spaces:
Build error
Build error
from __future__ import print_function | |
from __future__ import absolute_import | |
__author__ = 'Taneem Jan, taneemishere.github.io' | |
from .Vocabulary import * | |
from .Utils import * | |
class Sampler: | |
def __init__(self, voc_path, input_shape, output_size, context_length): | |
self.voc = Vocabulary() | |
self.voc.retrieve(voc_path) | |
self.input_shape = input_shape | |
self.output_size = output_size | |
print("Vocabulary size: {}".format(self.voc.size)) | |
print("Input shape: {}".format(self.input_shape)) | |
print("Output size: {}".format(self.output_size)) | |
self.context_length = context_length | |
def predict_greedy(self, model, input_img, require_sparse_label=True, sequence_length=150, verbose=False): | |
current_context = [self.voc.vocabulary[PLACEHOLDER]] * (self.context_length - 1) | |
current_context.append(self.voc.vocabulary[START_TOKEN]) | |
if require_sparse_label: | |
current_context = Utils.sparsify(current_context, self.output_size) | |
predictions = START_TOKEN | |
out_probas = [] | |
for i in range(0, sequence_length): | |
if verbose: | |
print("predicting {}/{}...".format(i, sequence_length)) | |
probas = model.predict(input_img, np.array([current_context])) | |
prediction = np.argmax(probas) | |
out_probas.append(probas) | |
new_context = [] | |
for j in range(1, self.context_length): | |
new_context.append(current_context[j]) | |
if require_sparse_label: | |
sparse_label = np.zeros(self.output_size) | |
sparse_label[prediction] = 1 | |
new_context.append(sparse_label) | |
else: | |
new_context.append(prediction) | |
current_context = new_context | |
predictions += self.voc.token_lookup[prediction] | |
if self.voc.token_lookup[prediction] == END_TOKEN: | |
break | |
return predictions, out_probas | |