howard-hou commited on
Commit
7fabc1b
·
verified ·
1 Parent(s): 98ba582

Update modeling_rwkv.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv.py +23 -10
modeling_rwkv.py CHANGED
@@ -320,14 +320,16 @@ class RWKV(MyModule):
320
  w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
321
  except:
322
  w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float())
323
- # del w['blocks.0.ln0.weight']
324
- # del w['blocks.0.ln0.bias']
325
 
326
  print_need_newline = False
327
 
328
  REAL_TIME_FIRST = False
 
329
  for x in list(w.keys()):
330
  if '.time_faaaa' in x: REAL_TIME_FIRST = True
 
331
  if REAL_TIME_FIRST:
332
  w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()}
333
  self.w = w
@@ -377,7 +379,7 @@ class RWKV(MyModule):
377
  elif '.ln_x' in x: # need fp32 for group_norm
378
  w[x] = w[x].float()
379
  else:
380
- if (len(w[x].shape) == 2) and ('emb' not in x):
381
  if WTYPE != torch.uint8:
382
  w[x] = w[x].to(dtype=WTYPE)
383
  else:
@@ -436,10 +438,12 @@ class RWKV(MyModule):
436
  torch.cuda.empty_cache()
437
 
438
  shape = [i for i in w[x].shape if i != 1]
439
- if len(shape) > 1:
440
- shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
 
 
441
  else:
442
- shape = f" {str(shape[0]).rjust(5)} "
443
  if layer_id == 0 or layer_id >= args.n_layer-1:
444
  if print_need_newline:
445
  prxxx('\n', end = '')
@@ -498,7 +502,7 @@ class RWKV(MyModule):
498
  if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == '1':
499
  HEAD_SIZE = args.n_att // args.n_head
500
  rwkv6 = load(name="rwkv6", sources=[f"{current_path}/cuda/rwkv6_op.cpp", f"{current_path}/cuda/rwkv6.cu"],
501
- verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])
502
 
503
  class RWKV_6(torch.autograd.Function):
504
  @staticmethod
@@ -1024,15 +1028,24 @@ class RWKV(MyModule):
1024
  dev = dd.device
1025
  atype = dd.atype
1026
  state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
1027
- state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous()
 
 
 
1028
  state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
1029
 
1030
- if embs is None:
1031
  seq_mode = len(tokens) > 1
1032
  x = w['emb.weight'][tokens if seq_mode else tokens[0]]
1033
- else:
1034
  x = embs
1035
  seq_mode = True
 
 
 
 
 
 
1036
 
1037
  for i in range(args.n_layer):
1038
  bbb = f'blocks.{i}.'
 
320
  w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
321
  except:
322
  w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float())
323
+ #del w['blocks.0.ln0.weight']
324
+ #del w['blocks.0.ln0.bias']
325
 
326
  print_need_newline = False
327
 
328
  REAL_TIME_FIRST = False
329
+ args.time_state = False
330
  for x in list(w.keys()):
331
  if '.time_faaaa' in x: REAL_TIME_FIRST = True
332
+ if '.time_state' in x: args.time_state = True
333
  if REAL_TIME_FIRST:
334
  w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()}
335
  self.w = w
 
379
  elif '.ln_x' in x: # need fp32 for group_norm
380
  w[x] = w[x].float()
381
  else:
382
+ if (len(w[x].shape) == 2) and ('emb' not in x) and ('_w1' not in x) and ('_w2' not in x):
383
  if WTYPE != torch.uint8:
384
  w[x] = w[x].to(dtype=WTYPE)
385
  else:
 
438
  torch.cuda.empty_cache()
439
 
440
  shape = [i for i in w[x].shape if i != 1]
441
+ if len(shape) > 2:
442
+ shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
443
+ elif len(shape) > 1:
444
+ shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
445
  else:
446
+ shape = f" {str(shape[0]).rjust(5)} "
447
  if layer_id == 0 or layer_id >= args.n_layer-1:
448
  if print_need_newline:
449
  prxxx('\n', end = '')
 
502
  if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == '1':
503
  HEAD_SIZE = args.n_att // args.n_head
504
  rwkv6 = load(name="rwkv6", sources=[f"{current_path}/cuda/rwkv6_op.cpp", f"{current_path}/cuda/rwkv6.cu"],
505
+ verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])
506
 
507
  class RWKV_6(torch.autograd.Function):
508
  @staticmethod
 
1028
  dev = dd.device
1029
  atype = dd.atype
1030
  state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
1031
+ if args.time_state:
1032
+ state[i*3+1] = w[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
1033
+ else:
1034
+ state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous()
1035
  state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
1036
 
1037
+ if embs is None and tokens is not None:
1038
  seq_mode = len(tokens) > 1
1039
  x = w['emb.weight'][tokens if seq_mode else tokens[0]]
1040
+ elif embs is not None and tokens is None:
1041
  x = embs
1042
  seq_mode = True
1043
+ elif embs is not None and tokens is not None:
1044
+ seq_mode = len(tokens) > 1
1045
+ x = w['emb.weight'][tokens if seq_mode else tokens[0]]
1046
+ x = torch.cat([x, embs], dim=0)
1047
+ else:
1048
+ raise ValueError('Either tokens or embs must be provided')
1049
 
1050
  for i in range(args.n_layer):
1051
  bbb = f'blocks.{i}.'