Naandhu's picture
add gitignore
d705756
# Preprocess functions
preprocess_code = """
def preprocess(text):
import re
import string
import spacy
try:
# Checking if it the string
text = str(text)
# remove html
text = re.sub(r"<.*?>","", text)
# Remove URL
url_pattern = r"https?://\\S+|www\\.\\S+|\\S+\\.\\S{2,"
text = re.sub(url_pattern,"", text)
# Remove Punctuation
translator = str.maketrans("","", string.punctuation)
text.translate(translator)
# Lower case
text.lower().strip()
# Remove Unicodes - only applicable for english language. Because other language letters represented as unicodes.
unicode_pattern = str.maketrans("","","\\xa0")
text.translate(unicode_pattern)
# Remove Escape sequences (\\n, \\t, \\r)
text = re.sub(r"\\[nt\\r]"," ",text)
# Remove Stop words using spacy
spacy.prefer_gpu() # using GPU if available. may reduce the run time.
nlp = spacy.load("en_core_web_sm")
doc = nlp(text)
text = " ".join([token.text for token in doc if not token.is_stop])
# Remove irrelevant white spaces
text = re.sub(r"\\s+"," ",text)
except:
print(f"error occured")
return text
"""
postprocess_code = """
def post_process(output):
import torch
classes = ['ACCOUNTANT', 'ADVOCATE', 'AGRICULTURE', 'APPAREL', 'ARTS', 'AUTOMOBILE', 'AVIATION', 'BANKING', 'BPO', 'BUSINESS-DEVELOPMENT', 'CHEF', 'CONSTRUCTION', 'CONSULTANT', 'DESIGNER', 'DIGITAL-MEDIA', 'ENGINEERING', 'FINANCE', 'FITNESS', 'HEALTHCARE', 'HR', 'INFORMATION-TECHNOLOGY', 'PUBLIC-RELATIONS', 'SALES', 'TEACHER']
try:
logits = output.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
temp = probs.sort()
return classes[temp[-1][-1].item()]
except:
print("Some Error occured")
"""
from transformers import PretrainedConfig, AutoModel
class CustomConfig(PretrainedConfig):
def __init__(self, preprocess_function = None, postprocess_function = None, **kwargs):
super().__init__(**kwargs)
self.preprocess_function = preprocess_function
self.postprocess_function = postprocess_function
config = CustomConfig(preprocess_function= preprocess_code, postprocess_function=postprocess_code)
config.save_pretrained("config with functions")