Source code for constants

[docs] __gpt2_num_params_lm_head = 50257*768 + 1024*768 # wte, wpe
[docs] __gpt2_num_params_gpt2block = 768*2304 + 768*768 + 768*3072 + 768*3072 + 6*768 + 2304 + 3072
[docs] __gpt2_num_params_gpt2block_in_pg0 = 768*2304 + 768*768 + 768*3072 + 768*3072 # weight decay (if any)
[docs] __gpt2_num_params_gpt2block_in_pg1 = 768 + 768 + 2304 + 768 + 768 + 768 + 3072 + 768 # no weight decay (no matter what) ln_1.weight: 768 ln_1.bias: 768 attn.c_attn.bias: 2304 attn.c_proj.bias: 768 ln_2.weight: 768 ln_2.bias: 768 mlp.c_fc.bias: 3072 mlp.c_proj.bias: 768
[docs] __gpt2_param_intervals_pg0 = {"lmhead": (0, __gpt2_num_params_lm_head-1)}
__gpt2_param_intervals_pg0.update({ f"gpt2block{i}": ( __gpt2_num_params_lm_head+__gpt2_num_params_gpt2block_in_pg0*i, __gpt2_num_params_lm_head+__gpt2_num_params_gpt2block_in_pg0*(i+1)-1 ) for i in range(12) })
[docs] __gpt2_param_intervals_pg1 = { f"gpt2block{i}": ( __gpt2_num_params_gpt2block_in_pg1*i, __gpt2_num_params_gpt2block_in_pg1*(i+1)-1 ) for i in range(12) }
__gpt2_param_intervals_pg1.update({ "ln": (__gpt2_num_params_gpt2block_in_pg1*12, __gpt2_num_params_gpt2block_in_pg1*12+768+768-1) })
[docs] GPT2 = { "param_name_in_group": [ {'parameter_names': ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.2.attn.c_attn.weight', 'transformer.h.2.attn.c_proj.weight', 'transformer.h.2.mlp.c_fc.weight', 'transformer.h.2.mlp.c_proj.weight', 'transformer.h.3.attn.c_attn.weight', 'transformer.h.3.attn.c_proj.weight', 'transformer.h.3.mlp.c_fc.weight', 'transformer.h.3.mlp.c_proj.weight', 'transformer.h.4.attn.c_attn.weight', 'transformer.h.4.attn.c_proj.weight', 'transformer.h.4.mlp.c_fc.weight', 'transformer.h.4.mlp.c_proj.weight', 'transformer.h.5.attn.c_attn.weight', 'transformer.h.5.attn.c_proj.weight', 'transformer.h.5.mlp.c_fc.weight', 'transformer.h.5.mlp.c_proj.weight', 'transformer.h.6.attn.c_attn.weight', 'transformer.h.6.attn.c_proj.weight', 'transformer.h.6.mlp.c_fc.weight', 'transformer.h.6.mlp.c_proj.weight', 'transformer.h.7.attn.c_attn.weight', 'transformer.h.7.attn.c_proj.weight', 'transformer.h.7.mlp.c_fc.weight', 'transformer.h.7.mlp.c_proj.weight', 'transformer.h.8.attn.c_attn.weight', 'transformer.h.8.attn.c_proj.weight', 'transformer.h.8.mlp.c_fc.weight', 'transformer.h.8.mlp.c_proj.weight', 'transformer.h.9.attn.c_attn.weight', 'transformer.h.9.attn.c_proj.weight', 'transformer.h.9.mlp.c_fc.weight', 'transformer.h.9.mlp.c_proj.weight', 'transformer.h.10.attn.c_attn.weight', 'transformer.h.10.attn.c_proj.weight', 'transformer.h.10.mlp.c_fc.weight', 'transformer.h.10.mlp.c_proj.weight', 'transformer.h.11.attn.c_attn.weight', 'transformer.h.11.attn.c_proj.weight', 'transformer.h.11.mlp.c_fc.weight', 'transformer.h.11.mlp.c_proj.weight']}, {'parameter_names': ['transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.bias', 'transformer.h.2.ln_1.weight', 'transformer.h.2.ln_1.bias', 'transformer.h.2.attn.c_attn.bias', 'transformer.h.2.attn.c_proj.bias', 'transformer.h.2.ln_2.weight', 'transformer.h.2.ln_2.bias', 'transformer.h.2.mlp.c_fc.bias', 'transformer.h.2.mlp.c_proj.bias', 'transformer.h.3.ln_1.weight', 'transformer.h.3.ln_1.bias', 'transformer.h.3.attn.c_attn.bias', 'transformer.h.3.attn.c_proj.bias', 'transformer.h.3.ln_2.weight', 'transformer.h.3.ln_2.bias', 'transformer.h.3.mlp.c_fc.bias', 'transformer.h.3.mlp.c_proj.bias', 'transformer.h.4.ln_1.weight', 'transformer.h.4.ln_1.bias', 'transformer.h.4.attn.c_attn.bias', 'transformer.h.4.attn.c_proj.bias', 'transformer.h.4.ln_2.weight', 'transformer.h.4.ln_2.bias', 'transformer.h.4.mlp.c_fc.bias', 'transformer.h.4.mlp.c_proj.bias', 'transformer.h.5.ln_1.weight', 'transformer.h.5.ln_1.bias', 'transformer.h.5.attn.c_attn.bias', 'transformer.h.5.attn.c_proj.bias', 'transformer.h.5.ln_2.weight', 'transformer.h.5.ln_2.bias', 'transformer.h.5.mlp.c_fc.bias', 'transformer.h.5.mlp.c_proj.bias', 'transformer.h.6.ln_1.weight', 'transformer.h.6.ln_1.bias', 'transformer.h.6.attn.c_attn.bias', 'transformer.h.6.attn.c_proj.bias', 'transformer.h.6.ln_2.weight', 'transformer.h.6.ln_2.bias', 'transformer.h.6.mlp.c_fc.bias', 'transformer.h.6.mlp.c_proj.bias', 'transformer.h.7.ln_1.weight', 'transformer.h.7.ln_1.bias', 'transformer.h.7.attn.c_attn.bias', 'transformer.h.7.attn.c_proj.bias', 'transformer.h.7.ln_2.weight', 'transformer.h.7.ln_2.bias', 'transformer.h.7.mlp.c_fc.bias', 'transformer.h.7.mlp.c_proj.bias', 'transformer.h.8.ln_1.weight', 'transformer.h.8.ln_1.bias', 'transformer.h.8.attn.c_attn.bias', 'transformer.h.8.attn.c_proj.bias', 'transformer.h.8.ln_2.weight', 'transformer.h.8.ln_2.bias', 'transformer.h.8.mlp.c_fc.bias', 'transformer.h.8.mlp.c_proj.bias', 'transformer.h.9.ln_1.weight', 'transformer.h.9.ln_1.bias', 'transformer.h.9.attn.c_attn.bias', 'transformer.h.9.attn.c_proj.bias', 'transformer.h.9.ln_2.weight', 'transformer.h.9.ln_2.bias', 'transformer.h.9.mlp.c_fc.bias', 'transformer.h.9.mlp.c_proj.bias', 'transformer.h.10.ln_1.weight', 'transformer.h.10.ln_1.bias', 'transformer.h.10.attn.c_attn.bias', 'transformer.h.10.attn.c_proj.bias', 'transformer.h.10.ln_2.weight', 'transformer.h.10.ln_2.bias', 'transformer.h.10.mlp.c_fc.bias', 'transformer.h.10.mlp.c_proj.bias', 'transformer.h.11.ln_1.weight', 'transformer.h.11.ln_1.bias', 'transformer.h.11.attn.c_attn.bias', 'transformer.h.11.attn.c_proj.bias', 'transformer.h.11.ln_2.weight', 'transformer.h.11.ln_2.bias', 'transformer.h.11.mlp.c_fc.bias', 'transformer.h.11.mlp.c_proj.bias', 'transformer.ln_f.weight', 'transformer.ln_f.bias']} ], "num_params": { "lm_head": __gpt2_num_params_lm_head, "gpt2block": __gpt2_num_params_gpt2block, "gpt2block_in_pg0": __gpt2_num_params_gpt2block_in_pg0, "gpt2block_in_pg1": __gpt2_num_params_gpt2block_in_pg1, }, "param_intervals":{ "pg0": __gpt2_param_intervals_pg0, "pg1": __gpt2_param_intervals_pg1 } }