Spaces:
Runtime error
Runtime error
Update modeling_rwkv.py
Browse files- modeling_rwkv.py +23 -10
modeling_rwkv.py
CHANGED
@@ -320,14 +320,16 @@ class RWKV(MyModule):
|
|
320 |
w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
|
321 |
except:
|
322 |
w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float())
|
323 |
-
#
|
324 |
-
#
|
325 |
|
326 |
print_need_newline = False
|
327 |
|
328 |
REAL_TIME_FIRST = False
|
|
|
329 |
for x in list(w.keys()):
|
330 |
if '.time_faaaa' in x: REAL_TIME_FIRST = True
|
|
|
331 |
if REAL_TIME_FIRST:
|
332 |
w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()}
|
333 |
self.w = w
|
@@ -377,7 +379,7 @@ class RWKV(MyModule):
|
|
377 |
elif '.ln_x' in x: # need fp32 for group_norm
|
378 |
w[x] = w[x].float()
|
379 |
else:
|
380 |
-
if (len(w[x].shape) == 2) and ('emb' not in x):
|
381 |
if WTYPE != torch.uint8:
|
382 |
w[x] = w[x].to(dtype=WTYPE)
|
383 |
else:
|
@@ -436,10 +438,12 @@ class RWKV(MyModule):
|
|
436 |
torch.cuda.empty_cache()
|
437 |
|
438 |
shape = [i for i in w[x].shape if i != 1]
|
439 |
-
if len(shape) >
|
440 |
-
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
|
|
|
|
|
441 |
else:
|
442 |
-
shape = f" {str(shape[0]).rjust(5)}
|
443 |
if layer_id == 0 or layer_id >= args.n_layer-1:
|
444 |
if print_need_newline:
|
445 |
prxxx('\n', end = '')
|
@@ -498,7 +502,7 @@ class RWKV(MyModule):
|
|
498 |
if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == '1':
|
499 |
HEAD_SIZE = args.n_att // args.n_head
|
500 |
rwkv6 = load(name="rwkv6", sources=[f"{current_path}/cuda/rwkv6_op.cpp", f"{current_path}/cuda/rwkv6.cu"],
|
501 |
-
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])
|
502 |
|
503 |
class RWKV_6(torch.autograd.Function):
|
504 |
@staticmethod
|
@@ -1024,15 +1028,24 @@ class RWKV(MyModule):
|
|
1024 |
dev = dd.device
|
1025 |
atype = dd.atype
|
1026 |
state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
1027 |
-
|
|
|
|
|
|
|
1028 |
state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
1029 |
|
1030 |
-
if embs is None:
|
1031 |
seq_mode = len(tokens) > 1
|
1032 |
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
|
1033 |
-
|
1034 |
x = embs
|
1035 |
seq_mode = True
|
|
|
|
|
|
|
|
|
|
|
|
|
1036 |
|
1037 |
for i in range(args.n_layer):
|
1038 |
bbb = f'blocks.{i}.'
|
|
|
320 |
w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
|
321 |
except:
|
322 |
w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float())
|
323 |
+
#del w['blocks.0.ln0.weight']
|
324 |
+
#del w['blocks.0.ln0.bias']
|
325 |
|
326 |
print_need_newline = False
|
327 |
|
328 |
REAL_TIME_FIRST = False
|
329 |
+
args.time_state = False
|
330 |
for x in list(w.keys()):
|
331 |
if '.time_faaaa' in x: REAL_TIME_FIRST = True
|
332 |
+
if '.time_state' in x: args.time_state = True
|
333 |
if REAL_TIME_FIRST:
|
334 |
w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()}
|
335 |
self.w = w
|
|
|
379 |
elif '.ln_x' in x: # need fp32 for group_norm
|
380 |
w[x] = w[x].float()
|
381 |
else:
|
382 |
+
if (len(w[x].shape) == 2) and ('emb' not in x) and ('_w1' not in x) and ('_w2' not in x):
|
383 |
if WTYPE != torch.uint8:
|
384 |
w[x] = w[x].to(dtype=WTYPE)
|
385 |
else:
|
|
|
438 |
torch.cuda.empty_cache()
|
439 |
|
440 |
shape = [i for i in w[x].shape if i != 1]
|
441 |
+
if len(shape) > 2:
|
442 |
+
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
|
443 |
+
elif len(shape) > 1:
|
444 |
+
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
|
445 |
else:
|
446 |
+
shape = f" {str(shape[0]).rjust(5)} "
|
447 |
if layer_id == 0 or layer_id >= args.n_layer-1:
|
448 |
if print_need_newline:
|
449 |
prxxx('\n', end = '')
|
|
|
502 |
if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == '1':
|
503 |
HEAD_SIZE = args.n_att // args.n_head
|
504 |
rwkv6 = load(name="rwkv6", sources=[f"{current_path}/cuda/rwkv6_op.cpp", f"{current_path}/cuda/rwkv6.cu"],
|
505 |
+
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])
|
506 |
|
507 |
class RWKV_6(torch.autograd.Function):
|
508 |
@staticmethod
|
|
|
1028 |
dev = dd.device
|
1029 |
atype = dd.atype
|
1030 |
state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
1031 |
+
if args.time_state:
|
1032 |
+
state[i*3+1] = w[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
|
1033 |
+
else:
|
1034 |
+
state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous()
|
1035 |
state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
|
1036 |
|
1037 |
+
if embs is None and tokens is not None:
|
1038 |
seq_mode = len(tokens) > 1
|
1039 |
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
|
1040 |
+
elif embs is not None and tokens is None:
|
1041 |
x = embs
|
1042 |
seq_mode = True
|
1043 |
+
elif embs is not None and tokens is not None:
|
1044 |
+
seq_mode = len(tokens) > 1
|
1045 |
+
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
|
1046 |
+
x = torch.cat([x, embs], dim=0)
|
1047 |
+
else:
|
1048 |
+
raise ValueError('Either tokens or embs must be provided')
|
1049 |
|
1050 |
for i in range(args.n_layer):
|
1051 |
bbb = f'blocks.{i}.'
|