Browse Source

bump torch to 2.3.1

AlpinDale 6 months ago
parent
commit
b1e61268a8
6 changed files with 10 additions and 10 deletions
  1. 1 1
      .github/workflows/publish.yml
  2. 1 1
      CMakeLists.txt
  3. 1 1
      environment.yaml
  4. 1 1
      pyproject.toml
  5. 1 1
      requirements-build.txt
  6. 5 5
      requirements-cuda.txt

+ 1 - 1
.github/workflows/publish.yml

@@ -49,7 +49,7 @@ jobs:
       matrix:
           os: ['ubuntu-20.04']
           python-version: ['3.8', '3.9', '3.10', '3.11']
-          pytorch-version: ['2.3.0']  # Must be the most recent version that meets requirements-cuda.txt.
+          pytorch-version: ['2.3.1']  # Must be the most recent version that meets requirements-cuda.txt.
           cuda-version: ['11.8', '12.1']
 
     steps:

+ 1 - 1
CMakeLists.txt

@@ -32,7 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
 # requirements.txt files and should be kept consistent.  The ROCm torch
 # versions are derived from Dockerfile.rocm
 #
-set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
+set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
 set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
 
 #

+ 1 - 1
environment.yaml

@@ -7,7 +7,7 @@ channels:
   - defaults
 dependencies:
   - python=3.11.*
-  - pytorch=2.3.0
+  - pytorch=2.3.1
   - pytorch-cuda=12.1.*
   - cuda-nvcc=12.1.*
   - cuda-libraries-dev=12.1.*

+ 1 - 1
pyproject.toml

@@ -4,7 +4,7 @@ requires = [
     "ninja",
     "packaging",
     "setuptools >= 49.4.0",
-    "torch == 2.3.0",
+    "torch == 2.3.1",
     "wheel",
 ]
 build-backend = "setuptools.build_meta"

+ 1 - 1
requirements-build.txt

@@ -3,5 +3,5 @@ cmake>=3.21
 ninja
 packaging
 setuptools>=49.4.0
-torch==2.3.0
+torch==2.3.1
 wheel

+ 5 - 5
requirements-cuda.txt

@@ -3,10 +3,10 @@
 
 # Dependencies for NVIDIA GPUs
 nvidia-ml-py == 12.555.43
-torch == 2.3.0
-torchvision == 0.18.0  # for phi3v
-xformers == 0.0.26.post1  # Requires torch 2.3.0
-triton >= 2.2.0
-vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0
+torch == 2.3.1
+torchvision == 0.18.1  # for phi3v
+xformers == 0.0.27  # Requires torch 2.3.1
+triton >= 2.2.1
+vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1
 causal-conv1d >= 1.2.1
 mamba-ssm >= 1.2.2