Tri Dao
|
8c20cfef49
[Rotary] Support qkv block layout from GQA
|
3 months ago |
Antony Frolov
|
3566596ad8
Fix typo in RotaryEmbedding forward output type (#666)
|
1 year ago |
Katherine Crowson
|
4c8ff9154e
Fix NameError and typo in ApplyRotaryEmbQKV_ (#569)
|
1 year ago |
Tri Dao
|
1879e089c7
Reduce number of templates for headdim > 128
|
1 year ago |
Tri Dao
|
a86442f0f3
[Gen] Use flash_attn_with_kvcache in generation
|
1 year ago |
Tri Dao
|
b28ec236df
[Rotary] Implement varlen rotary
|
1 year ago |
Tri Dao
|
de2949f37d
[Rotary] Pass max_seqlen from mha.py to rotary during inference
|
1 year ago |
Tri Dao
|
942fcbf046
[Rotary] Implement rotary in Triton
|
1 year ago |
Tri Dao
|
f1a73d0740
Run isort and black on python files
|
1 year ago |
Tri Dao
|
425dbcb6c6
[MHA] Implement MQA/GQA
|
1 year ago |
Tri Dao
|
ec9f74ab9a
[Rotary] Don't store inv_freq in state_dict
|
1 year ago |
Volodymyr Kyrylov
|
70ab266a56
rotary: update cos/sin cache when switching from inference mode
|
1 year ago |
Tri Dao
|
62e9814466
[Rotary] Make sure frequency calculation is in fp32
|
1 year ago |
Tri Dao
|
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
|
1 year ago |
Tri Dao
|
e45a46a5b7
[Rotary] Implement GPT-J style (interleaved) rotary
|
1 year ago |
Tri Dao
|
85b8e3d334
[Docs] Mention that XPos's scale_base is recommended to be 512
|
1 year ago |
Tri Dao
|
1e712ea8b0
Implement TensorParallel for MHA
|
2 years ago |
Tri Dao
|
496e4f528c
Implement XPos (Sun et al.)
|
2 years ago |
Alexander Ploshkin
|
ee8984d2be
add asserts for sin shape
|
2 years ago |
Alexander Ploshkin
|
c7c66976cc
fix slicing dimensions
|
2 years ago |
Alexander Ploshkin
|
96656b9323
Remove redundant shape asserts in rotary embeddings
|
2 years ago |
Tri Dao
|
71f674ae23
[Rotary] Customize base, support seqlen_offset
|
2 years ago |
Tri Dao
|
d4b320b31f
Add MLP, MHA, Block, Embedding modules
|
2 years ago |