kernel-luso-comfort's picture
Add Apache License 2.0 header to multiple source files
202eff6
raw
history blame
2.38 kB
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import nltk
import numpy as np
from utilities.constants import IMAGENET_DEFAULT_TEMPLATES
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
def get_tag(tokenized, tags):
if not isinstance(tags, (list, tuple)):
tags = [tags]
ret = []
for (word, pos) in nltk.pos_tag(tokenized):
for tag in tags:
if pos == tag:
ret.append(word)
return ret
def get_noun_phrase(tokenized):
# Taken from Su Nam Kim Paper...
grammar = r"""
NBAR:
{<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
NP:
{<NBAR>}
{<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
"""
chunker = nltk.RegexpParser(grammar)
chunked = chunker.parse(nltk.pos_tag(tokenized))
continuous_chunk = []
current_chunk = []
for subtree in chunked:
if isinstance(subtree, nltk.Tree):
current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
elif current_chunk:
named_entity = ' '.join(current_chunk)
if named_entity not in continuous_chunk:
continuous_chunk.append(named_entity)
current_chunk = []
else:
continue
return continuous_chunk
def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
tokenized = nltk.word_tokenize(text)
if random.random() >= phrase_prob:
nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
else:
nouns = get_noun_phrase(tokenized)
prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
if append_text:
prompt_texts += [text]
nouns += [text]
return prompt_texts, nouns