Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
import collections | |
import torch | |
class Fold(object): | |
class Node(object): | |
def __init__(self, op, step, index, *args): | |
self.op = op | |
self.step = step | |
self.index = index | |
self.args = args | |
self.split_idx = -1 | |
self.batch = True | |
def split(self, num): ##lo uso cuando la red da mas de un tensor como output | |
u"""Split resulting node, if function returns multiple values.""" | |
#print("op", self.op) | |
#print("step", self.step) | |
#print("arg", self.args) | |
nodes = [] | |
for idx in range(num): | |
nodes.append(Fold.Node( | |
self.op, self.step, self.index, *self.args)) | |
nodes[-1].split_idx = idx | |
#print("idx", idx) | |
#print("nodes", nodes) | |
#print("nodes", nodes[-1].split_idx) | |
return tuple(nodes) | |
def nobatch(self): | |
self.batch = False | |
return self | |
def get(self, values): | |
if self.split_idx >= 0: | |
#print("split index", self.split_idx) | |
#print("v0",values[self.step][self.op]) | |
#print("v1",values[self.step][self.op][self.split_idx]) | |
#print("v2",values[self.step][self.op][self.split_idx][self.index]) | |
return values[self.step][self.op][self.split_idx][self.index] | |
else: | |
return values[self.step][self.op][self.index] | |
def __repr__(self): | |
return u"[%d:%d]%s" % ( | |
self.step, self.index, self.op) | |
def __init__(self, volatile=False, cuda=False, variable=True): | |
self.steps = collections.defaultdict( | |
lambda: collections.defaultdict(list)) | |
self.cached_nodes = collections.defaultdict(dict) | |
self.total_nodes = 0 | |
self.volatile = volatile | |
self._cuda = cuda | |
self._variable = variable | |
def __repr__(self): | |
return str(self.steps.keys()) | |
def cuda(self): | |
self._cuda = True | |
return self | |
def add(self, op, *args): | |
u"""Add op to the fold.""" | |
self.total_nodes += 1 | |
# si el nodo no fue visitado antes | |
if args not in self.cached_nodes[op]: | |
#arg a veces son solo los features del nodo, a veces tiene info de los hijos tambien | |
step = max([0] + [arg.step + 1 for arg in args if isinstance(arg, Fold.Node)]) #step es nivel | |
node = Fold.Node(op, step, len(self.steps[step][op]), *args)#voy creando nodos fold y agregndolos a cached nodes | |
#len(self.steps[step][op] es index, cuenta los nodos por nivel | |
#en steps guardo los nodos, por "step"=nivel, y operacion | |
self.steps[step][op].append(args) | |
self.cached_nodes[op][args] = node | |
return self.cached_nodes[op][args] | |
def _batch_args(self, arg_lists, values, op): | |
res = [] | |
for arg in arg_lists: | |
#print("arg apply", arg) | |
r = [] | |
#si es un nodo de fold | |
#si viene un "nodo" fold, obtengo todos los argumentos que tiene ese nodo y los concateno en un solo vector | |
#print("op", op) | |
if isinstance(arg[0], Fold.Node): | |
#print("arg", arg) | |
if arg[0].batch: | |
for x in arg: | |
#print("x", x) | |
r.append(x.get(values)) | |
#print("r sin stack", r) | |
#print("r con stack", torch.stack(r)) | |
res.append(torch.stack(r)) | |
#if op == 'sampleEncoder': | |
# print("arg", arg) | |
# print("r", r) | |
#nunca uso este caso | |
''' | |
else: | |
for i in range(2, len(arg)): | |
if arg[i] != arg[0]: | |
raise ValueError(u"Can not use more then one of nobatch argument, got: %s." % str(arg)) | |
x = arg[0] | |
res.append(x.get(values)) | |
''' | |
else: | |
#print("else") | |
#si es un tensor de atributos | |
if isinstance(arg[0], torch.Tensor): | |
var = torch.stack(arg) | |
res.append(var) | |
#si es un nodo de arbol | |
else: | |
if op != "classifyLossEstimator" and op != "calcularLossAtributo" and op != "vectorMult" and op != "sampleEncoder": #en caso de que op sea alguna red | |
var = arg[0].radius | |
elif op == 'sampleEncoder': | |
print("arg", arg) | |
var = arg | |
elif op == "calcularLossAtributo": #en caso de estar calculano mse | |
#var = [(a.radius, a.childs()) for a in arg] | |
var = [a.radius for a in arg] | |
#print("var", var) | |
elif op == "classifyLossEstimator": | |
var = [a.childs() for a in arg] #en caso de estar calculando cross entropy | |
elif op == "vectorMult": | |
#print("arg",arg) | |
if isinstance(arg, torch.Tensor): | |
var = arg | |
else: | |
var = list(arg) | |
#print("var",var) | |
res.append(var) | |
#if op == 'sampleEncoder': | |
# print("res", res) | |
return res | |
def apply(self, nn, nodes): | |
u"""Apply current fold to given neural module.""" | |
values = {} | |
for step in sorted(self.steps.keys()): | |
values[step] = {} | |
for op in self.steps[step]: | |
func = getattr(nn, op) | |
#if op == 'sampleEncoder': | |
# print("nodes", nodes) | |
##junto los atributos de los nodos que estan en el mismo step y op | |
try: | |
batched_args = self._batch_args( | |
zip(*self.steps[step][op]), values, op) | |
except Exception: | |
print("Error while executing node %s[%d] with args: %s" % (op, step, self.steps[step][op])) | |
raise | |
res = func(*batched_args) | |
#if op == 'bifurcationDecoder': | |
# print("res", res) | |
if isinstance(res, (tuple, list)): | |
values[step][op] = [] | |
for x in res: | |
#values[step][op].append(torch.chunk(x, arg_size)) | |
values[step][op].append(x) | |
else: | |
if len(res.shape) == 1 and op != 'vectorAdder' and op != 'vectorMult': | |
values[step][op] = res.reshape(-1, 4) | |
else: #los vectores de output del clasificador tienen tres elementos, no hago el reshape | |
values[step][op] = res | |
if op == 'vectorMult': | |
print("res", res) | |
try: | |
return self._batch_args(nodes, values, op) | |
except Exception: | |
print("cannot batch") | |
raise |