weight_shapes.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Weight Shapes are in the format
  2. # ([K, N], TP_SPLIT_DIM)
  3. # Example:
  4. # A shape of ([14336, 4096], 0) indicates the following GEMM shape,
  5. # - TP1 : K = 14336, N = 4096
  6. # - TP2 : K = 7168, N = 4096
  7. # A shape of ([4096, 6144], 1) indicates the following GEMM shape,
  8. # - TP1 : K = 4096, N = 6144
  9. # - TP4 : K = 4096, N = 1536
  10. # TP1 shapes
  11. WEIGHT_SHAPES = {
  12. "mistralai/Mistral-7B-v0.1": [
  13. ([4096, 6144], 1),
  14. ([4096, 4096], 0),
  15. ([4096, 28672], 1),
  16. ([14336, 4096], 0),
  17. ],
  18. "meta-llama/Llama-2-7b-hf": [
  19. ([4096, 12288], 1),
  20. ([4096, 4096], 0),
  21. ([4096, 22016], 1),
  22. ([11008, 4096], 0),
  23. ],
  24. "meta-llama/Llama-3-8b": [
  25. ([4096, 6144], 1),
  26. ([4096, 4096], 0),
  27. ([4096, 28672], 1),
  28. ([14336, 4096], 0),
  29. ],
  30. "meta-llama/Llama-2-13b-hf": [
  31. ([5120, 15360], 1),
  32. ([5120, 5120], 0),
  33. ([5120, 27648], 1),
  34. ([13824, 5120], 0),
  35. ],
  36. "meta-llama/Llama-2-70b-hf": [
  37. ([8192, 10240], 1),
  38. ([8192, 8192], 0),
  39. ([8192, 57344], 1),
  40. ([28672, 8192], 0),
  41. ],
  42. }