Tri Dao
|
dfe29f5e2b
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
|
1 anno fa |
Tri Dao
|
d0032700d1
Add tests for Pythia, GPT-JT, and RedPajama models
|
1 anno fa |
Tri Dao
|
8a733cbd53
[Gen] Fix calling update_graph_cache in tests
|
1 anno fa |
Tri Dao
|
913922cac5
[Gen] Refactor decoding function
|
1 anno fa |
Tri Dao
|
0e8c46ae08
Run isort and black on test files
|
1 anno fa |
Tri Dao
|
8e9820a55b
[Rotary] Fix tests when loading state dict with rotary inv_freqs
|
1 anno fa |
Tri Dao
|
425dbcb6c6
[MHA] Implement MQA/GQA
|
1 anno fa |
Tri Dao
|
b3177dfaf6
[GPT] Enable FlashAttention for GPT-J
|
1 anno fa |
Tri Dao
|
96d10f6545
Implement LLaMa
|
1 anno fa |
Tri Dao
|
605655bc66
[Gen] Fix FT kernel when using CG
|
1 anno fa |
Tri Dao
|
393882bc08
[LayerNorm] Implement LN with parallel residual, support dim 8k
|
1 anno fa |
Tri Dao
|
993d12448e
Implement GPT-NeoX
|
1 anno fa |
Tri Dao
|
4d87e4d875
Implement GPT-J
|
1 anno fa |